Decision Transformer: A brand new Offline RL Pattern

Overview

DecisionTransformer_StepbyStep

Intro

Decision Transformer: A brand new Offline RL Pattern.

这是关于NeurIPS 2021 热门论文Decision Transformer的复现。

👍 原文地址: Decision Transformer: Reinforcement Learning via Sequence Modeling

👍 官方的Git仓库: decision-transformer(official)

Decision Transformer

Decision Transformer属于Offline RL,所谓Offline RL,即从次优数据中学习策略来分配Agent,即从固定、有限的经验中产生最大有效的行为。

👀️ Motivation

DT将RL看成一个序列建模问题(Sequence Modeling Problem ),不用传统RL方法,而使用网络直接输出动作进行决策。传统RL方法存在一些问题,比如估计未来Return过程中Bootstrapping过程会导致Overestimate; 马尔可夫假设;

DT借助了Transformer的强大表征能力和时序建模能力。

  • Decision Transformer的表现达到甚至超过了目前最好的基于dynamic programming的主流方法;
  • 在一些需要long-term credit assignment的task【例如sparse reward或者delayed reward等】,Decision Transformer的表现远超过了最好的主流方法.

🚀️ DT的核心思想

image.png

Decision Transformer的核心思想; States、Actions、Returns被Fed into Modality-Specific的线性Embedding;并添加了带有时间步信息的positional episodic timestep; 这些Tokens被输入一个GPT架构,使用a causal self-attention mask来预测actions。

🎉️ DT的优势

  1. 无需Markov假设;
  2. 没有使用一个可学习的Value Function作为Training Target;
  3. 利用Transformer的特性,绕过长期信用分配进行“自举bootstrapping”的需要,避免了时序差分学习的“短视”行为;
  4. 可以通过self-attention直接执行信度分配。这与缓慢传播奖励并容易产生干扰信号的 Bellman Backup 相反,可以使 Transformer 在奖励稀少或分散注意力的情况下仍然有效地工作.

Dependencies

1. D4RL ( Dataset for Deep Data-Driven Reinforcement Learning )

2. MUJOCO 210

# 安装之前先安装absl-py和matplotlib 
pip install absl-py 
pip install matplotlib 

"""
git clone https://github.com/rail-berkeley/d4rl.git
cd d4rl
pip install -e . # 这种方法不好使 !! 
"""

#首先在https://github.com/deepmind/dm_control这个库git clone
# cd
pip install -r requirement.txt 
# 然后 
pip install matplotlib 
# 然后 https://github.com/takuseno/d3rlpy 
pip install d3rlpy 
# 然后安装mujoco 210  
# 直接安装,然后添加环境变量 
# 装完之后进d4rl文件夹下
python setup.py install 
# 成功安装 d4rl 1.1 

3. GPT-2


pip install transformers

Experiments

Group1: Decision Transformer — Hopper-v3-Medium-Dataset

参数Config

class Config:
    env = "hopper"
    dataset = "medium"
    mode = "normal" # "delayed" : all rewards moved to end of trajectory
    device = 'cuda'
    log_dir = 'TB_log/'
    record_algo = 'DT_Hopper_v1'
    test_cycles = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

    # 模型
    model_type = "DT"
    activation_function = 'relu'

    # Scalar
    max_length = 20 # max_len # K
    pct_traj = 1.
    batch_size = 64
    embed_dim = 128
    n_layer = 3
    n_head = 1
    dropout = 0.1
    lr = 1e-4
    wd = 1e-4
    warmup_steps = 1000
    num_eval_episodes = 100
    max_iters = 50
    num_steps_per_iter = 1000

    # Bool
    log_to_tb = True

效果

image.png

Owner
Irving
Irving
PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending"

Bridging the Visual Gap: Wide-Range Image Blending PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending".

Chia-Ni Lu 69 Dec 20, 2022
This is a classifier which basically predicts whether there is a gun law in a state or not, depending on various things like murder rates etc.

Gun-Laws-Classifier This is a classifier which basically predicts whether there is a gun law in a state or not, depending on various things like murde

Awais Saleem 1 Jan 20, 2022
Deep Networks with Recurrent Layer Aggregation

RLA-Net: Recurrent Layer Aggregation Recurrence along Depth: Deep Networks with Recurrent Layer Aggregation This is an implementation of RLA-Net (acce

Joy Fang 21 Aug 16, 2022
Pytorch code for our paper Beyond ImageNet Attack: Towards Crafting Adversarial Examples for Black-box Domains)

