Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

Overview

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning

The predictive learning of spatiotemporal sequences aims to generate future images by learning from the historical context, where the visual dynamics are believed to have modular structures that can be learned with compositional subsystems.

First version at NeurIPS 2017

This repo first contains a PyTorch implementation of PredRNN (2017) [paper], a recurrent network with a pair of memory cells that operate in nearly independent transition manners, and finally form unified representations of the complex environment.

Concretely, besides the original memory cell of LSTM, this network is featured by a zigzag memory flow that propagates in both bottom-up and top-down directions across all layers, enabling the learned visual dynamics at different levels of RNNs to communicate.

New in PredRNN-V2 (2021)

This repo also includes the implementation of PredRNN-V2 (2021) [paper], which improves PredRNN in the following two aspects.

1. Memory Decoupling

We find that the pair of memory cells in PredRNN contain undesirable, redundant features, and thus present a memory decoupling loss to encourage them to learn modular structures of visual dynamics.

decouple

2. Reverse Scheduled Sampling

Reverse scheduled sampling is a new curriculum learning strategy for seq-to-seq RNNs. As opposed to scheduled sampling, it gradually changes the training process of the PredRNN encoder from using the previously generated frame to using the previous ground truth. Benefits: (1) It makes the training converge quickly by reducing the encoder-forecaster training gap. (2) It enforces the model to learn more from long-term input context.

rss

Evaluation in LPIPS

LPIPS is more sensitive to perceptual human judgments, the lower the better.

Moving MNIST KTH action
PredRNN 0.109 0.204
PredRNN-V2 0.071 0.139

Prediction examples

mnist

kth

radar

Get Started

  1. Install Python 3.7, PyTorch 1.3, and OpenCV 3.4.
  2. Download data. This repo contains code for two datasets: the Moving Mnist dataset and the KTH action dataset.
  3. Train the model. You can use the following bash script to train the model. The learned model will be saved in the --save_dir folder. The generated future frames will be saved in the --gen_frm_dir folder.
  4. You can get pretrained models from here.
cd mnist_script/
sh predrnn_mnist_train.sh
sh predrnn_v2_mnist_train.sh

cd kth_script/
sh predrnn_kth_train.sh
sh predrnn_v2_kth_train.sh

Citation

If you find this repo useful, please cite the following papers.

@inproceedings{wang2017predrnn,
  title={{PredRNN}: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal {LSTM}s},
  author={Wang, Yunbo and Long, Mingsheng and Wang, Jianmin and Gao, Zhifeng and Yu, Philip S},
  booktitle={Advances in Neural Information Processing Systems},
  pages={879--888},
  year={2017}
}

@misc{wang2021predrnn,
      title={{PredRNN}: A Recurrent Neural Network for Spatiotemporal Predictive Learning}, 
      author={Wang, Yunbo and Wu, Haixu and Zhang, Jianjin and Gao, Zhifeng and Wang, Jianmin and Yu, Philip S and Long, Mingsheng},
      year={2021},
      eprint={2103.09504},
      archivePrefix={arXiv},
}
Owner
THUML: Machine Learning Group @ THSS
Machine Learning Group, School of Software, Tsinghua University
THUML: Machine Learning Group @ THSS
Code for "Graph-Evolving Meta-Learning for Low-Resource Medical Dialogue Generation". [AAAI 2021]

Graph Evolving Meta-Learning for Low-resource Medical Dialogue Generation Code to be further cleaned... This repo contains the code of the following p

Shuai Lin 29 Nov 01, 2022
This is an official implementation for "DeciWatch: A Simple Baseline for 10x Efficient 2D and 3D Pose Estimation"

DeciWatch: A Simple Baseline for 10× Efficient 2D and 3D Pose Estimation This repo is the official implementation of "DeciWatch: A Simple Baseline for

117 Dec 24, 2022
Monitora la qualità della ricezione dei segnali radio nelle province siciliane.

FMap-server Monitora la qualità della ricezione dei segnali radio nelle province siciliane. Conversion data Frequency - StationName maps are stored in

Triglie 5 May 24, 2021
Hydra Lightning Template for Structured Configs

Hydra Lightning Template for Structured Configs Template for creating projects with pytorch-lightning and hydra. How to use this template? Create your

