Decision Transformer: A brand new Offline RL Pattern

Overview

DecisionTransformer_StepbyStep

Intro

Decision Transformer: A brand new Offline RL Pattern.

这是关于NeurIPS 2021 热门论文Decision Transformer的复现。

👍 原文地址: Decision Transformer: Reinforcement Learning via Sequence Modeling

👍 官方的Git仓库: decision-transformer(official)

Decision Transformer

Decision Transformer属于Offline RL,所谓Offline RL,即从次优数据中学习策略来分配Agent,即从固定、有限的经验中产生最大有效的行为。

👀️ Motivation

DT将RL看成一个序列建模问题(Sequence Modeling Problem ),不用传统RL方法,而使用网络直接输出动作进行决策。传统RL方法存在一些问题,比如估计未来Return过程中Bootstrapping过程会导致Overestimate; 马尔可夫假设;

DT借助了Transformer的强大表征能力和时序建模能力。

  • Decision Transformer的表现达到甚至超过了目前最好的基于dynamic programming的主流方法;
  • 在一些需要long-term credit assignment的task【例如sparse reward或者delayed reward等】,Decision Transformer的表现远超过了最好的主流方法.

🚀️ DT的核心思想

image.png

Decision Transformer的核心思想; States、Actions、Returns被Fed into Modality-Specific的线性Embedding;并添加了带有时间步信息的positional episodic timestep; 这些Tokens被输入一个GPT架构,使用a causal self-attention mask来预测actions。

🎉️ DT的优势

  1. 无需Markov假设;
  2. 没有使用一个可学习的Value Function作为Training Target;
  3. 利用Transformer的特性,绕过长期信用分配进行“自举bootstrapping”的需要,避免了时序差分学习的“短视”行为;
  4. 可以通过self-attention直接执行信度分配。这与缓慢传播奖励并容易产生干扰信号的 Bellman Backup 相反,可以使 Transformer 在奖励稀少或分散注意力的情况下仍然有效地工作.

Dependencies

1. D4RL ( Dataset for Deep Data-Driven Reinforcement Learning )

2. MUJOCO 210

# 安装之前先安装absl-py和matplotlib 
pip install absl-py 
pip install matplotlib 

"""
git clone https://github.com/rail-berkeley/d4rl.git
cd d4rl
pip install -e . # 这种方法不好使 !! 
"""

#首先在https://github.com/deepmind/dm_control这个库git clone
# cd
pip install -r requirement.txt 
# 然后 
pip install matplotlib 
# 然后 https://github.com/takuseno/d3rlpy 
pip install d3rlpy 
# 然后安装mujoco 210  
# 直接安装,然后添加环境变量 
# 装完之后进d4rl文件夹下
python setup.py install 
# 成功安装 d4rl 1.1 

3. GPT-2


pip install transformers

Experiments

Group1: Decision Transformer — Hopper-v3-Medium-Dataset

参数Config

class Config:
    env = "hopper"
    dataset = "medium"
    mode = "normal" # "delayed" : all rewards moved to end of trajectory
    device = 'cuda'
    log_dir = 'TB_log/'
    record_algo = 'DT_Hopper_v1'
    test_cycles = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

    # 模型
    model_type = "DT"
    activation_function = 'relu'

    # Scalar
    max_length = 20 # max_len # K
    pct_traj = 1.
    batch_size = 64
    embed_dim = 128
    n_layer = 3
    n_head = 1
    dropout = 0.1
    lr = 1e-4
    wd = 1e-4
    warmup_steps = 1000
    num_eval_episodes = 100
    max_iters = 50
    num_steps_per_iter = 1000

    # Bool
    log_to_tb = True

效果

image.png

Owner
Irving
Irving
git《FSCE: Few-Shot Object Detection via Contrastive Proposal Encoding》(CVPR 2021) GitHub: [fig8]

FSCE: Few-Shot Object Detection via Contrastive Proposal Encoding (CVPR 2021) This repo contains the implementation of our state-of-the-art fewshot ob

233 Dec 29, 2022
Jupyter notebooks for the code samples of the book "Deep Learning with Python"

Jupyter notebooks for the code samples of the book "Deep Learning with Python"

François Chollet 16.2k Dec 30, 2022
Boostcamp CV Serving For Python

Boostcamp-CV-Serving Prerequisites MySQL GCP Cloud Storage GCP key file Sentry Streamlit Cloud Secrets: .streamlit/secrets.toml #DO NOT SHARE THIS I

