Code for "Unsupervised State Representation Learning in Atari"

Overview

Unsupervised State Representation Learning in Atari

Ankesh Anand*, Evan Racah*, Sherjil Ozair*, Yoshua Bengio, Marc-Alexandre Côté, R Devon Hjelm

This repo provides code for the benchmark and techniques introduced in the paper Unsupervised State Representation Learning in Atari

Install

AtariARI Wrapper

You can do a minimal install to get just the AtariARI (Atari Annotated RAM Interface) wrapper by doing:

pip install 'gym[atari]'
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

This just requires gym[atari] and it gives you the ability to play around with the AtariARI wrapper. If you want to use the code for training representation learning methods and probing them, you will need a full installation:

Full installation (AtariARI Wrapper + Training & Probing Code)

# PyTorch and scikit learn
conda install pytorch torchvision -c pytorch
conda install scikit-learn

# Baselines for Atari preprocessing
# Tensorflow is a dependency, but you don't need to install the GPU version
conda install tensorflow
pip install git+git://github.com/openai/baselines

# pytorch-a2c-ppo-acktr for RL utils
pip install git+git://github.com/ankeshanand/pytorch-a2c-ppo-acktr-gail

# Clone and install our package
pip install -r requirements.txt
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

Usage

Atari Annotated RAM Interface (AtariARI):

AtariARI exposes the ground truth labels for different state variables for each observation. We have made AtariARI available as a Gym wrapper, to use it simply wrap an Atari gym env with AtariARIWrapper.

import gym
from atariari.benchmark.wrapper import AtariARIWrapper
env = AtariARIWrapper(gym.make('MsPacmanNoFrameskip-v4'))
obs = env.reset()
obs, reward, done, info = env.step(1)

Now, info is a dictionary of the form:

{'ale.lives': 3,
 'labels': {'enemy_sue_x': 88,
  'enemy_inky_x': 88,
  'enemy_pinky_x': 88,
  'enemy_blinky_x': 88,
  'enemy_sue_y': 80,
  'enemy_inky_y': 80,
  'enemy_pinky_y': 80,
  'enemy_blinky_y': 50,
  'player_x': 88,
  'player_y': 98,
  'fruit_x': 0,
  'fruit_y': 0,
  'ghosts_count': 3,
  'player_direction': 3,
  'dots_eaten_count': 0,
  'player_score': 0,
  'num_lives': 2}}

Note: In our experiments, we use additional preprocessing for Atari environments mainly following Minh et. al, 2014. See atariari/benchmark/envs.py for more info!

If you want the raw RAM annotations (which parts of ram correspond to each state variable), check out atariari/benchmark/ram_annotations.py

Probing


⚠️ Important ⚠️ : The RAM labels are meant for full-sized Atari observations (210 * 160). Probing results won't be accurate if you downsample the observations.

We provide an interface for the included probing tasks.

First, get episodes for train, val and, test:

from atariari.benchmark.episodes import get_episodes

tr_episodes, val_episodes,\
tr_labels, val_labels,\
test_episodes, test_labels = get_episodes(env_name="PitfallNoFrameskip-v4", 
                                     steps=50000, 
                                     collect_mode="random_agent")

Then probe them using ProbeTrainer and your encoder (my_encoder):

from atariari.benchmark.probe import ProbeTrainer

probe_trainer = ProbeTrainer(my_encoder, representation_len=my_encoder.feature_size)
probe_trainer.train(tr_episodes, val_episodes,
                     tr_labels, val_labels,)
final_accuracies, final_f1_scores = probe_trainer.test(test_episodes, test_labels)

To see how we use ProbeTrainer, check out scripts/run_probe.py

Here is an example of my_encoder:

# get your encoder
import torch.nn as nn
import torch
class MyEncoder(nn.Module):
    def __init__(self, input_channels, feature_size):
        super().__init__()
        self.feature_size = feature_size
        self.input_channels = input_channels
        self.final_conv_size = 64 * 9 * 6
        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, stride=1),
            nn.ReLU()
        )
        self.fc = nn.Linear(self.final_conv_size, self.feature_size)

    def forward(self, inputs):
        x = self.cnn(inputs)
        x = x.view(x.size(0), -1)
        return self.fc(x)
        

my_encoder = MyEncoder(input_channels=1,feature_size=256)
# load in weights
my_encoder.load_state_dict(torch.load(open("path/to/my/weights.pt", "rb")))

Spatio-Temporal DeepInfoMax:

src/ contains implementations of several representation learning methods, along with ST-DIM. Here's a sample usage:

python -m scripts.run_probe --method infonce-stdim --env-name {env_name}

where env_name is of the form {game}NoFrameskip-v4, such as PongNoFrameskip-v4

Citation

