VideoGPT: Video Generation using VQ-VAE and Transformers

Related tags

Deep LearningVideoGPT
Overview

VideoGPT: Video Generation using VQ-VAE and Transformers

[Paper][Website][Colab][Gradio Demo]

We present VideoGPT: a conceptually simple architecture for scaling likelihood based generative modeling to natural videos. VideoGPT uses VQ-VAE that learns downsampled discrete latent representations of a raw video by employing 3D convolutions and axial self-attention. A simple GPT-like architecture is then used to autoregressively model the discrete latents using spatio-temporal position encodings. Despite the simplicity in formulation and ease of training, our architecture is able to generate samples competitive with state-of-the-art GAN models for video generation on the BAIR Robot dataset, and generate high fidelity natural images from UCF-101 and Tumbler GIF Dataset (TGIF). We hope our proposed architecture serves as a reproducible reference for a minimalistic implementation of transformer based video generation models.

Approach

VideoGPT

Installation

Change the cudatoolkit version compatible to your machine.

$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
$ pip install git+https://github.com/wilson1yan/VideoGPT.git

Sparse Attention (Optional)

For limited compute scenarios, it may be beneficial to use sparse attention.

$ sudo apt-get install llvm-9-dev
$ DS_BUILD_SPARSE_ATTN=1 pip install deepspeed

After installng deepspeed, you can train a sparse transformer by setting the flag --attn_type sparse in scripts/train_videogpt.py. The default support sparsity configuration is an N-d strided sparsity layout, however, you can write your own arbitrary layouts to use.

Dataset

The default code accepts data as an HDF5 file with the specified format in videogpt/data.py, and a directory format with the follow structure:

video_dataset/
    train/
        class_0/
            video1.mp4
            video2.mp4
            ...
        class_1/
            video1.mp4
            ...
        ...
        class_n/
            ...
    test/
        class_0/
            video1.mp4
            video2.mp4
            ...
        class_1/
            video1.mp4
            ...
        ...
        class_n/
            ...

An example of such a dataset can be constructed from UCF-101 data by running the script

sh scripts/preprocess/create_ucf_dataset.sh datasets/ucf101

You may need to install unrar and unzip for the code to work correctly.

If you do not care about classes, the class folders are not necessary and the dataset file structure can be collapsed into train and test directories of just videos.

Using Pretrained VQ-VAEs

There are four available pre-trained VQ-VAE models. All strides listed with each model are downsampling amounts across THW for the encoders.

  • bair_stride4x2x2: trained on 16 frame 64 x 64 videos from the BAIR Robot Pushing dataset
  • ucf101_stride4x4x4: trained on 16 frame 128 x 128 videos from UCF-101
  • kinetics_stride4x4x4: trained on 16 frame 128 x 128 videos from Kinetics-600
  • kinetics_stride2x4x4: trained on 16 frame 128 x 128 videos from Kinetics-600, with 2x larger temporal latent codes (achieves slightly better reconstruction)
from torchvision.io import read_video
from videogpt import load_vqvae
from videogpt.data import preprocess

video_filename = 'path/to/video_file.mp4'
sequence_length = 16
resolution = 128
device = torch.device('cuda')

vqvae = load_vqvae('kinetics_stride2x4x4')
video = read_video(video_filename, pts_unit='sec')[0]
video = preprocess(video, resolution, sequence_length).unsqueeze(0).to(device)

encodings = vqvae.encode(video)
video_recon = vqvae.decode(encodings)

Training VQ-VAE

Use the scripts/train_vqvae.py script to train a VQ-VAE. Execute python scripts/train_vqvae.py -h for information on all available training settings. A subset of more relevant settings are listed below, along with default values.

VQ-VAE Specific Settings

  • --embedding_dim: number of dimensions for codebooks embeddings
  • --n_codes 2048: number of codes in the codebook
  • --n_hiddens 240: number of hidden features in the residual blocks
  • --n_res_layers 4: number of residual blocks
  • --downsample 4 4 4: T H W downsampling stride of the encoder

