Deep Reinforcement Learning for Keras.

Overview

Deep Reinforcement Learning for Keras

Build Status Documentation License Join the chat at https://gitter.im/keras-rl/Lobby

What is it?

keras-rl implements some state-of-the art deep reinforcement learning algorithms in Python and seamlessly integrates with the deep learning library Keras.

Furthermore, keras-rl works with OpenAI Gym out of the box. This means that evaluating and playing around with different algorithms is easy.

Of course you can extend keras-rl according to your own needs. You can use built-in Keras callbacks and metrics or define your own. Even more so, it is easy to implement your own environments and even algorithms by simply extending some simple abstract classes. Documentation is available online.

What is included?

As of today, the following algorithms have been implemented:

  • Deep Q Learning (DQN) [1], [2]
  • Double DQN [3]
  • Deep Deterministic Policy Gradient (DDPG) [4]
  • Continuous DQN (CDQN or NAF) [6]
  • Cross-Entropy Method (CEM) [7], [8]
  • Dueling network DQN (Dueling DQN) [9]
  • Deep SARSA [10]
  • Asynchronous Advantage Actor-Critic (A3C) [5]
  • Proximal Policy Optimization Algorithms (PPO) [11]

You can find more information on each agent in the doc.

Installation

  • Install Keras-RL from Pypi (recommended):
pip install keras-rl
  • Install from Github source:
git clone https://github.com/keras-rl/keras-rl.git
cd keras-rl
python setup.py install

Examples

If you want to run the examples, you'll also have to install:

For atari example you will also need:

  • Pillow: pip install Pillow
  • gym[atari]: Atari module for gym. Use pip install gym[atari]

Once you have installed everything, you can try out a simple example:

python examples/dqn_cartpole.py

This is a very simple example and it should converge relatively quickly, so it's a great way to get started! It also visualizes the game during training, so you can watch it learn. How cool is that?

Some sample weights are available on keras-rl-weights.

If you have questions or problems, please file an issue or, even better, fix the problem yourself and submit a pull request!

External Projects

You're using Keras-RL on a project? Open a PR and share it!

Visualizing Training Metrics

To see graphs of your training progress and compare across runs, run pip install wandb and add the WandbLogger callback to your agent's fit() call:

from rl.callbacks import WandbLogger

...

agent.fit(env, nb_steps=50000, callbacks=[WandbLogger()])

For more info and options, see the W&B docs.

Citing

If you use keras-rl in your research, you can cite it as follows:

@misc{plappert2016kerasrl,
    author = {Matthias Plappert},
    title = {keras-rl},
    year = {2016},
    publisher = {GitHub},
    journal = {GitHub repository},
    howpublished = {\url{https://github.com/keras-rl/keras-rl}},
}

References

  1. Playing Atari with Deep Reinforcement Learning, Mnih et al., 2013
  2. Human-level control through deep reinforcement learning, Mnih et al., 2015
  3. Deep Reinforcement Learning with Double Q-learning, van Hasselt et al., 2015
  4. Continuous control with deep reinforcement learning, Lillicrap et al., 2015
  5. Asynchronous Methods for Deep Reinforcement Learning, Mnih et al., 2016
  6. Continuous Deep Q-Learning with Model-based Acceleration, Gu et al., 2016
  7. Learning Tetris Using the Noisy Cross-Entropy Method, Szita et al., 2006
  8. Deep Reinforcement Learning (MLSS lecture notes), Schulman, 2016
  9. Dueling Network Architectures for Deep Reinforcement Learning, Wang et al., 2016
  10. Reinforcement learning: An introduction, Sutton and Barto, 2011
  11. Proximal Policy Optimization Algorithms, Schulman et al., 2017
You might also like...
Distributed Deep learning with Keras & Spark
Distributed Deep learning with Keras & Spark

Elephas: Distributed Deep Learning with Keras & Spark Elephas is an extension of Keras, which allows you to run distributed deep learning models at sc

QKeras: a quantization deep learning library for Tensorflow Keras

QKeras github.com/google/qkeras QKeras 0.8 highlights: Automatic quantization using QKeras; Stochastic behavior (including stochastic rouding) is disa

MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.
MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.

MMdnn MMdnn is a comprehensive and cross-framework tool to convert, visualize and diagnose deep learning (DL) models. The "MM" stands for model manage

Advanced Deep Learning with TensorFlow 2 and Keras (Updated for 2nd Edition)
Advanced Deep Learning with TensorFlow 2 and Keras (Updated for 2nd Edition)

Advanced Deep Learning with TensorFlow 2 and Keras (Updated for 2nd Edition)

Keras like implementation of Deep Learning architectures from scratch using numpy.

Mini-Keras Keras like implementation of Deep Learning architectures from scratch using numpy. How to contribute? The project contains implementations

Realtime Face Anti Spoofing with Face Detector based on Deep Learning using Tensorflow/Keras and OpenCV
Realtime Face Anti Spoofing with Face Detector based on Deep Learning using Tensorflow/Keras and OpenCV

Realtime Face Anti-Spoofing Detection 🤖 Realtime Face Anti Spoofing Detection with Face Detector to detect real and fake faces Please star this repo

This source code is implemented using keras library based on "Automatic ocular artifacts removal in EEG using deep learning"

CSP_Deep_EEG This source code is implemented using keras library based on "Automatic ocular artifacts removal in EEG using deep learning" {https://www

Vision Deep-Learning using Tensorflow, Keras.

Welcome! I am a computer vision deep learning developer working in Korea. This is my blog, and you can see everything I've studied here. https://www.n

A deep learning network built with TensorFlow and Keras to classify gender and estimate age.
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Releases(v0.4.2)
Python implementation of Bayesian optimization over permutation spaces.

Bayesian Optimization over Permutation Spaces This repository contains the source code and the resources related to the paper "Bayesian Optimization o

Aryan Deshwal 9 Dec 23, 2022
CPPE - 5 (Medical Personal Protective Equipment) is a new challenging object detection dataset

CPPE - 5 CPPE - 5 (Medical Personal Protective Equipment) is a new challenging dataset with the goal to allow the study of subordinate categorization

Rishit Dagli 53 Dec 17, 2022
PyTorch Implementation of SSTNs for hyperspectral image classifications from the IEEE T-GRS paper "Spectral-Spatial Transformer Network for Hyperspectral Image Classification: A FAS Framework."

PyTorch Implementation of SSTN for Hyperspectral Image Classification Paper links: SSTN published on IEEE T-GRS. Also, you can directly find the imple

Zilong Zhong 54 Dec 19, 2022
Research code for CVPR 2021 paper "End-to-End Human Pose and Mesh Reconstruction with Transformers"

MeshTransformer ✨ This is our research code of End-to-End Human Pose and Mesh Reconstruction with Transformers. MEsh TRansfOrmer is a simple yet effec

Microsoft 473 Dec 31, 2022
Official code for the ICCV 2021 paper "DECA: Deep viewpoint-Equivariant human pose estimation using Capsule Autoencoders"

DECA Official code for the ICCV 2021 paper "DECA: Deep viewpoint-Equivariant human pose estimation using Capsule Autoencoders". All the code is writte

23 Dec 01, 2022
(to be released) [NeurIPS'21] Transformers Generalize DeepSets and Can be Extended to Graphs and Hypergraphs

Higher-Order Transformers Kim J, Oh S, Hong S, Transformers Generalize DeepSets and Can be Extended to Graphs and Hypergraphs, NeurIPS 2021. [arxiv] W

Jinwoo Kim 44 Dec 28, 2022
Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"

Memory Compressed Attention Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers

Phil Wang 47 Dec 23, 2022
A Small and Easy approach to the BraTS2020 dataset (2D Segmentation)

BraTS2020 A Light & Scalable Solution to BraTS2020 | Medical Brain Tumor Segmentation (2D Segmentation) Developed the segmentation models for segregat

Gunjan Haldar 0 Jan 19, 2022
ZeroVL - The official implementation of ZeroVL

This repository contains source code necessary to reproduce the results presente

31 Nov 04, 2022
Accelerated SMPL operation, commonly used in generate 3D human mesh, STAR included.

SMPL2 An enchanced and accelerated SMPL operation which commonly used in 3D human mesh generation. It takes a poses, shapes, cam_trans as inputs, outp

JinTian 20 Oct 17, 2022
Python framework for Stochastic Differential Equations modeling

SDElearn: a Python package for SDE modeling This package implements functionalities for working with Stochastic Differential Equations models (SDEs fo

4 May 10, 2022
Prompt Tuning with Rules

PTR Code and datasets for our paper "PTR: Prompt Tuning with Rules for Text Classification" If you use the code, please cite the following paper: @art

THUNLP 118 Dec 30, 2022
This repo contains the code for the paper "Efficient hierarchical Bayesian inference for spatio-temporal regression models in neuroimaging" that has been accepted to NeurIPS 2021.

Dugh-NeurIPS-2021 This repo contains the code for the paper "Efficient hierarchical Bayesian inference for spatio-temporal regression models in neuroi

Ali Hashemi 5 Jul 12, 2022
Deeply Supervised, Layer-wise Prediction-aware (DSLP) Transformer for Non-autoregressive Neural Machine Translation

Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision Training Efficiency We show the training efficiency of our DSLP model b

Chenyang Huang 36 Oct 31, 2022
public repo for ESTER dataset and modeling (EMNLP'21)

Project / Paper Introduction This is the project repo for our EMNLP'21 paper: https://arxiv.org/abs/2104.08350 Here, we provide brief descriptions of

PlusLab 19 Oct 27, 2022
Repo for flood prediction using LSTMs and HAND

Abstract Every year, floods cause billions of dollars’ worth of damages to life, crops, and property. With a proper early flood warning system in plac

1 Oct 27, 2021
Official implementation of the paper Vision Transformer with Progressive Sampling, ICCV 2021.

Vision Transformer with Progressive Sampling This is the official implementation of the paper Vision Transformer with Progressive Sampling, ICCV 2021.

yuexy 123 Jan 01, 2023
Large-scale open domain KNOwledge grounded conVERsation system based on PaddlePaddle

Knover Knover is a toolkit for knowledge grounded dialogue generation based on PaddlePaddle. Knover allows researchers and developers to carry out eff

607 Dec 31, 2022
C3d-pytorch - Pytorch porting of C3D network, with Sports1M weights

C3D for pytorch This is a pytorch porting of the network presented in the paper Learning Spatiotemporal Features with 3D Convolutional Networks How to

Davide Abati 311 Jan 06, 2023
This program was designed to detect whether someone is wearing a facemask through a live video stream.

This program was designed to detect whether someone is wearing a facemask through a live video stream. A custom lightweight CNN trained with TensorFlow on a public dataset provided by Kaggle is used

0 Apr 02, 2022