Beyond ImageNet Attack: Towards Crafting Adversarial Examples for Black-box Domains (ICLR'2022) This is the Pytorch code for our paper Beyond ImageNet

Alibaba-AAIG 37 Nov 23, 2022
Official PyTorch implementation of "Improving Face Recognition with Large AgeGaps by Learning to Distinguish Children" (BMVC 2021)

Inter-Prototype (BMVC 2021): Official Project Webpage This repository provides the official PyTorch implementation of the following paper: Improving F

Jungsoo Lee 16 Jun 30, 2022
An introduction to bioimage analysis - http://bioimagebook.github.io

Introduction to Bioimage Analysis This book tries explain the main ideas of image analysis in a practical and engaging way. It's written primarily for

Bioimage Book 20 Nov 28, 2022
Super Resolution for images using deep learning.

Neural Enhance Example #1 — Old Station: view comparison in 24-bit HD, original photo CC-BY-SA @siv-athens. As seen on TV! What if you could increase

Alex J. Champandard 11.7k Dec 29, 2022
Code for the Image similarity challenge.

ISC 2021 This repository contains code for the Image Similarity Challenge 2021. Getting started The docs subdirectory has step-by-step instructions on

Facebook Research 173 Dec 12, 2022
Repository for Traffic Accident Benchmark for Causality Recognition (ECCV 2020)

Causality In Traffic Accident (Under Construction) Repository for Traffic Accident Benchmark for Causality Recognition (ECCV 2020) Overview Data Prepa

Tackgeun 21 Nov 20, 2022
Voice Conversion Using Speech-to-Speech Neuro-Style Transfer

This repo contains the official implementation of the VAE-GAN from the INTERSPEECH 2020 paper Voice Conversion Using Speech-to-Speech Neuro-Style Transfer.

Ehab AlBadawy 93 Jan 05, 2023
This is a repository with the code for the ACL 2019 paper

The Story of Heads This is the official repo for the following papers: (ACL 2019) Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy

231 Nov 15, 2022
✂️ EyeLipCropper is a Python tool to crop eyes and mouth ROIs of the given video.

EyeLipCropper EyeLipCropper is a Python tool to crop eyes and mouth ROIs of the given video. The whole process consists of three parts: frame extracti

Zi-Han Liu 9 Oct 25, 2022
A Sign Language detection project using Mediapipe landmark detection and Tensorflow LSTM's

sign-language-detection A Sign Language detection project using Mediapipe landmark detection and Tensorflow LSTM. The project is built for a vocabular

Hashim 4 Feb 06, 2022
CHERRY is a python library for predicting the interactions between viral and prokaryotic genomes

CHERRY is a python library for predicting the interactions between viral and prokaryotic genomes. CHERRY is based on a deep learning model, which consists of a graph convolutional encoder and a link

Kenneth Shang 12 Dec 15, 2022
A full-fledged version of Pix2Seq

Stable-Pix2Seq A full-fledged version of Pix2Seq What it is. This is a full-fledged version of Pix2Seq. Compared with unofficial-pix2seq, stable-pix2s

peng gao 205 Dec 27, 2022
Image-to-Image Translation in PyTorch

CycleGAN and pix2pix in PyTorch New: Please check out contrastive-unpaired-translation (CUT), our new unpaired image-to-image translation model that e

Jun-Yan Zhu 19k Jan 07, 2023
Implementation of GGB color space

GGB Color Space This package is implementation of GGB color space from Development of a Robust Algorithm for Detection of Nuclei and Classification of

Resha Dwika Hefni Al-Fahsi 2 Oct 06, 2021
Neural Nano-Optics for High-quality Thin Lens Imaging

Neural Nano-Optics for High-quality Thin Lens Imaging Project Page | Paper | Data Ethan Tseng, Shane Colburn, James Whitehead, Luocheng Huang, Seung-H

Ethan Tseng 39 Dec 05, 2022
Implement slightly different caffe-segnet in tensorflow

Tensorflow-SegNet Implement slightly different (see below for detail) SegNet in tensorflow, successfully trained segnet-basic in CamVid dataset. Due t

Tseng Kuan Lun 364 Oct 27, 2022
Huawei Hackathon 2021 - Sweden (Stockholm)

huawei-hackathon-2021 Contributors DrakeAxelrod Challenge Requirements: python=3.8.10 Standard libraries (no importing) Important factors: Data depend

Drake Axelrod 32 Nov 08, 2022