Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Overview

Robust Video Matting (RVM)

Teaser

English | 中文

Official repository for the paper Robust High-Resolution Video Matting with Temporal Guidance. RVM is specifically designed for robust human video matting. Unlike existing neural models that process frames as independent images, RVM uses a recurrent neural network to process videos with temporal memory. RVM can perform matting in real-time on any videos without additional inputs. It achieves 4K 76FPS and HD 104FPS on an Nvidia GTX 1080 Ti GPU. The project was developed at ByteDance Inc.


News

  • [Aug 25 2021] Source code and pretrained models are published.
  • [Jul 27 2021] Paper is accepted by WACV 2022.

Showreel

Watch the showreel video (YouTube, Bilibili) to see the model's performance.

All footage in the video are available in Google Drive and Baidu Pan (code: tb3w).


Demo

  • Webcam Demo: Run the model live in your browser. Visualize recurrent states.
  • Colab Demo: Test our model on your own videos with free GPU.

Download

We recommend MobileNetv3 models for most use cases. ResNet50 models are the larger variant with small performance improvements. Our model is available on various inference frameworks. See inference documentation for more instructions.

Framework Download Notes
PyTorch rvm_mobilenetv3.pth
rvm_resnet50.pth
Official weights for PyTorch. Doc
TorchHub Nothing to Download. Easiest way to use our model in your PyTorch project. Doc
TorchScript rvm_mobilenetv3_fp32.torchscript
rvm_mobilenetv3_fp16.torchscript
rvm_resnet50_fp32.torchscript
rvm_resnet50_fp16.torchscript
If inference on mobile, consider export int8 quantized models yourself. Doc
ONNX rvm_mobilenetv3_fp32.onnx
rvm_mobilenetv3_fp16.onnx
rvm_resnet50_fp32.onnx
rvm_resnet50_fp16.onnx
Tested on ONNX Runtime with CPU and CUDA backends. Provided models use opset 12. Doc, Exporter.
TensorFlow rvm_mobilenetv3_tf.zip
rvm_resnet50_tf.zip
TensorFlow 2 SavedModel. Doc
TensorFlow.js rvm_mobilenetv3_tfjs_int8.zip
Run the model on the web. Demo, Starter Code
CoreML rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel
rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel
CoreML does not support dynamic resolution. Other resolutions can be exported yourself. Models require iOS 13+. s denotes downsample_ratio. Doc, Exporter

All models are available in Google Drive and Baidu Pan (code: gym7).


PyTorch Example

  1. Install dependencies:
pip install -r requirements_inference.txt
  1. Load the model:
import torch
from model import MattingNetwork

model = MattingNetwork('mobilenetv3').eval().cuda()  # or "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
  1. To convert videos, we provide a simple conversion API:
from inference import convert_video

convert_video(
    model,                           # The model, can be on any device (cpu or cuda).
    input_source='input.mp4',        # A video file or an image sequence directory.
    output_type='video',             # Choose "video" or "png_sequence"
    output_composition='output.mp4', # File path if video; directory path if png sequence.
    output_video_mbps=4,             # Output video mbps. Not needed for png sequence.
    downsample_ratio=None,           # A hyperparameter to adjust or use None for auto.
    seq_chunk=12,                    # Process n frames at once for better parallelism.
)
  1. Or write your own inference code:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # Green background.
rec = [None] * 4                                       # Initial recurrent states.
downsample_ratio = 0.25                                # Adjust based on your video.

with torch.no_grad():
    for src in DataLoader(reader):                     # RGB tensor normalized to 0 ~ 1.
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio)  # Cycle the recurrent states.
        com = fgr * pha + bgr * (1 - pha)              # Composite to green background. 
        writer.write(com)                              # Write frame.
  1. The models and converter API are also available through TorchHub.
# Load the model.
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"

# Converter API.
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

Please see inference documentation for details on downsample_ratio hyperparameter, more converter arguments, and more advanced usage.


Training and Evaluation

