Train an RL agent to execute natural language instructions in a 3D Environment (PyTorch)

Overview

Gated-Attention Architectures for Task-Oriented Language Grounding

This is a PyTorch implementation of the AAAI-18 paper:

Gated-Attention Architectures for Task-Oriented Language Grounding
Devendra Singh Chaplot, Kanthashree Mysore Sathyendra, Rama Kumar Pasumarthi, Dheeraj Rajagopal, Ruslan Salakhutdinov
Carnegie Mellon University

Project Website: https://sites.google.com/view/gated-attention

example

This repository contains:

  • Code for training an A3C-LSTM agent using Gated-Attention
  • Code for Doom-based language grounding environment

Dependencies

(We recommend using Anaconda)

Usage

Using the Environment

For running a random agent:

python env_test.py

To play in the environment:

python env_test.py --interactive 1

To change the difficulty of the environment (easy/medium/hard):

python env_test.py -d easy

Training Gated-Attention A3C-LSTM agent

For training a A3C-LSTM agent with 32 threads:

python a3c_main.py --num-processes 32 --evaluate 0

The code will save the best model at ./saved/model_best.

To the test the pre-trained model for Multitask Generalization:

python a3c_main.py --evaluate 1 --load saved/pretrained_model

To the test the pre-trained model for Zero-shot Task Generalization:

python a3c_main.py --evaluate 2 --load saved/pretrained_model

To the visualize the model while testing add '--visualize 1':

python a3c_main.py --evaluate 2 --load saved/pretrained_model --visualize 1

To test the trained model, use --load saved/model_best in the above commands.

All arguments for a3c_main.py:

  -h, --help            show this help message and exit
  -l MAX_EPISODE_LENGTH, --max-episode-length MAX_EPISODE_LENGTH
                        maximum length of an episode (default: 30)
  -d DIFFICULTY, --difficulty DIFFICULTY
                        Difficulty of the environment, "easy", "medium" or
                        "hard" (default: hard)
  --living-reward LIVING_REWARD
                        Default reward at each time step (default: 0, change
                        to -0.005 to encourage shorter paths)
  --frame-width FRAME_WIDTH
                        Frame width (default: 300)
  --frame-height FRAME_HEIGHT
                        Frame height (default: 168)
  -v VISUALIZE, --visualize VISUALIZE
                        Visualize the envrionment (default: 0, use 0 for
                        faster training)
  --sleep SLEEP         Sleep between frames for better visualization
                        (default: 0)
  --scenario-path SCENARIO_PATH
                        Doom scenario file to load (default: maps/room.wad)
  --interactive INTERACTIVE
                        Interactive mode enables human to play (default: 0)
  --all-instr-file ALL_INSTR_FILE
                        All instructions file (default:
                        data/instructions_all.json)
  --train-instr-file TRAIN_INSTR_FILE
                        Train instructions file (default:
                        data/instructions_train.json)
  --test-instr-file TEST_INSTR_FILE
                        Test instructions file (default:
                        data/instructions_test.json)
  --object-size-file OBJECT_SIZE_FILE
                        Object size file (default: data/object_sizes.txt)
  --lr LR               learning rate (default: 0.001)
  --gamma G             discount factor for rewards (default: 0.99)
  --tau T               parameter for GAE (default: 1.00)
  --seed S              random seed (default: 1)
  -n N, --num-processes N
                        how many training processes to use (default: 4)
  --num-steps NS        number of forward steps in A3C (default: 20)
  --load LOAD           model path to load, 0 to not reload (default: 0)
  -e EVALUATE, --evaluate EVALUATE
                        0:Train, 1:Evaluate MultiTask Generalization
                        2:Evaluate Zero-shot Generalization (default: 0)
  --dump-location DUMP_LOCATION
                        path to dump models and log (default: ./saved/)

Demostration videos:

Multitask Generalization video: https://www.youtube.com/watch?v=YJG8fwkv7gA

Zero-shot Task Generalization video: https://www.youtube.com/watch?v=JziCKsLrudE

Different stages of training: https://www.youtube.com/watch?v=o_G6was03N0

Cite as

Chaplot, D.S., Sathyendra, K.M., Pasumarthi, R.K., Rajagopal, D. and Salakhutdinov, R., 2017. Gated-Attention Architectures for Task-Oriented Language Grounding. arXiv preprint arXiv:1706.07230. (PDF)

Bibtex:

@article{chaplot2017gated,
  title={Gated-Attention Architectures for Task-Oriented Language Grounding},
  author={Chaplot, Devendra Singh and Sathyendra, Kanthashree Mysore and Pasumarthi, Rama Kumar and Rajagopal, Dheeraj and Salakhutdinov, Ruslan},
  journal={arXiv preprint arXiv:1706.07230},
  year={2017}
}

Acknowledgements

