Repo for our ICML21 paper Unsupervised Learning of Visual 3D Keypoints for Control

Overview

Unsupervised Learning of Visual 3D Keypoints for Control

[Project Website] [Paper]

Boyuan Chen1, Pieter Abbeel1, Deepak Pathak2
1UC Berkeley 2Carnegie Mellon University

teaser

This is the code base for our paper on unsupervised learning of visual 3d keypoints for control. We propose an unsupervised learning method that learns temporally-consistent 3d keypoints via interaction. We jointly train an RL policy with the keypoint detector and shows 3d keypoints improve the sample efficiency of task learning in a variety of environments. If you find this work helpful to your research, please cite us as:

@inproceedings{chen2021unsupervised,
    title={Unsupervised Learning of Visual 3D Keypoints for Control},
    author={Boyuan Chen and Pieter Abbeel and Deepak Pathak},
    year={2021},
    Booktitle={ICML}
}

Environment Setup

If you hope to run meta-world experiments, make sure you have your mujoco binaries and valid license key in ~/.mujoco. Otherwise, you should edit the requirements.txt to remove metaworld and mujoco-py accordingly to avoid errors.

# clone this repo
git clone https://github.com/buoyancy99/unsup-3d-keypoints
cd unsup-3d-keypoints

# setup conda environment
conda create -n keypoint3d python=3.7.5
conda activate keypoint3d
pip3 install -r requirements.txt

Run Experiments

When training, all logs will be stored at data/, visualizations will be stored in images/ and all check points at ckpts/. You may use tensorboard to visualize training log or plotting the monitor files.

Quick start with pre-trained weights

# Visualize metaworld-hammer environment
python3 visualize.py --algo ppokeypoint -t hammer -v 1 -m 3d -j --offset_crop --decode_first_frame --num_keypoint 6 --decode_attention --seed 99 -u -e 0007

# Visualize metaworld-close-box environment
python3 visualize.py --algo ppokeypoint -t bc -v 1 -m 3d -j --offset_crop --decode_first_frame --num_keypoint 6 --decode_attention --seed 99 -u -e 0008

Reproduce the keypoints similiar to the two pre-trained checkpoints

# To reproduce keypoints visualization similiar to the above two checkpoints, use these commands
# Feel free to try any seed using [--seed]. Seeding makes training deterministic on each machine but has no guarantee across devices if using GPU. Thus you might not get the exact checkpoints as me if GPU models differ but resulted keypoints should look similiar. 

python3 train.py --algo ppokeypoint -t hammer -v 1 -e 0007 -m 3d -j --total_timesteps 6000000 --offset_crop --decode_first_frame --num_keypoint 6 --decode_attention --seed 200 -u

python3 train.py --algo ppokeypoint -t bc -v 1 -e 0008 -m 3d -j --total_timesteps 6000000 --offset_crop --decode_first_frame --num_keypoint 6 --decode_attention --seed 200 -u

Train & Visualize Pybullet Ant with Keypoint3D(Ours)

# use -t antnc to train ant with no color 
python3 train.py --algo ppokeypoint -t ant -v 1 -e 0001 -m 3d --frame_stack 2 -j --total_timesteps 5000000 --num_keypoint 16 --latent_stack --decode_first_frame --offset_crop --mean_depth 1.7 --decode_attention --separation_coef 0.005 --seed 99 -u

# After checkpoint is saved, visualize
python3 visualize.py --algo ppokeypoint -t ant -v 1 -e 0001 -m 3d --frame_stack 2 -j --total_timesteps 5000000 --num_keypoint 16 --latent_stack --decode_first_frame --offset_crop --mean_depth 1.7 --decode_attention --separation_coef 0.005 --seed 99 -u

Train Pybullet Ant with baselines

# RAD PPO baseline
python3 train.py --algo pporad -t ant -v 1 -e 0002 --total_timesteps 5000000 --frame_stack 2 --seed 99 -u

# Vanilla PPO baseline
python3 train.py --algo ppopixel -t ant -v 1 -e 0003 --total_timesteps 5000000 --frame_stack 2 --seed 99 -u

Train & Visualize 'Close-Box' environment in Meta-world with Keypoint3D(Ours)

python3 train.py --algo ppokeypoint -t bc -v 1 -e 0004 -m 3d -j --offset_crop --decode_first_frame --num_keypoint 32 --decode_attention --total_timesteps 4000000 --seed 99 -u

# After checkpoint is saved, visualize
python3 visualize.py --algo ppokeypoint -t bc -v 1 -e 0004 -m 3d -j --offset_crop --decode_first_frame --num_keypoint 32 --decode_attention --total_timesteps 4000000 --seed 99 -u

Train 'Close-Box' environment in Meta-world with baselines

# RAD PPO baseline
python3 train.py --algo pporad -t bc -v 1 -e 0005 --total_timesteps 4000000 --seed 99 -u

# Vanilla PPO baseline
python3 train.py --algo ppopixel -t bc -v 1 -e 0006 --total_timesteps 4000000 --seed 99 -u

Other environments in general

# Any training command follows the following format
python3 train.py -a [algo name] -t [env name] -v [env version] -e [experiment id] [...]

# Any visualization command is simply using the same options but run visualize.py instead of train.py
python3 visualize.py -a [algo name] -t [env name] -v [env version] -e [experiment id] [...]

# For colorless ant, you can change the ant example's [-t ant] flag to [-t antnc]
# For metaworld, you can change the close-box example's [-t bc] flag to other abbreviations such as [-t door] etc.

# For a full list of arugments and their meanings,
python3 train.py -h

Update Log