Please refer to the training documentation to train and evaluate your own model.


Speed

Speed is measured with inference_speed_test.py for reference.

GPU dType HD (1920x1080) 4K (3840x2160)
RTX 3090 FP16 172 FPS 154 FPS
RTX 2060 Super FP16 134 FPS 108 FPS
GTX 1080 Ti FP32 104 FPS 74 FPS
  • Note 1: HD uses downsample_ratio=0.25, 4K uses downsample_ratio=0.125. All tests use batch size 1 and frame chunk 1.
  • Note 2: GPUs before Turing architecture does not support FP16 inference, so GTX 1080 Ti uses FP32.
  • Note 3: We only measure tensor throughput. The provided video conversion script in this repo is expected to be much slower, because it does not utilize hardware video encoding/decoding and does not have the tensor transfer done on parallel threads. If you are interested in implementing hardware video encoding/decoding in Python, please refer to PyNvCodec.

Project Members

Comments
  • [Questions] - Training Procedure

    [Questions] - Training Procedure

    Hi,

    I have some questions about the training procedure:

    1. In the paper, you've mentioned training Stage 1, for 15 epochs, while in the code you've set the instructions to 20 epochs. Is there a reason for such change? Will the results be similar?
    2. I could not get access to Distinctions-646, I had no reply from the authors/maintainers of the dataset. Based on your indicated file structure, I've built a similar dataset, which adds uncertainty to the quality of my training, but it is a risk I am willing to take. To have a comparison parameter (stages 1-3) do not depend on this dataset, would you mind sharing your partial training weights on pytorch (stage1/epoch19.pth, stage2/epoch21.pth, and stage3/epoch22.pth)?
    3. What is the min resolution you've used for the background images while training?

    For the 3rd time, thank you very much for your contribution to the field. It was a brilliant work. Looking forward to your future work.

    opened by SamHSlva 39
  • hardsigmoid replacement

    hardsigmoid replacement

    I've been trying to export an onnx model replacing the hardsigmoid operator.

    I have modified the site-packages/torch/onnx/symbolic_opset9.py file this way:

    @parse_args("v") def hardswish(g, self): hardsigmoid = g.op('HardSigmoid', self, alpha_f=1 / 6) return g.op("Mul", self, hardsigmoid)

    @parse_args("v") def hardsigmoid(g, self): hardsigmoid = g.op('HardSigmoid', self, alpha_f=1 / 6) return g.op("Mul", self, hardsigmoid)

    But I am not sure at all if this is the way to replace them with primitive ops

    When I export the onnx with this change I still get and error "OnnxImportException: Unknown type HardSigmoid encountered while parsing layer 396" with the inference engine I am trying to use.

    opened by livingbeams 15
  • VideoMatte240K-HD

    VideoMatte240K-HD

    if I'm going to train stage3 and stage4, the VideoMatte-HD data will be used. And is it right to modify the following path?VideoMatte240K_JPEG_SD to VideoMatte240K_JPEG_HD

    'videomatte': { 'train': '../matting-data/VideoMatte240K_JPEG_SD/train', 'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid', },

    opened by FengMu1995 7
  • Add Unity example to README?

    Add Unity example to README?

    Hey there, I just ported RVM to Unity using NatML, an open-source machine learning runtime. I have questions:

    1. Can I make a PR to add a link into the README to a Unity example project demonstrating using RVM?
    2. I published the model under my account on NatML Hub. Would you be interested in signing up on Hub, so that I can transfer the model to you?

    Here's the model on NatML Hub:

    @natsuite/robust-video-matting

    opened by olokobayusuf 7
  •  Some questions about training

    Some questions about training

    1.How to eliminate or reduce edge flickering problem,can i set --seq-length-lr Is it possible to increase the sequence length improvement,Does it work? 2.Only the composite image has no foreground image,Is it possible to remove foreground training,and foreground loss?or is there a better way? 3.How important is foreground prediction for matting

    Looking forward to your reply

    opened by zhanghongyong123456 5
  • Not Issue 👉 Few questions

    Not Issue 👉 Few questions

    First of all thank you for working on this project! it looks much stronger than the BMV2 !

    1. Will it work on Anaconda and Windows 10 just like BMV2 works? (not more complicated?)

    2. Will it support same hardware, or need a much more powerful CPU / GPU compare to BMV2 ?

    3. Can you please tell when will you release it again, I missed it first so I can't test it because it's still offline. It will be very nice to have it this week if possible of course.

    Thanks ahead for the answers, please keep up the good work! ❤

    opened by AlonDan 5
  • Synchronization issues between inferred mask and original video

    Synchronization issues between inferred mask and original video

    Hello!. Thanks for the code. I have had some timing issues between the inferred output in the mask compared to the original video. I made this comparison by transforming my original video and the output video from masks to frames. I have obtained the same amount of frames in both processes, so the difference can be caused by a bad configuration of mine. My original video is 30fps and 1080x1920. If you have a suggestion I would appreciate it.

    opened by italosalgado14 4
  • Weird results when use Segmentation Pass for inference

    Weird results when use Segmentation Pass for inference

    https://github.com/PeterL1n/RobustVideoMatting/blob/f8a26e27198a93a94bfd06e96b8d5a34d0660f81/inference.py#L127

    I changed this line to use Segmentation Pass. (use the pretrained weights rvm_mobilenetv3.pth)

    pha, *rec = model(src, *rec, segmentation_pass=True)
    fgr = src * pha
    

    But I got weird mask results, something like this, why?

    seg_pass_alpha

    opened by luuil 4
  • 新手问题的关于模型结果

    新手问题的关于模型结果

    大神辛苦,两个问题请教....... 1.除了更改downsample_ratio的参数值来修正抠图的精度,还可以更改那些参数来更改实现效果? 2.此项目对显卡的要求是否更高?显卡的型号会影响最后结果么? 目前,有执行model的项目,但是效果并不是很理想,再次感谢!

    Hello, two questions to consult.......

    1. In addition to changing the parameter value of downsample_ratio to correct the accuracy of matting, which other parameters can be changed to change the implementation effect?

    2. Does this project have higher requirements for graphics cards? Does the type of graphics card affect the final results?

    I have my own project to implement model, but the effect is not very ideal, thank you!(Translation from Youdao Translation)

    opened by yinjia823 4
  • FP16 is slower than FP32

    FP16 is slower than FP32

    I use pre-trained ONNX model parameters for inference tests (in Python not C++), only onnxruntime, cv2 and numpy libraries, nothing extra. Parameters downloaded from https://github.com/PeterL1n/RobustVideoMatting/releases/: rvm_mobilenetv3_fp32.onnx and rvm_mobilenetv3_fp16.onnx

    Inference on 1080x1920 video,downsample_ratio=0.25. As a result, the speed of FP32 is about 170ms (1 frame), and the speed of FP16 is about 240ms. Why is FP16 so slow?

    I have adjusted the input correctly, for src, r1i, r2i, r3i, r4i it is np.array([[[[]]]], dtype=np.float32 or 16) and for downsample_ratio it is always np.array([0.25], dtype= float32)

    I use CPU (Intel i5) for inference, Is it so slow because the CPU does not support FP16 operations?

    opened by ZachL1 3
  • foreground prediction details

    foreground prediction details

    你好,请教一下,关于前景预测,从官方提供的web demo中,我看到模型预测的前景图片中除了前景(人像)外,还存在输入图片的背景细节(非人像像素),但是我自己训练得到的模型(我的模型没有修改官方的任何细节,唯一的不同仅仅是采用我采集而来的背景图片),预测的前景图片只含有人像而不会存在输入图片的背景细节,一开始我怀疑可能是前景loss包含了所有像素(alpha可以是任何值而不仅仅是像论文中所说的大于0)的loss, 但是我查看代码后没有任何问题,和论文一致,请问这是什么原因造成? 谢谢。

    opened by li-wenquan 3
  • Add Replicate demo and

    Add Replicate demo and

    Hey @PeterL1n ! 👋

    Thanks for this wonderful project!

    This pull request makes it possible to run your model inside a Docker environment, which makes it easier for other people to run it. We're using an open source tool called Cog to make this process easier.

    This also means we can make a web page where other people can run your model! View it here: https://replicate.com/arielreplicate/robust_video_matting Replicate also have an API, so people can easily run your model from their code:

    import replicate
    model = replicate.models.get("[arielreplicate/robust_video_matting]")
    model.predict(...)
    

    If you'd like to modify the Replicate page, let me know and I can transfer ownership to your account.

    In case you're wondering who I am, I'm from Replicate, where we're trying to make machine learning reproducible. We got frustrated that we couldn't run all the really interesting ML work being done. So, we're going round implementing models we like. 😊

    opened by ArielReplicate 0
  • Support of NCNN version

    Support of NCNN version

    Hello How are you? Thanks for contributing to this project. I tried to use NCNN version of this model on Windows C++. Of course, there are several github repos using NCNN model in C++ but I can NOT run them because there are some issue when extracting the output data from model. First, the model are old so it is impossible to run them. Could u support NCNN version here?

    opened by rose-jinyang 0
  • Problem with exporting alpha-mask on the Replicate/COG version

    Problem with exporting alpha-mask on the Replicate/COG version

    I tried both the replicate page and local COG variants. When predicting with alpha-mask, this error is a constant:

    FileNotFoundError: [Errno 2] No such file or directory: 'alpha-mask.mp4'

    Exporting with green-screen and foreground-mask works however. Maybe this is a issue with mp4 not supporting alpha transparent video, so it fails?

    opened by SomeOrdinaryDude 0
  • A question about src_sm in the model.py

    A question about src_sm in the model.py

    I see "x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])" in mobilenetv3.py and ''' f1, f2, f3, f4 = self.backbone(src_sm) ... hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) ''' in model.py. This means the input src_sm of the decoder has not been normalized. Is that your intention?

    opened by FengMu1995 0
  • Performance using grayscale images

    Performance using grayscale images

    Hey there,

    thanks for this amazing tool! Does anybody know how the performance for grayscale images is? I want to use it in the dark and I have an infrared camera.

    If retraining is required how much GPU hours do you think are necessary?

    :)

    opened by bytosaur 1
  • Slow inference and low GPU use.

    Slow inference and low GPU use.

    The inference.py and its running at ~4.2it/s. It barely loads my RTX2060 (0-13% use) The inference_speed_test script gives me ~33.2it/s on the same model and video settings. Changing the --workers on the convert_video() function did nothing. Am I missing something? How can I run inferences faster using the full hardware potential?

    Thanks.

    opened by sharp-trickster 0
