Open-source codebase for EfficientZero, from "Mastering Atari Games with Limited Data" at NeurIPS 2021.

Overview

EfficientZero (NeurIPS 2021)

Open-source codebase for EfficientZero, from "Mastering Atari Games with Limited Data" at NeurIPS 2021.

Environments

EfficientZero requires python3 (>=3.6) and pytorch (>=1.8.0) with the development headers.

We recommend to use torch amp (--amp_type torch_amp) to accelerate training.

Prerequisites

Before starting training, you need to build the c++/cython style external packages.

cd core/ctree
bash make.sh

The distributed framework of this codebase is built on ray.

Installation

As for other packages required for this codebase, please run pip install -r requirements.txt.

Usage

Quick start

  • Train: python main.py --env BreakoutNoFrameskip-v4 --case atari --opr train --amp_type torch_amp --num_gpus 1 --num_cpus 10 --cpu_actor 1 --gpu_actor 1 --force
  • Test: python main.py --env BreakoutNoFrameskip-v4 --case atari --opr test --amp_type torch_amp --num_gpus 1 --load_model --model_path model.p \

Bash file

We provide train.sh and test.sh for training and evaluation.

  • Train:
    • With 4 GPUs (3090): bash train.sh
  • Test: bash test.sh
Required Arguments Description
--env Name of the environment
--case {atari} It's used for switching between different domains(default: atari)
--opr {train,test} select the operation to be performed
--amp_type {torch_amp,none} use torch amp for acceleration
Other Arguments Description
--force will rewrite the result directory
--num_gpus 4 how many GPUs are available
--num_cpus 96 how many CPUs are available
--cpu_actor 14 how many cpu workers
--gpu_actor 20 how many gpu workers
--seed 0 the seed
--use_priority use priority in replay buffer sampling
--use_max_priority use the max priority for the newly collectted data
--amp_type 'torch_amp' use torch amp for acceleration
--info 'EZ-V0' some tags for you experiments
--p_mcts_num 8 set the parallel number of envs in self-play
--revisit_policy_search_rate 0.99 set the rate of reanalyzing policies
--use_root_value use root values in value targets (require more GPU actors)
--render render in evaluation
--save_video save videos for evaluation

Architecture Designs

The architecture of the training pipeline is shown as follows:

Some suggestions

  • To use a smaller model, you can choose smaller dim of the projection layers (Eg: 256/64) and the LSTM hidden layer (Eg: 64) in the config.
  • For GPUs with 10G memory instead of 20G memory, you can allocate 0.25 gpu for each GPU maker (@ray.remote(num_gpus=0.25)) in core/reanalyze_worker.py.

New environment registration

If you wan to apply EfficientZero to a new environment like mujoco. Here are the steps for registration:

  1. Follow the directory config/atari and create dir for the env at config/mujoco.
  2. Implement your MujocoConfig(BaseConfig) class and implement the models as well as your environment wrapper.
  3. Register the case at main.py.

Results

Evaluation with 32 seeds for 3 different runs (different seeds).

Citation

If you find this repo useful, please cite our paper:

@inproceedings{ye2021mastering,
  title={Mastering Atari Games with Limited Data},
  author={Weirui Ye, and Shaohuai Liu, and Thanard Kurutach, and Pieter Abbeel, and Yang Gao},
  booktitle={NeurIPS},
  year={2021}
}

Contact

If you have any question or want to use the code, please contact [email protected] .

Acknowledgement

We appreciate the following github repos a lot for their valuable code base implementations:

https://github.com/koulanurag/muzero-pytorch

https://github.com/werner-duvaud/muzero-general

https://github.com/pytorch/ELF

Owner
Weirui Ye
Weirui Ye
The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

ISC-Track2-Submission The codes and related files to reproduce the results for Image Similarity Challenge Track 2. Required dependencies To begin with

Wenhao Wang 89 Jan 02, 2023
Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems

Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems This is our experimental code for RecSys 2021 paper "Learning

11 Jul 28, 2022
StocksMA is a package to facilitate access to financial and economic data of Moroccan stocks.

Creating easier access to the Moroccan stock market data What is StocksMA ? StocksMA is a package to facilitate access to financial and economic data

Salah Eddine LABIAD 28 Jan 04, 2023
Kaggle Lyft Motion Prediction for Autonomous Vehicles 4th place solution

