Deep Reinforcement Learning with pytorch & visdom

Overview

Deep Reinforcement Learning with

pytorch & visdom


  • Sample testings of trained agents (DQN on Breakout, A3C on Pong, DoubleDQN on CartPole, continuous A3C on InvertedPendulum(MuJoCo)):
  • Sample on-line plotting while training an A3C agent on Pong (with 16 learner processes): a3c_pong_plot

  • Sample loggings while training a DQN agent on CartPole (we use WARNING as the logging level currently to get rid of the INFO printouts from visdom):

[WARNING ] (MainProcess) <===================================>
[WARNING ] (MainProcess) bash$: python -m visdom.server
[WARNING ] (MainProcess) http://localhost:8097/env/daim_17040900
[WARNING ] (MainProcess) <===================================> DQN
[WARNING ] (MainProcess) <-----------------------------------> Env
[WARNING ] (MainProcess) Creating {gym | CartPole-v0} w/ Seed: 123
[INFO    ] (MainProcess) Making new env: CartPole-v0
[WARNING ] (MainProcess) Action Space: [0, 1]
[WARNING ] (MainProcess) State  Space: 4
[WARNING ] (MainProcess) <-----------------------------------> Model
[WARNING ] (MainProcess) MlpModel (
  (fc1): Linear (4 -> 16)
  (rl1): ReLU ()
  (fc2): Linear (16 -> 16)
  (rl2): ReLU ()
  (fc3): Linear (16 -> 16)
  (rl3): ReLU ()
  (fc4): Linear (16 -> 2)
)
[WARNING ] (MainProcess) No Pretrained Model. Will Train From Scratch.
[WARNING ] (MainProcess) <===================================> Training ...
[WARNING ] (MainProcess) Validation Data @ Step: 501
[WARNING ] (MainProcess) Start  Training @ Step: 501
[WARNING ] (MainProcess) Reporting       @ Step: 2500 | Elapsed Time: 5.32397913933
[WARNING ] (MainProcess) Training Stats:   epsilon:          0.972
[WARNING ] (MainProcess) Training Stats:   total_reward:     2500.0
[WARNING ] (MainProcess) Training Stats:   avg_reward:       21.7391304348
[WARNING ] (MainProcess) Training Stats:   nepisodes:        115
[WARNING ] (MainProcess) Training Stats:   nepisodes_solved: 114
[WARNING ] (MainProcess) Training Stats:   repisodes_solved: 0.991304347826
[WARNING ] (MainProcess) Evaluating      @ Step: 2500
[WARNING ] (MainProcess) Iteration: 2500; v_avg: 1.73136949539
[WARNING ] (MainProcess) Iteration: 2500; tderr_avg: 0.0964358523488
[WARNING ] (MainProcess) Iteration: 2500; steps_avg: 9.34579439252
[WARNING ] (MainProcess) Iteration: 2500; steps_std: 0.798395631184
[WARNING ] (MainProcess) Iteration: 2500; reward_avg: 9.34579439252
[WARNING ] (MainProcess) Iteration: 2500; reward_std: 0.798395631184
[WARNING ] (MainProcess) Iteration: 2500; nepisodes: 107
[WARNING ] (MainProcess) Iteration: 2500; nepisodes_solved: 106
[WARNING ] (MainProcess) Iteration: 2500; repisodes_solved: 0.990654205607
[WARNING ] (MainProcess) Saving Model    @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth ...
[WARNING ] (MainProcess) Saved  Model    @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth.
[WARNING ] (MainProcess) Resume Training @ Step: 2500
...

What is included?

This repo currently contains the following agents:

  • Deep Q Learning (DQN) [1], [2]
  • Double DQN [3]
  • Dueling network DQN (Dueling DQN) [4]
  • Asynchronous Advantage Actor-Critic (A3C) (w/ both discrete/continuous action space support) [5], [6]
  • Sample Efficient Actor-Critic with Experience Replay (ACER) (currently w/ discrete action space support (Truncated Importance Sampling, 1st Order TRPO)) [7], [8]

Work in progress:

  • Testing ACER

Future Plans:

  • Deep Deterministic Policy Gradient (DDPG) [9], [10]
  • Continuous DQN (CDQN or NAF) [11]

