Pytorch implementation of Value Iteration Networks (NIPS 2016 best paper)

Overview

VIN: Value Iteration Networks

Architecture of Value Iteration Network

A quick thank you

A few others have released amazing related work which helped inspire and improve my own implementation. It goes without saying that this release would not be nearly as good if it were not for all of the following:

Why another VIN implementation?

  1. The Pytorch VIN model in this repository is, in my opinion, more readable and closer to the original Theano implementation than others I have found (both Tensorflow and Pytorch).
  2. This is not simply an implementation of the VIN model in Pytorch, it is also a full Python implementation of the gridworld environments as used in the original MATLAB implementation.
  3. Provide a more extensible research base for others to build off of without needing to jump through the possible MATLAB paywall.

Installation

This repository requires following packages:

Use pip to install the necessary dependencies:

pip install -U -r requirements.txt 

Note that PyTorch cannot be installed directly from PyPI; refer to http://pytorch.org/ for custom installation instructions specific to your needs.

How to train

8x8 gridworld

python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128

16x16 gridworld

python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128

28x28 gridworld

python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --epochs 30 --k 36 --batch_size 128

Flags:

  • datafile: The path to the data files.
  • imsize: The size of input images. One of: [8, 16, 28]
  • lr: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]
  • epochs: Number of epochs to train. Default: 30
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • l_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • l_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • l_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
  • batch_size: Batch size. Default: 128

How to test / visualize paths (requires training first)

8x8 gridworld

python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10

16x16 gridworld

python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20

28x28 gridworld

python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36

To visualize the optimal and predicted paths simply pass:

--plot

Flags:

  • weights: Path to trained weights.
  • imsize: The size of input images. One of: [8, 16, 28]
  • plot: If supplied, the optimal and predicted paths will be plotted
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • l_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • l_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • l_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.

Results

Gridworld Sample One Sample Two
8x8
16x16
28x28

Datasets

Each data sample consists of an obstacle image and a goal image followed by the (x, y) coordinates of current state in the gridworld.

Dataset size 8x8 16x16 28x28
Train set 81337 456309 1529584
Test set 13846 77203 251755

The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the dataset/make_training_data.py script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D)

Performance: Success Rate

This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains).

Success Rate 8x8 16x16 28x28
PyTorch 99.69% 96.99% 91.07%

Performance: Test Accuracy

NOTE: This is the accuracy on test set. It is different from the table in the paper, which indicates the success rate from rollouts of the learned policy in the environment.