Lyft Motion Prediction for Autonomous Vehicles Code for the 4th place solution of Lyft Motion Prediction for Autonomous Vehicles on Kaggle. Discussion

44 Jun 27, 2022
Kohei's 5th place solution for xview3 challenge

xview3-kohei-solution Usage This repository assumes that the given data set is stored in the following locations: $ ls data/input/xview3/*.csv data/in

Kohei Ozaki 2 Jan 17, 2022
Codes accompanying the paper "Believe What You See: Implicit Constraint Approach for Offline Multi-Agent Reinforcement Learning" (NeurIPS 2021 Spotlight

Implicit Constraint Q-Learning This is a pytorch implementation of ICQ on Datasets for Deep Data-Driven Reinforcement Learning (D4RL) and ICQ-MA on SM

42 Dec 23, 2022
The code for replicating the experiments from the LFI in SSMs with Unknown Dynamics paper.

Likelihood-Free Inference in State-Space Models with Unknown Dynamics This package contains the codes required to run the experiments in the paper. Th

Alex Aushev 0 Dec 27, 2021
Parameter-ensemble-differential-evolution - Shows how to do parameter ensembling using differential evolution.

Ensembling parameters with differential evolution This repository shows how to ensemble parameters of two trained neural networks using differential e

Sayak Paul 9 May 04, 2022
Multi-Glimpse Network With Python

Multi-Glimpse Network Multi-Glimpse Network: A Robust and Efficient Classification Architecture based on Recurrent Downsampled Attention arXiv Require

9 May 10, 2022
Deep Learning tutorials in jupyter notebooks.

DeepSchool.io Sign up here for Udemy Course on Machine Learning (Use code DEEPSCHOOL-MARCH to get 85% off course). Goals Make Deep Learning easier (mi

Sachin Abeywardana 1.8k Dec 28, 2022
The repository offers the official implementation of our paper in PyTorch.

Cloth Interactive Transformer (CIT) Cloth Interactive Transformer for Virtual Try-On Bin Ren1, Hao Tang1, Fanyang Meng2, Runwei Ding3, Ling Shao4, Phi

Bingoren 49 Dec 01, 2022
No-Reference Image Quality Assessment via Transformers, Relative Ranking, and Self-Consistency

This repository contains the implementation for the paper: No-Reference Image Quality Assessment via Transformers, Relative Ranking, and Self-Consiste

Alireza Golestaneh 75 Dec 30, 2022
SpinalNet: Deep Neural Network with Gradual Input

SpinalNet: Deep Neural Network with Gradual Input This repository contains scripts for training different variations of the SpinalNet and its counterp

H M Dipu Kabir 142 Dec 30, 2022
Code basis for the paper "Camera Condition Monitoring and Readjustment by means of Noise and Blur" (2021)

Camera Condition Monitoring and Readjustment by means of Noise and Blur This repository contains the source code of the paper: Wischow, M., Gallego, G

7 Dec 22, 2022
Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction".

TGIN Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction". Files in the folder dataset/ electr

Alibaba 21 Dec 21, 2022
Fully Convolutional DenseNets for semantic segmentation.

Introduction This repo contains the code to train and evaluate FC-DenseNets as described in The One Hundred Layers Tiramisu: Fully Convolutional Dense

485 Nov 26, 2022
Codes for the AAAI'22 paper "TransZero: Attribute-guided Transformer for Zero-Shot Learning"

TransZero [arXiv] This repository contains the testing code for the paper "TransZero: Attribute-guided Transformer for Zero-Shot Learning" accepted to

Shiming Chen 52 Jan 01, 2023
Project ArXiv Citation Network

Project ArXiv Citation Network Overview This project involved the analysis of the ArXiv citation network. Usage The complete code of this project is i

Dennis Núñez-Fernández 5 Oct 20, 2022
PyTorch implementation of probabilistic deep forecast applied to air quality.

Probabilistic Deep Forecast PyTorch implementation of a paper, titled: Probabilistic Deep Learning to Quantify Uncertainty in Air Quality Forecasting

Abdulmajid Murad 13 Nov 16, 2022
Plotting points that lie on the intersection of the given curves using gradient descent.

Plotting intersection of curves using gradient descent Webapp Link --- What's the app about Why this app Plotting functions and their intersection. A

Divakar Verma 2 Jan 09, 2022