Training Settings

  • --gpus 2: number of gpus for distributed training
  • --sync_batchnorm: uses SyncBatchNorm instead of BatchNorm3d when using > 1 gpu
  • --gradient_clip_val 1: gradient clipping threshold for training
  • --batch_size 16: batch size per gpu
  • --num_workers 8: number of workers for each DataLoader

Dataset Settings

  • --data_path : path to an hdf5 file or a folder containing train and test folders with subdirectories of videos
  • --resolution 128: spatial resolution to train on
  • --sequence_length 16: temporal resolution, or video clip length

Training VideoGPT

You can download a pretrained VQ-VAE, or train your own. Afterwards, use the scripts/train_videogpt.py script to train an VideoGPT model for sampling. Execute python scripts/train_videogpt.py -h for information on all available training settings. A subset of more relevant settings are listed below, along with default values.

VideoGPT Specific Settings

  • --vqvae kinetics_stride4x4x4: path to a vqvae checkpoint file, OR a pretrained model name to download. Available pretrained models are: bair_stride4x2x2, ucf101_stride4x4x4, kinetics_stride4x4x4, kinetics_stride2x4x4. BAIR was trained on 64 x 64 videos, and the rest on 128 x 128 videos
  • --n_cond_frames 0: number of frames to condition on. 0 represents a non-frame conditioned model
  • --class_cond: trains a class conditional model if activated
  • --hidden_dim 576: number of transformer hidden features
  • --heads 4: number of heads for multihead attention
  • --layers 8: number of transformer layers
  • --dropout 0.2': dropout probability applied to features after attention and positionwise feedforward layers
  • --attn_type full: full or sparse attention. Refer to the Installation section for install sparse attention
  • --attn_dropout 0.3: dropout probability applied to the attention weight matrix

Training Settings

  • --gpus 2: number of gpus for distributed training
  • --sync_batchnorm: uses SyncBatchNorm instead of BatchNorm3d when using > 1 gpu
  • --gradient_clip_val 1: gradient clipping threshold for training
  • --batch_size 16: batch size per gpu
  • --num_workers 8: number of workers for each DataLoader

Dataset Settings

  • --data_path : path to an hdf5 file or a folder containing train and test folders with subdirectories of videos
  • --resolution 128: spatial resolution to train on
  • --sequence_length 16: temporal resolution, or video clip length

Sampling VideoGPT

After training, the VideoGPT model can be sampled using the scripts/sample_videogpt.py. You may need to install ffmpeg: sudo apt-get install ffmpeg

Reproducing Paper Results

Note that this repo is primarily designed for simplicity and extending off of our method. Reproducing the full paper results can be done using code found at a separate repo. However, be aware that the code is not as clean.

Citation

Please consider using the follow citation when using our code:

@misc{yan2021videogpt,
      title={VideoGPT: Video Generation using VQ-VAE and Transformers}, 
      author={Wilson Yan and Yunzhi Zhang and Pieter Abbeel and Aravind Srinivas},
      year={2021},
      eprint={2104.10157},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
Owner
Wilson Yan
1st year PhD interested in unsupervised learning and reinforcement learning
Wilson Yan
A tight inclusion function for continuous collision detection

Tight-Inclusion Continuous Collision Detection A conservative Continuous Collision Detection (CCD) method with support for minimum separation. You can

Continuous Collision Detection 89 Jan 01, 2023
🔥🔥High-Performance Face Recognition Library on PaddlePaddle & PyTorch🔥🔥

face.evoLVe: High-Performance Face Recognition Library based on PaddlePaddle & PyTorch Evolve to be more comprehensive, effective and efficient for fa

Zhao Jian 3.1k Jan 02, 2023
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data

Real-ESRGAN Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data Ported from https://github.com/xinntao/Real-ESRGAN Depend

Holy Wu 44 Dec 27, 2022
Robustness between the worst and average case

Robustness between the worst and average case A repository that implements intermediate robustness training and evaluation from the NeurIPS 2021 paper

CMU Locus Lab 16 Dec 02, 2022
Learn about quantum computing and algorithm on quantum computing

quantum_computing this repo contains everything i learn about quantum computing and algorithm on quantum computing what is aquantum computing quantum

arfy slowy 8 Dec 25, 2022
The materials used in the SaxonJS tutorial presented at Declarative Amsterdam, 2021

SaxonJS-Tutorial-2021, version 1.0.4 Last updated on 4 November, 2021. Table of contents Background Prerequisites Starting a web server Running a Java

Saxonica 11 Oct 23, 2022
Implementation of Auto-Conditioned Recurrent Networks for Extended Complex Human Motion Synthesis

acLSTM_motion This folder contains an implementation of acRNN for the CMU motion database written in Pytorch. See the following links for more backgro

Yi_Zhou 61 Sep 07, 2022
unofficial pytorch implement of "Squareplus: A Softplus-Like Algebraic Rectifier"

SquarePlus (Pytorch implement) unofficial pytorch implement of "Squareplus: A Softplus-Like Algebraic Rectifier" SquarePlus Squareplus is a Softplus-L

SeeFun 3 Dec 29, 2021
Rank 1st in the public leaderboard of ScanRefer (2021-03-18)

InstanceRefer InstanceRefer: Cooperative Holistic Understanding for Visual Grounding on Point Clouds through Instance Multi-level Contextual Referring

63 Dec 07, 2022
Group Activity Recognition with Clustered Spatial Temporal Transformer

GroupFormer Group Activity Recognition with Clustered Spatial-TemporalTransformer Backbone Style Action Acc Activity Acc Config Download Inv3+flow+pos

28 Dec 12, 2022
Mmdet benchmark with python

mmdet_benchmark 本项目是为了研究 mmdet 推断性能瓶颈,并且对其进行优化。 配置与环境 机器配置 CPU:Intel(R) Core(TM) i9-10900K CPU @ 3.70GHz GPU:NVIDIA GeForce RTX 3080 10GB 内存:64G 硬盘:1T

杨培文 (Yang Peiwen) 24 May 21, 2022
Code related to the manuscript "Averting A Crisis In Simulation-Based Inference"

Abstract We present extensive empirical evidence showing that current Bayesian simulation-based inference algorithms are inadequate for the falsificat

Montefiore Artificial Intelligence Research 3 Nov 14, 2022
Fortuitous Forgetting in Connectionist Networks

Fortuitous Forgetting in Connectionist Networks Introduction This repository includes reference code for the paper Fortuitous Forgetting in Connection

Hattie Zhou 14 Nov 26, 2022
Simple Python application to transform Serial data into OSC messages

SerialToOSC-Bridge Simple Python application to transform Serial data into OSC messages. The current purpose is to be a compatibility layer between ha

Division of Applied Acoustics at Chalmers University of Technology 3 Jun 03, 2021
Recreate CenternetV2 based on MMDET.

Introduction This project is trying to Recreate CenternetV2 based on MMDET, which is proposed in paper Probabilistic two-stage detection. This project

25 Dec 09, 2022
Rax is a Learning-to-Rank library written in JAX

🦖 Rax: Composable Learning to Rank using JAX Rax is a Learning-to-Rank library written in JAX. Rax provides off-the-shelf implementations of ranking

Google 247 Dec 27, 2022
Discriminative Region Suppression for Weakly-Supervised Semantic Segmentation

Discriminative Region Suppression for Weakly-Supervised Semantic Segmentation (AAAI 2021) Official pytorch implementation of our paper: Discriminative

Beom 74 Dec 27, 2022
Code for the AI lab course 2021/2022 of the University of Verona

AI-Lab Code for the AI lab course 2021/2022 of the University of Verona Set-Up the environment for the curse Download Anaconda for your System. Instal

Davide Corsi 5 Oct 19, 2022
Visualizing lattice vibration information from phonon dispersion to atoms (For GPUMD)

Phonon-Vibration-Viewer (For GPUMD) Visualizing lattice vibration information from phonon dispersion for primitive atoms. In this tutorial, we will in

Liangting 6 Dec 10, 2022
DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs

DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs Abstract: Image-to-image translation has recently achieved re

yaxingwang 23 Apr 14, 2022