Deep Reinforcement Learning for Keras.

Overview

Deep Reinforcement Learning for Keras

Build Status Documentation License Join the chat at https://gitter.im/keras-rl/Lobby

What is it?

keras-rl implements some state-of-the art deep reinforcement learning algorithms in Python and seamlessly integrates with the deep learning library Keras.

Furthermore, keras-rl works with OpenAI Gym out of the box. This means that evaluating and playing around with different algorithms is easy.

Of course you can extend keras-rl according to your own needs. You can use built-in Keras callbacks and metrics or define your own. Even more so, it is easy to implement your own environments and even algorithms by simply extending some simple abstract classes. Documentation is available online.

What is included?

As of today, the following algorithms have been implemented:

  • Deep Q Learning (DQN) [1], [2]
  • Double DQN [3]
  • Deep Deterministic Policy Gradient (DDPG) [4]
  • Continuous DQN (CDQN or NAF) [6]
  • Cross-Entropy Method (CEM) [7], [8]
  • Dueling network DQN (Dueling DQN) [9]
  • Deep SARSA [10]
  • Asynchronous Advantage Actor-Critic (A3C) [5]
  • Proximal Policy Optimization Algorithms (PPO) [11]

You can find more information on each agent in the doc.

Installation

  • Install Keras-RL from Pypi (recommended):
pip install keras-rl
  • Install from Github source:
git clone https://github.com/keras-rl/keras-rl.git
cd keras-rl
python setup.py install

Examples

If you want to run the examples, you'll also have to install:

For atari example you will also need:

  • Pillow: pip install Pillow
  • gym[atari]: Atari module for gym. Use pip install gym[atari]

Once you have installed everything, you can try out a simple example:

python examples/dqn_cartpole.py

This is a very simple example and it should converge relatively quickly, so it's a great way to get started! It also visualizes the game during training, so you can watch it learn. How cool is that?

Some sample weights are available on keras-rl-weights.

If you have questions or problems, please file an issue or, even better, fix the problem yourself and submit a pull request!

External Projects

You're using Keras-RL on a project? Open a PR and share it!

Visualizing Training Metrics

To see graphs of your training progress and compare across runs, run pip install wandb and add the WandbLogger callback to your agent's fit() call:

from rl.callbacks import WandbLogger

...

agent.fit(env, nb_steps=50000, callbacks=[WandbLogger()])

For more info and options, see the W&B docs.

Citing

If you use keras-rl in your research, you can cite it as follows:

@misc{plappert2016kerasrl,
    author = {Matthias Plappert},
    title = {keras-rl},
    year = {2016},
    publisher = {GitHub},
    journal = {GitHub repository},
    howpublished = {\url{https://github.com/keras-rl/keras-rl}},
}

References

  1. Playing Atari with Deep Reinforcement Learning, Mnih et al., 2013
  2. Human-level control through deep reinforcement learning, Mnih et al., 2015
  3. Deep Reinforcement Learning with Double Q-learning, van Hasselt et al., 2015
  4. Continuous control with deep reinforcement learning, Lillicrap et al., 2015
  5. Asynchronous Methods for Deep Reinforcement Learning, Mnih et al., 2016
  6. Continuous Deep Q-Learning with Model-based Acceleration, Gu et al., 2016
  7. Learning Tetris Using the Noisy Cross-Entropy Method, Szita et al., 2006
  8. Deep Reinforcement Learning (MLSS lecture notes), Schulman, 2016
  9. Dueling Network Architectures for Deep Reinforcement Learning, Wang et al., 2016
  10. Reinforcement learning: An introduction, Sutton and Barto, 2011
  11. Proximal Policy Optimization Algorithms, Schulman et al., 2017
You might also like...
Distributed Deep learning with Keras & Spark
Distributed Deep learning with Keras & Spark

Elephas: Distributed Deep Learning with Keras & Spark Elephas is an extension of Keras, which allows you to run distributed deep learning models at sc

QKeras: a quantization deep learning library for Tensorflow Keras

QKeras github.com/google/qkeras QKeras 0.8 highlights: Automatic quantization using QKeras; Stochastic behavior (including stochastic rouding) is disa

MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.
MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.

MMdnn MMdnn is a comprehensive and cross-framework tool to convert, visualize and diagnose deep learning (DL) models. The "MM" stands for model manage

Advanced Deep Learning with TensorFlow 2 and Keras (Updated for 2nd Edition)
Advanced Deep Learning with TensorFlow 2 and Keras (Updated for 2nd Edition)

Advanced Deep Learning with TensorFlow 2 and Keras (Updated for 2nd Edition)

Keras like implementation of Deep Learning architectures from scratch using numpy.

Mini-Keras Keras like implementation of Deep Learning architectures from scratch using numpy. How to contribute? The project contains implementations

Realtime Face Anti Spoofing with Face Detector based on Deep Learning using Tensorflow/Keras and OpenCV
Realtime Face Anti Spoofing with Face Detector based on Deep Learning using Tensorflow/Keras and OpenCV

Realtime Face Anti-Spoofing Detection 🤖 Realtime Face Anti Spoofing Detection with Face Detector to detect real and fake faces Please star this repo

This source code is implemented using keras library based on "Automatic ocular artifacts removal in EEG using deep learning"

CSP_Deep_EEG This source code is implemented using keras library based on "Automatic ocular artifacts removal in EEG using deep learning" {https://www

Vision Deep-Learning using Tensorflow, Keras.

Welcome! I am a computer vision deep learning developer working in Korea. This is my blog, and you can see everything I've studied here. https://www.n

A deep learning network built with TensorFlow and Keras to classify gender and estimate age.
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Releases(v0.4.2)
efficient neural audio synthesis in the waveform domain

neural waveshaping synthesis real-time neural audio synthesis in the waveform domain paper • website • colab • audio by Ben Hayes, Charalampos Saitis,

Ben Hayes 169 Dec 23, 2022
PyTorch package for the discrete VAE used for DALL·E.

Overview [Blog] [Paper] [Model Card] [Usage] This is the official PyTorch package for the discrete VAE used for DALL·E. Installation Before running th

OpenAI 9.5k Jan 05, 2023
Short and long time series classification using convolutional neural networks

time-series-classification Short and long time series classification via convolutional neural networks In this project, we present a novel framework f

35 Oct 22, 2022
Code repo for "Transformer on a Diet" paper

Transformer on a Diet Reference: C Wang, Z Ye, A Zhang, Z Zhang, A Smola. "Transformer on a Diet". arXiv preprint arXiv (2020). Installation pip insta

cgraywang 31 Sep 26, 2021
Code for the paper Hybrid Spectrogram and Waveform Source Separation

Demucs Music Source Separation This is the 3rd release of Demucs (v3), featuring hybrid source separation. For the waveform only Demucs (v2): Go this

Meta Research 4.8k Jan 04, 2023
Very large and sparse networks appear often in the wild and present unique algorithmic opportunities and challenges for the practitioner

Sparse network learning with snlpy Very large and sparse networks appear often in the wild and present unique algorithmic opportunities and challenges

Andrew Stolman 1 Apr 30, 2021
Internship Assessment Task for BaggageAI.

BaggageAI Internship Task Problem Statement: You are given two sets of images:- background and threat objects. Background images are the background x-

Arya Shah 10 Nov 14, 2022
Understanding Convolutional Neural Networks from Theoretical Perspective via Volterra Convolution

nnvolterra Run Code Compile first: make compile Run all codes: make all Test xconv: make npxconv_test MNIST dataset needs to be downloaded, converted

1 May 24, 2022
BBScan py3 - BBScan py3 With Python

BBScan_py3 This repository is forked from lijiejie/BBScan 1.5. I migrated the fo

baiyunfei 12 Dec 30, 2022
Fine-grained Control of Image Caption Generation with Abstract Scene Graphs

Faster R-CNN pretrained on VisualGenome This repository modifies maskrcnn-benchmark for object detection and attribute prediction on VisualGenome data

Shizhe Chen 7 Apr 20, 2021
PyTorch module to use OpenFace's nn4.small2.v1.t7 model

OpenFace for Pytorch Disclaimer: This codes require the input face-images that are aligned and cropped in the same way of the original OpenFace. * I m

Pete Tae-hoon Kim 176 Dec 12, 2022
[ACMMM 2021 Oral] Enhanced Invertible Encoding for Learned Image Compression

InvCompress Official Pytorch Implementation for "Enhanced Invertible Encoding for Learned Image Compression", ACMMM 2021 (Oral) Figure: Our framework

96 Nov 30, 2022
Einshape: DSL-based reshaping library for JAX and other frameworks.

Einshape: DSL-based reshaping library for JAX and other frameworks. The jnp.einsum op provides a DSL-based unified interface to matmul and tensordot o

DeepMind 62 Nov 30, 2022
A new video text spotting framework with Transformer

TransVTSpotter: End-to-end Video Text Spotter with Transformer Introduction A Multilingual, Open World Video Text Dataset and End-to-end Video Text Sp

weijiawu 67 Jan 03, 2023
Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow

eXtreme Gradient Boosting Community | Documentation | Resources | Contributors | Release Notes XGBoost is an optimized distributed gradient boosting l

Distributed (Deep) Machine Learning Community 23.6k Dec 31, 2022
Machine learning notebooks in different subjects optimized to run in google collaboratory

Notebooks Name Description Category Link Training pix2pix This notebook shows a simple pipeline for training pix2pix on a simple dataset. Most of the

Zaid Alyafeai 363 Dec 06, 2022
This is a tensorflow-based rotation detection benchmark, also called AlphaRotate.

AlphaRotate: A Rotation Detection Benchmark using TensorFlow Abstract AlphaRotate is maintained by Xue Yang with Shanghai Jiao Tong University supervi

yangxue 972 Jan 05, 2023
Paddle implementation for "Highly Efficient Knowledge Graph Embedding Learning with Closed-Form Orthogonal Procrustes Analysis" (NAACL 2021)

ProcrustEs-KGE Paddle implementation for Highly Efficient Knowledge Graph Embedding Learning with Orthogonal Procrustes Analysis 🙈 A more detailed re

Lincedo Lab 4 Jun 09, 2021
MMRazor: a model compression toolkit for model slimming and AutoML

Documentation: https://mmrazor.readthedocs.io/ English | 简体中文 Introduction MMRazor is a model compression toolkit for model slimming and AutoML, which

OpenMMLab 899 Jan 02, 2023
Data visualization app for H&M competition in kaggle

handm_data_visualize_app Data visualization app by streamlit for H&M competition in kaggle. competition page: https://www.kaggle.com/competitions/h-an

Kyohei Uto 12 Apr 30, 2022