Test Accuracy 8x8 16x16 28x28
PyTorch 99.83% 94.84% 88.54%
Comments
  • testing accuracy fairly low

    testing accuracy fairly low

    I just tried to follow the instructions in the repo, and tested models trained but got a fairly low accuracy. I'm using pyTorch 0.1.12_1. Is there anything I should pay attention to?

    opened by xinleipan 10
  • Prebuilt Dataset Generation

    Prebuilt Dataset Generation

    Hello,

    I was wondering how you generated the prebuilt datasets that are downloaded when running download_weights_and_datasets.sh, i.e. what were the max_obs and max_obs_size parameters?

    Did you follow this file in the original repo? https://github.com/avivt/VIN/blob/master/scripts/make_data_gridworld_nips.m

    Thanks, Emilio

    opened by eparisotto 5
  • the rollout accuracy in test script is lower than the test accuracy in train script.

    the rollout accuracy in test script is lower than the test accuracy in train script.

    Hello!

    I have a little doubt.Does the rollout accuracy indicate the success rate? If so, why is it lower than the prediction accuracy? In the Aviv's implementation, the success rate of the 8x8 grid world was as high as 99.6%. Why is the success rate in your experiment relatively low?

    Thanks!

    opened by albzni 4
  • RUN ERROR

    RUN ERROR

    when I run 'python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128', it's ok,but again 'python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128' was run, an error occurred as follows: [email protected]:~/pytorch-value-iteration-networks$ python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 10 --k 20 --batch_size 128 Traceback (most recent call last): File "train.py", line 135, in config.datafile, imsize=config.imsize, train=True, transform=transform) File "/home/ni/pytorch-value-iteration-networks/dataset/dataset.py", line 22, in init self._process(file, self.train) File "/home/ni/pytorch-value-iteration-networks/dataset/dataset.py", line 58, in _process images = images.astype(np.float32) MemoryError

    opened by N-Kingsley 3
  • Problem of running the test script

    Problem of running the test script

    Hello,

    I downloaded the data with the .sh downloading script you provided, I also got an nps weights file after training. When I ran the testing command I got the following error: Traceback (most recent call last): File "/home/research/DL/VIN/pytorch-value-iteration-networks/test.py", line 158, in main(config) File "/home/research/DL/VIN/pytorch-value-iteration-networks/test.py", line 85, in main _, predictions = vin(X_in, S1_in, S2_in, config) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 357, in call result = self.forward(*input, **kwargs) File "/home/research/DL/VIN/pytorch-value-iteration-networks/model.py", line 64, in forward return logits, self.sm(logits) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 352, in call for hook in self._forward_pre_hooks.values(): File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 398, in getattr type(self).name, name)) AttributeError: 'Softmax' object has no attribute '_forward_pre_hooks'

    Thanks for helping!

    opened by YantianZha 3
  • Improved readability of the VIN model, in addition to minor changes

    Improved readability of the VIN model, in addition to minor changes

    My main modification is in the forward method of the model where you extract the q_out from the q values, and not repeating q = F.conv2d(...) in two places. I also made minor improvements, such as adding argparse in the dataset creation script and changing .cuda() into .to(device) in test.py.

    opened by shuishida 2
  • Inconsistent tensor sizes when starting training

    Inconsistent tensor sizes when starting training

    Hey there. I'm trying to run

    python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128
    

    But I get the following error

    Number of Train Samples: 103926
    Number of Test Samples: 17434
         Epoch | Train Loss | Train Error | Epoch Time
    Traceback (most recent call last):
      File "train.py", line 147, in <module>
        train(net, trainloader, config, criterion, optimizer, use_GPU)
      File "train.py", line 40, in train
        outputs, predictions = net(X, S1, S2, config)
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 224, in __call__
        result = self.forward(*input, **kwargs)
      File "/media/user_home2/j1k1000o/j1k/VINs/pytorch-value-iteration-networks/model.py", line 44, in forward
        q = F.conv2d(torch.cat([r, v], 1), 
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py", line 897, in cat
        return Concat.apply(dim, *iterable)
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/autograd/_functions/tensor.py", line 317, in forward
        return torch.cat(inputs, dim)
    RuntimeError: inconsistent tensor sizes at /opt/conda/conda-bld/pytorch_1502009910772/work/torch/lib/THC/generic/THCTensorMath.cu:141
    

    I've executed

    ./download_weights_and_datasets.sh
    

    as well as

    python ./dataset/make_training_data.py
    

    And I'm running it on an Ubuntu 16.04, python 3.6 and with all the requirements installed.

    Can you help me out?

    opened by juancprzs 2
  • Don't understand VIN last step

    Don't understand VIN last step

        slice_s1 = S1.long().expand(config.imsize, 1, config.l_q, q.size(0))
        slice_s1 = slice_s1.permute(3, 2, 1, 0)
        q_out = q.gather(2, slice_s1).squeeze(2)
    

    What does this 3 lines do?

    opened by QiXuanWang 1
  • KeyError: 'arr_1 is not a file in the archive'

    KeyError: 'arr_1 is not a file in the archive'

    python3 train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128 Traceback (most recent call last): File "train.py", line 135, in config.datafile, imsize=config.imsize, train=True, transform=transform) File "/home/user/pytorch/tutorials/valueiterationnetworks/pytorch-value-iteration-networks/dataset/dataset.py", line 22, in init self._process(file, self.train) File "/home/user/pytorch/tutorials/valueiterationnetworks/pytorch-value-iteration-networks/dataset/dataset.py", line 49, in _process S1 = f['arr_1'] File "/home/user/miniconda3/lib/python3.6/site-packages/numpy/lib/npyio.py", line 255, in getitem raise KeyError("%s is not a file in the archive" % key) KeyError: 'arr_1 is not a file in the archive'

    I got this error, could you please

    opened by derelearnro 1
  • Problem of running dataset/make_training_data.py script

    Problem of running dataset/make_training_data.py script

    Hi

    When I tried to run the make_training_data.py script to generate the gridworld.npz file, I got the following error:

    FileNotFoundError: [Errno 2] No such file or directory: 'dataset/gridworld_28x28.npz'
    

    And I found that line 101 should be modified as follows:

    save_path = "gridworld_{0}x{1}".format(dom_size[0], dom_size[1])
    
    opened by ruqing00 0
Owner
Kent Sommer
Software Engineer @ Toyota Research Institute (SF Bay Area)
Kent Sommer
BERT model training impelmentation using 1024 A100 GPUs for MLPerf Training v1.1

Pre-trained checkpoint and bert config json file Location of checkpoint and bert config json file This MLCommons members Google Drive location contain

SAIT (Samsung Advanced Institute of Technology) 12 Apr 27, 2022
'A C2C E-COMMERCE TRUST MODEL BASED ON REPUTATION' Python implementation

Project description A library providing functionalities to calculate reputation and degree of trust on C2C ecommerce platforms. The work is fully base

Davide Bigotti 2 Dec 14, 2022
This code is the implementation of the paper "Coherence-Based Distributed Document Representation Learning for Scientific Documents".