Official repository for the paper F, B, Alpha Matting

FBA Matting Official repository for the paper F, B, Alpha Matting. This paper and project is under heavy revision for peer reviewed publication, and s

Marco Forte 404 Jan 05, 2023
Deep learning models for change detection of remote sensing images

Change Detection Models (Remote Sensing) Python library with Neural Networks for Change Detection based on PyTorch. ⚡ ⚡ ⚡ I am trying to build this pr

Kaiyu Li 176 Dec 24, 2022
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 04, 2023
This is an implementation of Googles Yogi-Optimizer in Keras (tf.keras)

Yogi-Optimizer_Keras This is an implementation of Googles Yogi-Optimizer in Keras (tf.keras) The NeurIPS-Paper can be found here: http://papers.nips.c

14 Sep 13, 2022
RepVGG: Making VGG-style ConvNets Great Again

This repository is the code that needs to be submitted for OpenMMLab Algorithm Ecological Challenge,the paper is RepVGG: Making VGG-style ConvNets Great Again

Ty Feng 62 May 21, 2022
Unofficial Implementation of MLP-Mixer, gMLP, resMLP, Vision Permutator, S2MLPv2, RaftMLP, ConvMLP, ConvMixer in Jittor and PyTorch.

Unofficial Implementation of MLP-Mixer, gMLP, resMLP, Vision Permutator, S2MLPv2, RaftMLP, ConvMLP, ConvMixer in Jittor and PyTorch! Now, Rearrange and Reduce in einops.layers.jittor are support!!