@article{anand2019unsupervised,
  title={Unsupervised State Representation Learning in Atari},
  author={Anand, Ankesh and Racah, Evan and Ozair, Sherjil and Bengio, Yoshua and C{\^o}t{\'e}, Marc-Alexandre and Hjelm, R Devon},
  journal={arXiv preprint arXiv:1906.08226},
  year={2019}
}
Owner
Mila
Quebec Artificial Intelligence Institute
Mila
The MLOps platform for innovators 🚀

​ DS2.ai is an integrated AI operation solution that supports all stages from custom AI development to deployment. It is an AI-specialized platform service that collects data, builds a training datas

9 Jan 03, 2023
Semantically Contrastive Learning for Low-light Image Enhancement

Semantically Contrastive Learning for Low-light Image Enhancement Here, we propose an effective semantically contrastive learning paradigm for Low-lig

48 Dec 16, 2022
A criticism of a recent paper on buggy image downsampling methods in popular image processing and deep learning libraries.

A criticism of a recent paper on buggy image downsampling methods in popular image processing and deep learning libraries.

70 Jul 12, 2022
A PyTorch implementation of "TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?"

TokenLearner: What Can 8 Learned Tokens Do for Images and Videos? Source: Improving Vision Transformer Efficiency and Accuracy by Learning to Tokenize

Caiyong Wang 14 Sep 20, 2022
CenterPoint 3D Object Detection and Tracking using center points in the bird-eye view.

CenterPoint 3D Object Detection and Tracking using center points in the bird-eye view. Center-based 3D Object Detection and Tracking, Tianwei Yin, Xin

Tianwei Yin 134 Dec 23, 2022
Awesome Remote Sensing Toolkit based on PaddlePaddle.

基于飞桨框架开发的高性能遥感图像处理开发套件,端到端地完成从训练到部署的全流程遥感深度学习应用。 最新动态 PaddleRS 即将发布alpha版本!欢迎大家试用 简介 PaddleRS是遥感科研院所、相关高校共同基于飞桨开发的遥感处理平台,支持遥感图像分类,目标检测,图像分割,以及变化检测等常用遥

146 Dec 11, 2022
Kohei's 5th place solution for xview3 challenge

xview3-kohei-solution Usage This repository assumes that the given data set is stored in the following locations: $ ls data/input/xview3/*.csv data/in

Kohei Ozaki 2 Jan 17, 2022
Demo code for paper "Learning optical flow from still images", CVPR 2021.

Depthstillation Demo code for "Learning optical flow from still images", CVPR 2021. [Project page] - [Paper] - [Supplementary] This code is provided t

130 Dec 25, 2022
Official Pytorch implementation of "Learning to Estimate Robust 3D Human Mesh from In-the-Wild Crowded Scenes", CVPR 2022

Learning to Estimate Robust 3D Human Mesh from In-the-Wild Crowded Scenes / 3DCrowdNet News 💪 3DCrowdNet achieves the state-of-the-art accuracy on 3D

Hongsuk Choi 113 Dec 21, 2022
This repo is the official implementation of "L2ight: Enabling On-Chip Learning for Optical Neural Networks via Efficient in-situ Subspace Optimization".

L2ight is a closed-loop ONN on-chip learning framework to enable scalable ONN mapping and efficient in-situ learning. L2ight adopts a three-stage learning flow that first calibrates the complicated p

Jiaqi Gu 9 Jul 14, 2022
Unofficial pytorch implementation of 'Image Inpainting for Irregular Holes Using Partial Convolutions'

pytorch-inpainting-with-partial-conv Official implementation is released by the authors. Note that this is an ongoing re-implementation and I cannot f

Naoto Inoue 525 Jan 01, 2023
Fast Axiomatic Attribution for Neural Networks (NeurIPS*2021)

Fast Axiomatic Attribution for Neural Networks This is the official repository accompanying the NeurIPS 2021 paper: R. Hesse, S. Schaub-Meyer, and S.

Visual Inference Lab @TU Darmstadt 11 Nov 21, 2022
a spacial-temporal pattern detection system for home automation

Argos a spacial-temporal pattern detection system for home automation. Based on OpenCV and Tensorflow, can run on raspberry pi and notify HomeAssistan

Angad Singh 133 Jan 05, 2023
Pytorch implementation for Patient Knowledge Distillation for BERT Model Compression

Patient Knowledge Distillation for BERT Model Compression Knowledge distillation for BERT model Installation Run command below to install the environm

Siqi 180 Dec 19, 2022
Digital Twin Mobility Profiling: A Spatio-Temporal Graph Learning Approach

Digital Twin Mobility Profiling: A Spatio-Temporal Graph Learning Approach This is the implementation of traffic prediction code in DTMP based on PyTo

chenxin 1 Dec 19, 2021
Code for Universal Semi-Supervised Semantic Segmentation models paper accepted in ICCV 2019

USSS_ICCV19 Code for Universal Semi Supervised Semantic Segmentation accepted to ICCV 2019. Full Paper available at https://arxiv.org/abs/1811.10323.

Tarun K 68 Nov 24, 2022
Implementation of U-Net and SegNet for building segmentation

Specialized project Created by Katrine Nguyen and Martin Wangen-Eriksen as a part of our specialized project at Norwegian University of Science and Te

Martin.w-e 3 Dec 07, 2022
Contrastive Learning of Structured World Models

Contrastive Learning of Structured World Models This repository contains the official PyTorch implementation of: Contrastive Learning of Structured Wo

Thomas Kipf 371 Jan 06, 2023
Video-based open-world segmentation

UVO_Challenge Team Alpes_runner Solutions This is an official repo for our UVO Challenge solutions for Image/Video-based open-world segmentation. Our

Yuming Du 84 Dec 22, 2022
(AAAI2020)Grapy-ML: Graph Pyramid Mutual Learning for Cross-dataset Human Parsing

Grapy-ML: Graph Pyramid Mutual Learning for Cross-dataset Human Parsing This repository contains pytorch source code for AAAI2020 oral paper: Grapy-ML

54 Aug 04, 2022