Data Notes
Jun/15/21 Initial release of the code. Email me if you have questions or find any errors in this version.
Jun/16/21 Add all metaworld environments with notes about placeholder observations
Owner
Boyuan Chen
PhD at MIT studying ML + Robotics
Boyuan Chen
Codebase for Inducing Causal Structure for Interpretable Neural Networks

Interchange Intervention Training (IIT) Codebase for Inducing Causal Structure for Interpretable Neural Networks Release Notes 12/01/2021: Code and Pa

Zen 6 Oct 10, 2022
PyTorch implementation of Barlow Twins.

Barlow Twins: Self-Supervised Learning via Redundancy Reduction PyTorch implementation of Barlow Twins. @article{zbontar2021barlow, title={Barlow Tw

Facebook Research 839 Dec 29, 2022
Deep generative models of 3D grids for structure-based drug discovery

What is liGAN? liGAN is a research codebase for training and evaluating deep generative models for de novo drug design based on 3D atomic density grid

Matt Ragoza 152 Jan 03, 2023
Steer OpenAI's Jukebox with Music Taggers

TagBox Steer OpenAI's Jukebox with Music Taggers! The closest thing we have to VQGAN+CLIP for music! Unsupervised Source Separation By Steering Pretra

Ethan Manilow 34 Nov 02, 2022
Unsupervised Learning of Multi-Frame Optical Flow with Occlusions

This is a Pytorch implementation of Janai, J., Güney, F., Ranjan, A., Black, M. and Geiger, A., Unsupervised Learning of Multi-Frame Optical Flow with

Anurag Ranjan 110 Nov 02, 2022
Official Repository for "Robust On-Policy Data Collection for Data Efficient Policy Evaluation" (NeurIPS 2021 Workshop on OfflineRL).

Robust On-Policy Data Collection for Data-Efficient Policy Evaluation Source code of Robust On-Policy Data Collection for Data-Efficient Policy Evalua

Autonomous Agents Research Group (University of Edinburgh) 2 Oct 09, 2022
3D cascade RCNN for object detection on point cloud

3D Cascade RCNN This is the implementation of 3D Cascade RCNN: High Quality Object Detection in Point Clouds. We designed a 3D object detection model

Qi Cai 22 Dec 02, 2022
PINN Burgers - 1D Burgers equation simulated by PINN

PINN(s): Physics-Informed Neural Network(s) for Burgers equation This is an impl

ShotaDEGUCHI 1 Feb 12, 2022
Ray tracing of a Schwarzschild black hole written entirely in TensorFlow.

TensorGeodesic Ray tracing of a Schwarzschild black hole written entirely in TensorFlow. Dependencies: Python 3 TensorFlow 2.x numpy matplotlib About

5 Jan 15, 2022
The implementation of the algorithm in the paper "Safe Deep Semi-Supervised Learning for Unseen-Class Unlabeled Data" published in ICML 2020.

DS3L This is the code for paper "Safe Deep Semi-Supervised Learning for Unseen-Class Unlabeled Data" published in ICML 2020. Setups The code is implem

Guolz 36 Oct 19, 2022
Generative Exploration and Exploitation - This is an improved version of GENE.

GENE This is an improved version of GENE. In the original version, the states are generated from the decoder of VAE. We have to check whether the gere

33 Mar 23, 2022
Code and real data for the paper "Counterfactual Temporal Point Processes", available at arXiv.

counterfactual-tpp This is a repository containing code and real data for the paper Counterfactual Temporal Point Processes. Pre-requisites This code

Networks Learning 11 Dec 09, 2022
Code for paper "Learning to Reweight Examples for Robust Deep Learning"

learning-to-reweight-examples Code for paper Learning to Reweight Examples for Robust Deep Learning. [arxiv] Environment We tested the code on tensorf

Uber Research 261 Jan 01, 2023
This is a repo of basic Machine Learning!

Basic Machine Learning This repository contains a topic-wise curated list of Machine Learning and Deep Learning tutorials, articles and other resource

Ekram Asif 53 Dec 31, 2022
neural image generation

pixray Pixray is an image generation system. It combines previous ideas including: Perception Engines which uses image augmentation and iteratively op

dribnet 398 Dec 17, 2022
Official repo for SemanticGAN https://nv-tlabs.github.io/semanticGAN/

SemanticGAN This is the official code for: Semantic Segmentation with Generative Models: Semi-Supervised Learning and Strong Out-of-Domain Generalizat

151 Dec 28, 2022
The code repository for "RCNet: Reverse Feature Pyramid and Cross-scale Shift Network for Object Detection" (ACM MM'21)

RCNet: Reverse Feature Pyramid and Cross-scale Shift Network for Object Detection (ACM MM'21) By Zhuofan Zong, Qianggang Cao, Biao Leng Introduction F

TempleX 9 Jul 30, 2022
Unofficial Implementation of RobustSTL: A Robust Seasonal-Trend Decomposition Algorithm for Long Time Series (AAAI 2019)

RobustSTL: A Robust Seasonal-Trend Decomposition Algorithm for Long Time Series (AAAI 2019) This repository contains python (3.5.2) implementation of

Doyup Lee 222 Dec 21, 2022
Utilizes Pose Estimation to offer sprinters cues based on an image of their running form.

Running-Form-Correction Utilizes Pose Estimation to offer sprinters cues based on an image of their running form. How to Run Dependencies You will nee

3 Nov 08, 2022
A self-supervised learning framework for audio-visual speech

AV-HuBERT (Audio-Visual Hidden Unit BERT) Learning Audio-Visual Speech Representation by Masked Multimodal Cluster Prediction Robust Self-Supervised A

Meta Research 431 Jan 07, 2023