130 Jan 08, 2023
Official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Recognition" in AAAI2022.

AimCLR This is an official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Reco

Gty 44 Dec 17, 2022
Import Python modules from dicts and JSON formatted documents.

Paker Paker is module for importing Python packages/modules from dictionaries and JSON formatted documents. It was inspired by httpimporter. Important

Wojciech Wentland 1 Sep 07, 2022
PyTorch implementation of "PatchGame: Learning to Signal Mid-level Patches in Referential Games" to appear in NeurIPS 2021

PatchGame: Learning to Signal Mid-level Patches in Referential Games This repository is the official implementation of the paper - "PatchGame: Learnin

Kamal Gupta 22 Mar 16, 2022
Constructing interpretable quadratic accuracy predictors to serve as an objective function for an IQCQP problem that represents NAS under latency constraints and solve it with efficient algorithms.

IQNAS: Interpretable Integer Quadratic programming Neural Architecture Search Realistic use of neural networks often requires adhering to multiple con

0 Oct 24, 2021
Introduction to CPM

CPM CPM is an open-source program on large-scale pre-trained models, which is conducted by Beijing Academy of Artificial Intelligence and Tsinghua Uni

Tsinghua AI 136 Dec 23, 2022
PyTorch implementation of Trust Region Policy Optimization