Introduction This code is the implementation of the paper "Coherence-Based Distributed Document Representation Learning for Scientific Documents". If

tsc 0 Jan 11, 2022
Contains source code for the winning solution of the xView3 challenge

Winning Solution for xView3 Challenge This repository contains source code and pretrained models for my (Eugene Khvedchenya) solution to xView 3 Chall

Eugene Khvedchenya 51 Dec 30, 2022
Code for the paper "Reinforced Active Learning for Image Segmentation"

Reinforced Active Learning for Image Segmentation (RALIS) Code for the paper Reinforced Active Learning for Image Segmentation Dependencies python 3.6

Arantxa Casanova 79 Dec 19, 2022
ISNAS-DIP: Image Specific Neural Architecture Search for Deep Image Prior [CVPR 2022]

ISNAS-DIP: Image-Specific Neural Architecture Search for Deep Image Prior (CVPR 2022) Metin Ersin Arican*, Ozgur Kara*, Gustav Bredell, Ender Konukogl

Özgür Kara 24 Dec 18, 2022
Distance Encoding for GNN Design

Distance-encoding for GNN design This repository is the official PyTorch implementation of the DEGNN and DEAGNN framework reported in the paper: Dista

172 Nov 08, 2022
Example repository for custom C++/CUDA operators for TorchScript

Custom TorchScript Operators Example This repository contains examples for writing, compiling and using custom TorchScript operators. See here for the

106 Dec 14, 2022
Official Implementation of "LUNAR: Unifying Local Outlier Detection Methods via Graph Neural Networks"

LUNAR Official Implementation of "LUNAR: Unifying Local Outlier Detection Methods via Graph Neural Networks" Adam Goodge, Bryan Hooi, Ng See Kiong and

Adam Goodge 25 Dec 28, 2022
Table-Extractor 表格抽取

(t)able-(ex)tractor 本项目旨在实现pdf表格抽取。 Models 版面分析模块(Yolo) 表格结构抽取(ResNet + Transformer) 文字识别模块(CRNN + CTC Loss) Acknowledgements TableMaster attention-i

2 Jan 15, 2022
Official implementation of "Watermarking Images in Self-Supervised Latent-Spaces"

🔍 Watermarking Images in Self-Supervised Latent-Spaces PyTorch implementation and pretrained models for the paper. For details, see Watermarking Imag

Meta Research 32 Dec 13, 2022
competitions-v2

Codabench (formerly Codalab Competitions v2) Installation $ cp .env_sample .env $ docker-compose up -d $ docker-compose exec django ./manage.py migrat

CodaLab 21 Dec 02, 2022
Reinforcement-learning - Repository of the class assignment questions for the course on reinforcement learning

DSE 314/614: Reinforcement Learning This repository containing reinforcement lea

Manav Mishra 4 Apr 15, 2022
Weighted QMIX: Expanding Monotonic Value Function Factorisation

This repo contains the cleaned-up code that was used in "Weighted QMIX: Expanding Monotonic Value Function Factorisation"

whirl 82 Dec 29, 2022
Distributed Asynchronous Hyperparameter Optimization better than HyperOpt.

UltraOpt : Distributed Asynchronous Hyperparameter Optimization better than HyperOpt. UltraOpt is a simple and efficient library to minimize expensive

98 Aug 16, 2022
Official PyTorch Implementation of paper "NeLF: Neural Light-transport Field for Single Portrait View Synthesis and Relighting", EGSR 2021.

NeLF: Neural Light-transport Field for Single Portrait View Synthesis and Relighting Official PyTorch Implementation of paper "NeLF: Neural Light-tran

Ken Lin 38 Dec 26, 2022
PyTorch implementation for Partially View-aligned Representation Learning with Noise-robust Contrastive Loss (CVPR 2021)

2021-CVPR-MvCLN This repo contains the code and data of the following paper accepted by CVPR 2021 Partially View-aligned Representation Learning with

XLearning Group 33 Nov 01, 2022
Sky Computing: Accelerating Geo-distributed Computing in Federated Learning

Sky Computing Introduction Sky Computing is a load-balanced framework for federated learning model parallelism. It adaptively allocate model layers to

HPC-AI Tech 72 Dec 27, 2022
Official code for article "Expression is enough: Improving traffic signal control with advanced traffic state representation"

1 Introduction Official code for article "Expression is enough: Improving traffic signal control with advanced traffic state representation". The code s

Liang Zhang 10 Dec 10, 2022
Robustness via Cross-Domain Ensembles

Robustness via Cross-Domain Ensembles [ICCV 2021, Oral] This repository contains tools for training and evaluating: Pretrained models Demo code Traini

Visual Intelligence & Learning Lab, Swiss Federal Institute of Technology (EPFL) 27 Dec 23, 2022