This repository uses ViZDoom API (https://github.com/mwydmuch/ViZDoom) and parts of the code from the API. The implementation of A3C is borrowed from https://github.com/ikostrikov/pytorch-a3c. The poisson-disc code is borrowed from https://github.com/IHautaI/poisson-disc.

Owner
Devendra Chaplot
Ph.D. student in Machine Learning Dept., School of Computer Science, CMU.
Devendra Chaplot
Official implementation of the Neurips 2021 paper Searching Parameterized AP Loss for Object Detection.

Parameterized AP Loss By Chenxin Tao, Zizhang Li, Xizhou Zhu, Gao Huang, Yong Liu, Jifeng Dai This is the official implementation of the Neurips 2021

46 Jul 06, 2022
Official repository for the paper "Going Beyond Linear Transformers with Recurrent Fast Weight Programmers"

Recurrent Fast Weight Programmers This is the official repository containing the code we used to produce the experimental results reported in the pape

IDSIA 36 Nov 15, 2022
Deep GPs built on top of TensorFlow/Keras and GPflow

GPflux Documentation | Tutorials | API reference | Slack What does GPflux do? GPflux is a toolbox dedicated to Deep Gaussian processes (DGP), the hier

Secondmind Labs 107 Nov 02, 2022
Optimal Camera Position for a Practical Application of Gaze Estimation on Edge Devices,

Optimal Camera Position for a Practical Application of Gaze Estimation on Edge Devices, Linh Van Ma, Tin Trung Tran, Moongu Jeon, ICAIIC 2022 (The 4th

Linh 11 Oct 10, 2022
Robocop is your personal mini voice assistant made using Python.

Robocop-VoiceAssistant To use this project, you should have python installed in your system. If you don't have python installed, install it beforehand

Sohil Khanduja 3 Feb 26, 2022
This repository stores the code to reproduce the results published in "TiWS-iForest: Isolation Forest in Weakly Supervised and Tiny ML scenarios"

TinyWeaklyIsolationForest This repository stores the code to reproduce the results published in "TiWS-iForest: Isolation Forest in Weakly Supervised a

2 Mar 21, 2022
Liecasadi - liecasadi implements Lie groups operation written in CasADi

liecasadi liecasadi implements Lie groups operation written in CasADi, mainly di

Artificial and Mechanical Intelligence 14 Nov 05, 2022
LibFewShot: A Comprehensive Library for Few-shot Learning.

LibFewShot Make few-shot learning easy. Supported Methods Meta MAML(ICML'17) ANIL(ICLR'20) R2D2(ICLR'19) Versa(NeurIPS'18) LEO(ICLR'19) MTL(CVPR'19) M

<a href=[email protected]&L"> 603 Jan 05, 2023
Source code of CIKM2021 Long Paper "PSSL: Self-supervised Learning for Personalized Search with Contrastive Sampling".

PSSL Source code of CIKM2021 Long Paper "PSSL: Self-supervised Learning for Personalized Search with Contrastive Sampling". It consists of the pre-tra

2 Dec 21, 2021
Conditional Generative Adversarial Networks (CGAN) for Mobility Data Fusion

This code implements the paper, Kim et al. (2021). Imputing Qualitative Attributes for Trip Chains Extracted from Smart Card Data Using a Conditional Generative Adversarial Network. Transportation Re

Eui-Jin Kim 2 Feb 03, 2022
Heart Arrhythmia Classification

This program takes and input of an ECG in European Data Format (EDF) and outputs the classification for heartbeats into normal vs different types of arrhythmia . It uses a deep learning model for cla

4 Nov 02, 2022
Implementation of Bidirectional Recurrent Independent Mechanisms (Learning to Combine Top-Down and Bottom-Up Signals in Recurrent Neural Networks with Attention over Modules)

BRIMs Bidirectional Recurrent Independent Mechanisms Implementation of the paper Learning to Combine Top-Down and Bottom-Up Signals in Recurrent Neura

Sarthak Mittal 26 May 26, 2022
An efficient implementation of GPNN

Efficient-GPNN An efficient implementation of GPNN as depicted in "Drop the GAN: In Defense of Patches Nearest Neighbors as Single Image Generative Mo

7 Apr 16, 2022
This repository implements Douzero's interface to IGCA.

douzero-interface-for-ICGA This repository implements Douzero's interface to ICGA. ./douzero: This directory stores Doudizhu AI projects. ./interface:

zhanggenjin 4 Aug 07, 2022
NCNN implementation of Real-ESRGAN. Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration.

NCNN implementation of Real-ESRGAN. Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration.

Xintao 593 Jan 03, 2023
a project for 3D multi-object tracking

a project for 3D multi-object tracking

155 Jan 04, 2023
Sequential GCN for Active Learning

Sequential GCN for Active Learning Please cite if using the code: Link to paper. Requirements: python 3.6+ torch 1.0+ pip libraries: tqdm, sklearn, sc

45 Dec 26, 2022
A PyTorch Image-Classification With AlexNet And ResNet50.

PyTorch 图像分类 依赖库的下载与安装 在终端中执行 pip install -r -requirements.txt 完成项目依赖库的安装 使用方式 数据集的准备 STL10 数据集 下载:STL-10 Dataset 存储位置:将下载后的数据集中 train_X.bin,train_y.b

FYH 4 Feb 22, 2022
PyTorch implementation of MSBG hearing loss model and MBSTOI intelligibility metric

PyTorch implementation of MSBG hearing loss model and MBSTOI intelligibility metric This repository contains the implementation of MSBG hearing loss m

BUT <a href=[email protected]"> 9 Nov 08, 2022
An MQA (Studio, originalSampleRate) identifier for lossless flac files written in Python.

An MQA (Studio, originalSampleRate) identifier for "lossless" flac files written in Python.

Daniel 10 Oct 03, 2022