Code structure & Naming conventions:

NOTE: we follow the exact code structure as pytorch-dnc so as to make the code easily transplantable.

  • ./utils/factory.py

We suggest the users refer to ./utils/factory.py, where we list all the integrated Env, Model, Memory, Agent into Dict's. All of those four core classes are implemented in ./core/. The factory pattern in ./utils/factory.py makes the code super clean, as no matter what type of Agent you want to train, or which type of Env you want to train on, all you need to do is to simply modify some parameters in ./utils/options.py, then the ./main.py will do it all (NOTE: this ./main.py file never needs to be modified).

  • namings

To make the code more clean and readable, we name the variables using the following pattern (mainly in inherited Agent's):

  • *_vb: torch.autograd.Variable's or a list of such objects
  • *_ts: torch.Tensor's or a list of such objects
  • otherwise: normal python datatypes

Dependencies


How to run:

You only need to modify some parameters in ./utils/options.py to train a new configuration.

  • Configure your training in ./utils/options.py:
  • line 14: add an entry into CONFIGS to define your training (agent_type, env_type, game, model_type, memory_type)
  • line 33: choose the entry you just added
  • line 29-30: fill in your machine/cluster ID (MACHINE) and timestamp (TIMESTAMP) to define your training signature (MACHINE_TIMESTAMP), the corresponding model file and the log file of this training will be saved under this signature (./models/MACHINE_TIMESTAMP.pth & ./logs/MACHINE_TIMESTAMP.log respectively). Also the visdom visualization will be displayed under this signature (first activate the visdom server by type in bash: python -m visdom.server &, then open this address in your browser: http://localhost:8097/env/MACHINE_TIMESTAMP)
  • line 32: to train a model, set mode=1 (training visualization will be under http://localhost:8097/env/MACHINE_TIMESTAMP); to test the model of this current training, all you need to do is to set mode=2 (testing visualization will be under http://localhost:8097/env/MACHINE_TIMESTAMP_test).
  • Run:

python main.py


Bonus Scripts :)

We also provide 2 additional scripts for quickly evaluating your results after training. (Dependecies: lmj-plot)

  • plot.sh (e.g., plot from log file: logs/machine1_17080801.log)
  • ./plot.sh machine1 17080801
  • the generated figures will be saved into figs/machine1_17080801/
  • plot_compare.sh (e.g., compare log files: logs/machine1_17080801.log,logs/machine2_17080802.log)

./plot.sh 00 machine1 17080801 machine2 17080802

  • the generated figures will be saved into figs/compare_00/
  • the color coding will be in the order of: red green blue magenta yellow cyan

Repos we referred to during the development of this repo:


Citation

If you find this library useful and would like to cite it, the following would be appropriate:

@misc{pytorch-rl,
  author = {Zhang, Jingwei and Tai, Lei},
  title = {jingweiz/pytorch-rl},
  url = {https://github.com/jingweiz/pytorch-rl},
  year = {2017}
}
Owner
Jingwei Zhang
Jingwei Zhang
WSDM2022 "A Simple but Effective Bidirectional Extraction Framework for Relational Triple Extraction"

BiRTE WSDM2022 "A Simple but Effective Bidirectional Extraction Framework for Relational Triple Extraction" Requirements The main requirements are: py

9 Dec 27, 2022
Funnels: Exact maximum likelihood with dimensionality reduction.

Funnels This repository contains the code needed to reproduce the experiments from the paper: Funnels: Exact maximum likelihood with dimensionality re

2 Apr 21, 2022
Code for the paper: Learning Adversarially Robust Representations via Worst-Case Mutual Information Maximization (https://arxiv.org/abs/2002.11798)

Representation Robustness Evaluations Our implementation is based on code from MadryLab's robustness package and Devon Hjelm's Deep InfoMax. For all t

Sicheng 19 Dec 07, 2022
Weakly supervised medical named entity classification

Trove Trove is a research framework for building weakly supervised (bio)medical named entity recognition (NER) and other entity attribute classifiers

60 Nov 18, 2022
Official implementation of NeurIPS'2021 paper TransformerFusion

TransformerFusion: Monocular RGB Scene Reconstruction using Transformers Project Page | Paper | Video TransformerFusion: Monocular RGB Scene Reconstru

Aljaz Bozic 118 Dec 25, 2022
Codes accompanying the paper "Learning Nearly Decomposable Value Functions with Communication Minimization" (ICLR 2020)

NDQ: Learning Nearly Decomposable Value Functions with Communication Minimization Note This codebase accompanies paper Learning Nearly Decomposable Va

Tonghan Wang 69 Nov 26, 2022
TensorFlow implementation of "Attention is all you need (Transformer)"

[TensorFlow 2] Attention is all you need (Transformer) TensorFlow implementation of "Attention is all you need (Transformer)" Dataset The MNIST datase

YeongHyeon Park 4 Jan 05, 2022
Official implementation of "An Image is Worth 16x16 Words, What is a Video Worth?" (2021 paper)

An Image is Worth 16x16 Words, What is a Video Worth? paper Official PyTorch Implementation Gilad Sharir, Asaf Noy, Lihi Zelnik-Manor DAMO Academy, Al

213 Nov 12, 2022
PyTorch(Geometric) implementation of G^2GNN in "Imbalanced Graph Classification via Graph-of-Graph Neural Networks"

This repository is an official PyTorch(Geometric) implementation of G^2GNN in "Imbalanced Graph Classification via Graph-of-Graph Neural Networks". Th

Yu Wang (Jack) 13 Nov 18, 2022
Hierarchical Uniform Manifold Approximation and Projection

HUMAP Hierarchical Manifold Approximation and Projection (HUMAP) is a technique based on UMAP for hierarchical non-linear dimensionality reduction. HU

Wilson Estécio Marcílio Júnior 160 Jan 06, 2023
A light-weight image labelling tool for Python designed for creating segmentation data sets.

An image labelling tool for creating segmentation data sets, for Django and Flask.

117 Nov 21, 2022
ShapeGlot: Learning Language for Shape Differentiation

ShapeGlot: Learning Language for Shape Differentiation Created by Panos Achlioptas, Judy Fan, Robert X.D. Hawkins, Noah D. Goodman, Leonidas J. Guibas

Panos 32 Dec 23, 2022
Image Captioning on google cloud platform based on iot

Image-Captioning-on-google-cloud-platform-based-on-iot - Image Captioning on google cloud platform based on iot

Shweta_kumawat 1 Jan 20, 2022
Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021)

Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021) An efficient PyTorch library for Point Cloud Completion.

Microsoft 119 Jan 02, 2023
This is an official implementation for "ResT: An Efficient Transformer for Visual Recognition".

ResT By Qing-Long Zhang and Yu-Bin Yang [State Key Laboratory for Novel Software Technology at Nanjing University] This repo is the official implement

zhql 222 Dec 13, 2022
Code accompanying the paper "ProxyFL: Decentralized Federated Learning through Proxy Model Sharing"

ProxyFL Code accompanying the paper "ProxyFL: Decentralized Federated Learning through Proxy Model Sharing" Authors: Shivam Kalra*, Junfeng Wen*, Jess

Layer6 Labs 14 Dec 06, 2022
SOLOv2 on onnx & tensorRT

SOLOv2.tensorRT: NOTE: code based on WXinlong/SOLO add support to TensorRT inference onnxruntime tensorRT full_dims and dynamic shape postprocess with

47 Nov 26, 2022
Feedback is important: response-aware feedback mechanism for background based conversation

RFM The code for the paper: "Feedback is important: response-aware feedback mechanism for background based conversation." Requirements python 3.7 pyto

Jiatao Chen 2 Sep 29, 2022
This repository will be a summary and outlook on all our open, medical, AI advancements.

medical by LAION This repository will be a summary and outlook on all our open, medical, AI advancements. See the medical-general channel in the medic

LAION AI 18 Dec 30, 2022
Fine-tune pretrained Convolutional Neural Networks with PyTorch

Fine-tune pretrained Convolutional Neural Networks with PyTorch. Features Gives access to the most popular CNN architectures pretrained on ImageNet. A

Alex Parinov 694 Nov 23, 2022