PyTorch implementation of TRPO Try my implementation of PPO (aka newer better variant of TRPO), unless you need to you TRPO for some specific reasons.

Ilya Kostrikov 366 Nov 15, 2022
Direct LiDAR Odometry: Fast Localization with Dense Point Clouds

Direct LiDAR Odometry: Fast Localization with Dense Point Clouds DLO is a lightweight and computationally-efficient frontend LiDAR odometry solution w

VECTR at UCLA 369 Dec 30, 2022
An experiment to bait a generalized frontrunning MEV bot

Honeypot 🍯 A simple experiment that: Creates a honeypot contract Baits a generalized fronturnning bot with a unique transaction Analyze bot behaviour

0x1355 14 Nov 24, 2022
Python implementation of Bayesian optimization over permutation spaces.

Bayesian Optimization over Permutation Spaces This repository contains the source code and the resources related to the paper "Bayesian Optimization o

Aryan Deshwal 9 Dec 23, 2022
[CVPR 2016] Unsupervised Feature Learning by Image Inpainting using GANs

Context Encoders: Feature Learning by Inpainting CVPR 2016 [Project Website] [Imagenet Results] Sample results on held-out images: This is the trainin

Deepak Pathak 829 Dec 31, 2022
An open source bike computer based on Raspberry Pi Zero (W, WH) with GPS and ANT+. Including offline map and navigation.

Pi Zero Bikecomputer An open-source bike computer based on Raspberry Pi Zero (W, WH) with GPS and ANT+ https://github.com/hishizuka/pizero_bikecompute

hishizuka 264 Jan 02, 2023
A research toolkit for particle swarm optimization in Python

PySwarms is an extensible research toolkit for particle swarm optimization (PSO) in Python. It is intended for swarm intelligence researchers, practit

Lj Miranda 1k Dec 30, 2022
🔥 Real-time Super Resolution enhancement (4x) with content loss and relativistic adversarial optimization 🔥

🔥 Real-time Super Resolution enhancement (4x) with content loss and relativistic adversarial optimization 🔥

Rishik Mourya 48 Dec 20, 2022
Python-kafka-reset-consumergroup-offset-example - Python Kafka reset consumergroup offset example

Python Kafka reset consumergroup offset example This is a simple example of how

Willi Carlsen 1 Feb 16, 2022