Jungwon Seo 19 Feb 22, 2022
Space Ship Simulator using python

FlyOver Basic space-ship simulator using python How to run? Just double click run.py What modules do i need? All modules that i currently using is bui

0 Oct 09, 2022
code for Fast Point Cloud Registration with Optimal Transport

robot This is the repository for the paper "Accurate Point Cloud Registration with Robust Optimal Transport". We are in the process of refactoring the

28 Jan 04, 2023
StudioGAN is a Pytorch library providing implementations of representative Generative Adversarial Networks (GANs) for conditional/unconditional image generation.

StudioGAN is a Pytorch library providing implementations of representative Generative Adversarial Networks (GANs) for conditional/unconditional image generation.

3k Jan 08, 2023
Garbage classification using structure data.

垃圾分类模型使用说明 1.包含以下数据文件 文件 描述 data/MaterialMapping.csv 物体以及其归类的信息 data/TestRecords 光谱原始测试数据 CSV 文件 data/TestRecordDesc.zip CSV 文件描述文件 data/Boundaries.cs

wenqi 1 Dec 10, 2021
Vehicle Detection Using Deep Learning and YOLO Algorithm

VehicleDetection Vehicle Detection Using Deep Learning and YOLO Algorithm Dataset take or find vehicle images for create a special dataset for fine-tu

Maryam Boneh 96 Jan 05, 2023
GoodNews Everyone! Context driven entity aware captioning for news images

This is the code for a CVPR 2019 paper, called GoodNews Everyone! Context driven entity aware captioning for news images. Enjoy! Model preview: Huge T

117 Dec 19, 2022
An Image compression simulator that uses Source Extractor and Monte Carlo methods to examine the post compressive effects different compression algorithms have.

ImageCompressionSimulation An Image compression simulator that uses Source Extractor and Monte Carlo methods to examine the post compressive effects o

James Park 1 Dec 11, 2021
This repo is about implementing different approaches of pose estimation and also is a sub-task of the smart hospital bed project :smile:

Pose-Estimation This repo is a sub-task of the smart hospital bed project which is about implementing the task of pose estimation 😄 Many thanks to th

Max 11 Oct 17, 2022
NeRViS: Neural Re-rendering for Full-frame Video Stabilization

Neural Re-rendering for Full-frame Video Stabilization

Yu-Lun Liu 9 Jun 17, 2022
RLDS stands for Reinforcement Learning Datasets

RLDS RLDS stands for Reinforcement Learning Datasets and it is an ecosystem of tools to store, retrieve and manipulate episodic data in the context of

Google Research 135 Jan 01, 2023
The pytorch implementation of the paper "text-guided neural image inpainting" at MM'2020

TDANet: Text-Guided Neural Image Inpainting, MM'2020 (Oral) MM | ArXiv This repository implements the paper "Text-Guided Neural Image Inpainting" by L

LisaiZhang 75 Dec 22, 2022
Brain Tumor Detection with Tensorflow Neural Networks.

Brain-Tumor-Detection A convolutional neural network model built with Tensorflow & Keras to detect brain tumor and its different variants. Data of the

404ErrorNotFound 5 Aug 23, 2022
Classify bird species based on their songs using SIamese Networks and 1D dilated convolutions.

The goal is to classify different birds species based on their songs/calls. Spectrograms have been extracted from the audio samples and used as features for classification.

Aditya Dutt 9 Dec 27, 2022
Applying CLIP to Point Cloud Recognition.

PointCLIP: Point Cloud Understanding by CLIP This repository is an official implementation of the paper 'PointCLIP: Point Cloud Understanding by CLIP'

Renrui Zhang 175 Dec 24, 2022
Tensorflow implementation of "BEGAN: Boundary Equilibrium Generative Adversarial Networks"

BEGAN in Tensorflow Tensorflow implementation of BEGAN: Boundary Equilibrium Generative Adversarial Networks. Requirements Python 2.7 or 3.x Pillow tq

Taehoon Kim 922 Dec 21, 2022
Code for the ECIR'22 paper "Evaluating the Robustness of Retrieval Pipelines with Query Variation Generators"

Query Variation Generators This repository contains the code and annotation data for the ECIR'22 paper "Evaluating the Robustness of Retrieval Pipelin

Gustavo Penha 12 Nov 20, 2022
Blender Add-On for slicing meshes with planes

MeshSlicer Blender Add-On for slicing meshes with multiple overlapping planes at once. This is a simple Blender addon to slice a silmple mesh with mul

52 Dec 12, 2022