Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

Related tags

Deep Learningcql-jax
Overview

CQL-JAX

This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on top of the SAC base of JAX-RL.

Usage

Install Dependencies-

pip install -r requirements.txt
pip install "jax[cuda111]<=0.21.1" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Run CQL-

python train_offline.py --env_name=hopper-expert-v0 --min_q_weight=5

Please use the following values of min_q_weight on MuJoCo tasks to reproduce CQL results from IQL paper-

Domain medium medium-replay medium-expert
walker 10 1 10
hopper 5 5 1
cheetah 90 80 100

For antmaze tasks min_q_weight=10 is found to work best.

In case of Out-Of Memory errors in JAX, try running with the following env variables-

XLA_PYTHON_CLIENT_MEM_FRACTION=0.80 python ...
XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 python ...

Performance & Runtime

Returns are more or less same as the torch implementation and comparable to IQL-

Task CQL(PyTorch) CQL(JAX) IQL
hopper-medium-v2 58.5 74.6 66.3
hopper-medium-replay-v2 95.0 92.1 94.7
hopper-medium-expert-v2 105.4 83.2 91.5
antmaze-umaze-v0 74.0 69.5 87.5
antmaze-umaze-diverse-v0 84.0 78.7 62.2
antmaze-medium-play-v0 61.2 14.2 71.2
antmaze-medium-diverse-v0 53.7 10.7 70.2
antmaze-large-play-v0 15.8 0.0 39.6
antmaze-large-diverse-v0 14.9 0.0 47.5

Wall-clock time averages to ~50 mins, improving over IQL paper's 80 min CQL and closing the gap with IQL's 20 min.

Task CQL(JAX) IQL
hopper-medium-v2 52 27
hopper-medium-replay-v2 54 30
hopper-medium-expert-v2 57 29

Time efficiency over the original torch implementation is more than 4 times.

For more offline RL algorithm implementations, check out the JAX-RL, IQL and rlkit repositories.

Citation

In case you use CQL-JAX for your research, please cite the following-

@misc{cqljax,
  author = {Suri, Karush},
  title = {{Conservative Q Learning in JAX.}},
  url = {https://github.com/karush17/cql-jax},
  year = {2021}
}

References

Owner
Karush Suri
Deep Learning Researcher at Huawei Noah's Ark Lab, Toronto.
Karush Suri
CARL provides highly configurable contextual extensions to several well-known RL environments.

CARL (context adaptive RL) provides highly configurable contextual extensions to several well-known RL environments.

AutoML-Freiburg-Hannover 51 Dec 28, 2022
Video Background Music Generation with Controllable Music Transformer (ACM MM 2021 Oral)

CMT Code for paper Video Background Music Generation with Controllable Music Transformer (ACM MM 2021 Best Paper Award) [Paper] [Site] Directory Struc

Zhaokai Wang 198 Dec 27, 2022
PyTorch implementation for 3D human pose estimation

Towards 3D Human Pose Estimation in the Wild: a Weakly-supervised Approach This repository is the PyTorch implementation for the network presented in:

Xingyi Zhou 579 Dec 22, 2022
Official DGL implementation of "Rethinking High-order Graph Convolutional Networks"

SE Aggregation This is the implementation for Rethinking High-order Graph Convolutional Networks. Here we show the codes for citation networks as an e

Tianqi Zhang (张天启) 32 Jul 19, 2022
Keywords : Streamlit, BertTokenizer, BertForMaskedLM, Pytorch

Next Word Prediction Keywords : Streamlit, BertTokenizer, BertForMaskedLM, Pytorch 🎬 Project Demo ✔ Application is hosted on Streamlit. You can see t

Vivek7 3 Aug 26, 2022
[CVPR2021] Domain Consensus Clustering for Universal Domain Adaptation

[CVPR2021] Domain Consensus Clustering for Universal Domain Adaptation [Paper] Prerequisites To install requirements: pip install -r requirements.txt

Guangrui Li 84 Dec 26, 2022
A Keras implementation of YOLOv3 (Tensorflow backend)

keras-yolo3 Introduction A Keras implementation of YOLOv3 (Tensorflow backend) inspired by allanzelener/YAD2K. Quick Start Download YOLOv3 weights fro

7.1k Jan 03, 2023
A flexible tool for creating, organizing, and sharing visualizations of live, rich data. Supports Torch and Numpy.

Visdom A flexible tool for creating, organizing, and sharing visualizations of live, rich data. Supports Python. Overview Concepts Setup Usage API To

FOSSASIA 9.4k Jan 07, 2023
This is a project based on retinaface face detection, including ghostnet and mobilenetv3

English | 简体中文 RetinaFace in PyTorch Chinese detailed blog:https://zhuanlan.zhihu.com/p/379730820 Face recognition with masks is still robust---------

pogg 59 Dec 21, 2022
An AutoML Library made with Optuna and PyTorch Lightning

An AutoML Library made with Optuna and PyTorch Lightning Installation Recommended pip install -U gradsflow From source pip install git+https://github.

GradsFlow 294 Dec 17, 2022
Fully Convolutional DenseNet (A.K.A 100 layer tiramisu) for semantic segmentation of images implemented in TensorFlow.

FC-DenseNet-Tensorflow This is a re-implementation of the 100 layer tiramisu, technically a fully convolutional DenseNet, in TensorFlow (Tiramisu). Th

Hasnain Raza 121 Oct 12, 2022
This project aims to segment 4 common retinal lesions from Fundus Images.

This project aims to segment 4 common retinal lesions from Fundus Images.

Husam Nujaim 1 Oct 10, 2021
(CVPR 2022) Energy-based Latent Aligner for Incremental Learning

Energy-based Latent Aligner for Incremental Learning Accepted to CVPR 2022 We illustrate an Incremental Learning model trained on a continuum of tasks

Joseph K J 37 Jan 03, 2023
Out-of-Distribution Generalization of Chest X-ray Using Risk Extrapolation

OoD_Gen-Chest_Xray Out-of-Distribution Generalization of Chest X-ray Using Risk Extrapolation Requirements (Installations) Install the following libra

Enoch Tetteh 2 Oct 01, 2022
Instant Real-Time Example-Based Style Transfer to Facial Videos

FaceBlit: Instant Real-Time Example-Based Style Transfer to Facial Videos The official implementation of FaceBlit: Instant Real-Time Example-Based Sty

Aneta Texler 131 Dec 19, 2022
PyTorch Implementation of Vector Quantized Variational AutoEncoders.

Pytorch implementation of VQVAE. This paper combines 2 tricks: Vector Quantization (check out this amazing blog for better understanding.) Straight-Th

Vrushank Changawala 2 Oct 06, 2021
GeoMol: Torsional Geometric Generation of Molecular 3D Conformer Ensembles

GeoMol: Torsional Geometric Generation of Molecular 3D Conformer Ensembles This repository contains a method to generate 3D conformer ensembles direct

127 Dec 20, 2022
Author Disambiguation using Knowledge Graph Embeddings with Literals

Author Name Disambiguation with Knowledge Graph Embeddings using Literals This is the repository for the master thesis project on Knowledge Graph Embe

12 Oct 19, 2022
Torchyolo - Yolov3 ve Yolov4 modellerin Pytorch uygulamasıdır

TORCHYOLO : Yolo Modellerin Pytorch Uygulaması Yapılacaklar: Yolov3 model.py ve

Kadir Nar 3 Aug 22, 2022
TakeInfoatNistforICS - Take Information in NIST NVD for ICS

Take Information in NIST NVD for ICS This project developed with Python. When yo

5 Sep 05, 2022