Model-driven Machine Learning 4 Jul 19, 2022
Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs

Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs This is an implemetation of the paper Few-shot Relation Extraction via Baye

MilaGraph 36 Nov 22, 2022
Human POSEitioning System (HPS): 3D Human Pose Estimation and Self-localization in Large Scenes from Body-Mounted Sensors, CVPR 2021

Human POSEitioning System (HPS): 3D Human Pose Estimation and Self-localization in Large Scenes from Body-Mounted Sensors Human POSEitioning System (H

Aymen Mir 66 Dec 21, 2022
Official pytorch implementation of paper "Inception Convolution with Efficient Dilation Search" (CVPR 2021 Oral).

IC-Conv This repository is an official implementation of the paper Inception Convolution with Efficient Dilation Search. Getting Started Download Imag

Jie Liu 111 Dec 31, 2022
Neighborhood Reconstructing Autoencoders

Neighborhood Reconstructing Autoencoders The official repository for Neighborhood Reconstructing Autoencoders (Lee, Kwon, and Park, NeurIPS 2021). T

Yonghyeon Lee 24 Dec 14, 2022
Reverse engineering recurrent neural networks with Jacobian switching linear dynamical systems

Reverse engineering recurrent neural networks with Jacobian switching linear dynamical systems This repository is the official implementation of Rever

6 Aug 25, 2022
DirectVoxGO reconstructs a scene representation from a set of calibrated images capturing the scene.

DirectVoxGO reconstructs a scene representation from a set of calibrated images capturing the scene. We achieve NeRF-comparable novel-view synthesis quality with super-fast convergence.

sunset 709 Dec 31, 2022
CryptoFrog - My First Strategy for freqtrade

cryptofrog-strategies CryptoFrog - My First Strategy for freqtrade NB: (2021-04-20) You'll need the latest freqtrade develop branch otherwise you migh

Robert Davey 137 Jan 01, 2023
RAMA: Rapid algorithm for multicut problem

RAMA: Rapid algorithm for multicut problem Solves multicut (correlation clustering) problems orders of magnitude faster than CPU based solvers without

Paul Swoboda 60 Dec 13, 2022
Pairwise learning neural link prediction for ogb link prediction

Pairwise Learning for Neural Link Prediction for OGB (PLNLP-OGB) This repository provides evaluation codes of PLNLP for OGB link property prediction t

Zhitao WANG 31 Oct 10, 2022
This repository contains code and data for "On the Multimodal Person Verification Using Audio-Visual-Thermal Data"

trimodal_person_verification This repository contains the code, and preprocessed dataset featured in "A Study of Multimodal Person Verification Using

ISSAI 7 Aug 31, 2022
Build Low Code Automated Tensorflow, What-IF explainable models in just 3 lines of code.

Build Low Code Automated Tensorflow explainable models in just 3 lines of code.

Hasan Rafiq 170 Dec 26, 2022
Source code related to the article submitted to the International Conference on Computational Science ICCS 2022 in London

POTHER: Patch-Voted Deep Learning-based Chest X-ray Bias Analysis for COVID-19 Detection Source code related to the article submitted to the Internati

Tomasz Szczepański 1 Apr 29, 2022
Hierarchical Aggregation for 3D Instance Segmentation (ICCV 2021)

HAIS Hierarchical Aggregation for 3D Instance Segmentation (ICCV 2021) by Shaoyu Chen, Jiemin Fang, Qian Zhang, Wenyu Liu, Xinggang Wang*. (*) Corresp

Hust Visual Learning Team 145 Jan 05, 2023
Sparse R-CNN: End-to-End Object Detection with Learnable Proposals, CVPR2021

End-to-End Object Detection with Learnable Proposal, CVPR2021

Peize Sun 1.2k Dec 27, 2022
Convolutional neural network web app trained to track our infant’s sleep schedule using our Google Nest camera.

Machine Learning Sleep Schedule Tracker What is it? Convolutional neural network web app trained to track our infant’s sleep schedule using our Google

g-parki 7 Jul 15, 2022
Code for approximate graph reduction techniques for cardinality-based DSFM, from paper

SparseCard Code for approximate graph reduction techniques for cardinality-based DSFM, from paper "Approximate Decomposable Submodular Function Minimi

Nate Veldt 1 Nov 25, 2022