Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Overview

Robust Video Matting (RVM)

Teaser

English | 中文

Official repository for the paper Robust High-Resolution Video Matting with Temporal Guidance. RVM is specifically designed for robust human video matting. Unlike existing neural models that process frames as independent images, RVM uses a recurrent neural network to process videos with temporal memory. RVM can perform matting in real-time on any videos without additional inputs. It achieves 4K 76FPS and HD 104FPS on an Nvidia GTX 1080 Ti GPU. The project was developed at ByteDance Inc.


News

  • [Aug 25 2021] Source code and pretrained models are published.
  • [Jul 27 2021] Paper is accepted by WACV 2022.

Showreel

Watch the showreel video (YouTube, Bilibili) to see the model's performance.

All footage in the video are available in Google Drive and Baidu Pan (code: tb3w).


Demo

  • Webcam Demo: Run the model live in your browser. Visualize recurrent states.
  • Colab Demo: Test our model on your own videos with free GPU.

Download

We recommend MobileNetv3 models for most use cases. ResNet50 models are the larger variant with small performance improvements. Our model is available on various inference frameworks. See inference documentation for more instructions.

Framework Download Notes
PyTorch rvm_mobilenetv3.pth
rvm_resnet50.pth
Official weights for PyTorch. Doc
TorchHub Nothing to Download. Easiest way to use our model in your PyTorch project. Doc
TorchScript rvm_mobilenetv3_fp32.torchscript
rvm_mobilenetv3_fp16.torchscript
rvm_resnet50_fp32.torchscript
rvm_resnet50_fp16.torchscript
If inference on mobile, consider export int8 quantized models yourself. Doc
ONNX rvm_mobilenetv3_fp32.onnx
rvm_mobilenetv3_fp16.onnx
rvm_resnet50_fp32.onnx
rvm_resnet50_fp16.onnx
Tested on ONNX Runtime with CPU and CUDA backends. Provided models use opset 12. Doc, Exporter.
TensorFlow rvm_mobilenetv3_tf.zip
rvm_resnet50_tf.zip
TensorFlow 2 SavedModel. Doc
TensorFlow.js rvm_mobilenetv3_tfjs_int8.zip
Run the model on the web. Demo, Starter Code
CoreML rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel
rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel
CoreML does not support dynamic resolution. Other resolutions can be exported yourself. Models require iOS 13+. s denotes downsample_ratio. Doc, Exporter

All models are available in Google Drive and Baidu Pan (code: gym7).


PyTorch Example

  1. Install dependencies:
pip install -r requirements_inference.txt
  1. Load the model:
import torch
from model import MattingNetwork

model = MattingNetwork('mobilenetv3').eval().cuda()  # or "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
  1. To convert videos, we provide a simple conversion API:
from inference import convert_video

convert_video(
    model,                           # The model, can be on any device (cpu or cuda).
    input_source='input.mp4',        # A video file or an image sequence directory.
    output_type='video',             # Choose "video" or "png_sequence"
    output_composition='output.mp4', # File path if video; directory path if png sequence.
    output_video_mbps=4,             # Output video mbps. Not needed for png sequence.
    downsample_ratio=None,           # A hyperparameter to adjust or use None for auto.
    seq_chunk=12,                    # Process n frames at once for better parallelism.
)
  1. Or write your own inference code:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # Green background.
rec = [None] * 4                                       # Initial recurrent states.
downsample_ratio = 0.25                                # Adjust based on your video.

with torch.no_grad():
    for src in DataLoader(reader):                     # RGB tensor normalized to 0 ~ 1.
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio)  # Cycle the recurrent states.
        com = fgr * pha + bgr * (1 - pha)              # Composite to green background. 
        writer.write(com)                              # Write frame.
  1. The models and converter API are also available through TorchHub.
# Load the model.
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"

# Converter API.
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

Please see inference documentation for details on downsample_ratio hyperparameter, more converter arguments, and more advanced usage.


Training and Evaluation

Please refer to the training documentation to train and evaluate your own model.


Speed

Speed is measured with inference_speed_test.py for reference.

GPU dType HD (1920x1080) 4K (3840x2160)
RTX 3090 FP16 172 FPS 154 FPS
RTX 2060 Super FP16 134 FPS 108 FPS
GTX 1080 Ti FP32 104 FPS 74 FPS
  • Note 1: HD uses downsample_ratio=0.25, 4K uses downsample_ratio=0.125. All tests use batch size 1 and frame chunk 1.
  • Note 2: GPUs before Turing architecture does not support FP16 inference, so GTX 1080 Ti uses FP32.
  • Note 3: We only measure tensor throughput. The provided video conversion script in this repo is expected to be much slower, because it does not utilize hardware video encoding/decoding and does not have the tensor transfer done on parallel threads. If you are interested in implementing hardware video encoding/decoding in Python, please refer to PyNvCodec.

Project Members

Comments
  • [Questions] - Training Procedure

    [Questions] - Training Procedure

    Hi,

    I have some questions about the training procedure:

    1. In the paper, you've mentioned training Stage 1, for 15 epochs, while in the code you've set the instructions to 20 epochs. Is there a reason for such change? Will the results be similar?
    2. I could not get access to Distinctions-646, I had no reply from the authors/maintainers of the dataset. Based on your indicated file structure, I've built a similar dataset, which adds uncertainty to the quality of my training, but it is a risk I am willing to take. To have a comparison parameter (stages 1-3) do not depend on this dataset, would you mind sharing your partial training weights on pytorch (stage1/epoch19.pth, stage2/epoch21.pth, and stage3/epoch22.pth)?
    3. What is the min resolution you've used for the background images while training?

    For the 3rd time, thank you very much for your contribution to the field. It was a brilliant work. Looking forward to your future work.

    opened by SamHSlva 39
  • hardsigmoid replacement

    hardsigmoid replacement

    I've been trying to export an onnx model replacing the hardsigmoid operator.

    I have modified the site-packages/torch/onnx/symbolic_opset9.py file this way:

    @parse_args("v") def hardswish(g, self): hardsigmoid = g.op('HardSigmoid', self, alpha_f=1 / 6) return g.op("Mul", self, hardsigmoid)

    @parse_args("v") def hardsigmoid(g, self): hardsigmoid = g.op('HardSigmoid', self, alpha_f=1 / 6) return g.op("Mul", self, hardsigmoid)

    But I am not sure at all if this is the way to replace them with primitive ops

    When I export the onnx with this change I still get and error "OnnxImportException: Unknown type HardSigmoid encountered while parsing layer 396" with the inference engine I am trying to use.

    opened by livingbeams 15
  • VideoMatte240K-HD

    VideoMatte240K-HD

    if I'm going to train stage3 and stage4, the VideoMatte-HD data will be used. And is it right to modify the following path?VideoMatte240K_JPEG_SD to VideoMatte240K_JPEG_HD

    'videomatte': { 'train': '../matting-data/VideoMatte240K_JPEG_SD/train', 'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid', },

    opened by FengMu1995 7
  • Add Unity example to README?

    Add Unity example to README?

    Hey there, I just ported RVM to Unity using NatML, an open-source machine learning runtime. I have questions:

    1. Can I make a PR to add a link into the README to a Unity example project demonstrating using RVM?
    2. I published the model under my account on NatML Hub. Would you be interested in signing up on Hub, so that I can transfer the model to you?

    Here's the model on NatML Hub:

    @natsuite/robust-video-matting

    opened by olokobayusuf 7
  •  Some questions about training

    Some questions about training

    1.How to eliminate or reduce edge flickering problem,can i set --seq-length-lr Is it possible to increase the sequence length improvement,Does it work? 2.Only the composite image has no foreground image,Is it possible to remove foreground training,and foreground loss?or is there a better way? 3.How important is foreground prediction for matting

    Looking forward to your reply

    opened by zhanghongyong123456 5
  • Not Issue 👉 Few questions

    Not Issue 👉 Few questions

    First of all thank you for working on this project! it looks much stronger than the BMV2 !

    1. Will it work on Anaconda and Windows 10 just like BMV2 works? (not more complicated?)

    2. Will it support same hardware, or need a much more powerful CPU / GPU compare to BMV2 ?

    3. Can you please tell when will you release it again, I missed it first so I can't test it because it's still offline. It will be very nice to have it this week if possible of course.

    Thanks ahead for the answers, please keep up the good work! ❤

    opened by AlonDan 5
  • Synchronization issues between inferred mask and original video

    Synchronization issues between inferred mask and original video

    Hello!. Thanks for the code. I have had some timing issues between the inferred output in the mask compared to the original video. I made this comparison by transforming my original video and the output video from masks to frames. I have obtained the same amount of frames in both processes, so the difference can be caused by a bad configuration of mine. My original video is 30fps and 1080x1920. If you have a suggestion I would appreciate it.

    opened by italosalgado14 4
  • Weird results when use Segmentation Pass for inference

    Weird results when use Segmentation Pass for inference

    https://github.com/PeterL1n/RobustVideoMatting/blob/f8a26e27198a93a94bfd06e96b8d5a34d0660f81/inference.py#L127

    I changed this line to use Segmentation Pass. (use the pretrained weights rvm_mobilenetv3.pth)

    pha, *rec = model(src, *rec, segmentation_pass=True)
    fgr = src * pha
    

    But I got weird mask results, something like this, why?

    seg_pass_alpha

    opened by luuil 4
  • 新手问题的关于模型结果

    新手问题的关于模型结果

    大神辛苦,两个问题请教....... 1.除了更改downsample_ratio的参数值来修正抠图的精度,还可以更改那些参数来更改实现效果? 2.此项目对显卡的要求是否更高?显卡的型号会影响最后结果么? 目前,有执行model的项目,但是效果并不是很理想,再次感谢!

    Hello, two questions to consult.......

    1. In addition to changing the parameter value of downsample_ratio to correct the accuracy of matting, which other parameters can be changed to change the implementation effect?

    2. Does this project have higher requirements for graphics cards? Does the type of graphics card affect the final results?

    I have my own project to implement model, but the effect is not very ideal, thank you!(Translation from Youdao Translation)

    opened by yinjia823 4
  • FP16 is slower than FP32

    FP16 is slower than FP32

    I use pre-trained ONNX model parameters for inference tests (in Python not C++), only onnxruntime, cv2 and numpy libraries, nothing extra. Parameters downloaded from https://github.com/PeterL1n/RobustVideoMatting/releases/: rvm_mobilenetv3_fp32.onnx and rvm_mobilenetv3_fp16.onnx

    Inference on 1080x1920 video,downsample_ratio=0.25. As a result, the speed of FP32 is about 170ms (1 frame), and the speed of FP16 is about 240ms. Why is FP16 so slow?

    I have adjusted the input correctly, for src, r1i, r2i, r3i, r4i it is np.array([[[[]]]], dtype=np.float32 or 16) and for downsample_ratio it is always np.array([0.25], dtype= float32)

    I use CPU (Intel i5) for inference, Is it so slow because the CPU does not support FP16 operations?

    opened by ZachL1 3
  • foreground prediction details

    foreground prediction details

    你好,请教一下,关于前景预测,从官方提供的web demo中,我看到模型预测的前景图片中除了前景(人像)外,还存在输入图片的背景细节(非人像像素),但是我自己训练得到的模型(我的模型没有修改官方的任何细节,唯一的不同仅仅是采用我采集而来的背景图片),预测的前景图片只含有人像而不会存在输入图片的背景细节,一开始我怀疑可能是前景loss包含了所有像素(alpha可以是任何值而不仅仅是像论文中所说的大于0)的loss, 但是我查看代码后没有任何问题,和论文一致,请问这是什么原因造成? 谢谢。

    opened by li-wenquan 3
  • Add Replicate demo and

    Add Replicate demo and

    Hey @PeterL1n ! 👋

    Thanks for this wonderful project!

    This pull request makes it possible to run your model inside a Docker environment, which makes it easier for other people to run it. We're using an open source tool called Cog to make this process easier.

    This also means we can make a web page where other people can run your model! View it here: https://replicate.com/arielreplicate/robust_video_matting Replicate also have an API, so people can easily run your model from their code:

    import replicate
    model = replicate.models.get("[arielreplicate/robust_video_matting]")
    model.predict(...)
    

    If you'd like to modify the Replicate page, let me know and I can transfer ownership to your account.

    In case you're wondering who I am, I'm from Replicate, where we're trying to make machine learning reproducible. We got frustrated that we couldn't run all the really interesting ML work being done. So, we're going round implementing models we like. 😊

    opened by ArielReplicate 0
  • Support of NCNN version

    Support of NCNN version

    Hello How are you? Thanks for contributing to this project. I tried to use NCNN version of this model on Windows C++. Of course, there are several github repos using NCNN model in C++ but I can NOT run them because there are some issue when extracting the output data from model. First, the model are old so it is impossible to run them. Could u support NCNN version here?

    opened by rose-jinyang 0
  • Problem with exporting alpha-mask on the Replicate/COG version

    Problem with exporting alpha-mask on the Replicate/COG version

    I tried both the replicate page and local COG variants. When predicting with alpha-mask, this error is a constant:

    FileNotFoundError: [Errno 2] No such file or directory: 'alpha-mask.mp4'

    Exporting with green-screen and foreground-mask works however. Maybe this is a issue with mp4 not supporting alpha transparent video, so it fails?

    opened by SomeOrdinaryDude 0
  • A question about src_sm in the model.py

    A question about src_sm in the model.py

    I see "x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])" in mobilenetv3.py and ''' f1, f2, f3, f4 = self.backbone(src_sm) ... hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) ''' in model.py. This means the input src_sm of the decoder has not been normalized. Is that your intention?

    opened by FengMu1995 0
  • Performance using grayscale images

    Performance using grayscale images

    Hey there,

    thanks for this amazing tool! Does anybody know how the performance for grayscale images is? I want to use it in the dark and I have an infrared camera.

    If retraining is required how much GPU hours do you think are necessary?

    :)

    opened by bytosaur 1
  • Slow inference and low GPU use.

    Slow inference and low GPU use.

    The inference.py and its running at ~4.2it/s. It barely loads my RTX2060 (0-13% use) The inference_speed_test script gives me ~33.2it/s on the same model and video settings. Changing the --workers on the convert_video() function did nothing. Am I missing something? How can I run inferences faster using the full hardware potential?

    Thanks.

    opened by sharp-trickster 0
A simple tutoral for error correction task, based on Pytorch

gramcorrector A simple tutoral for error correction task, based on Pytorch Grammatical Error Detection (sentence-level) a binary sequence-based classi

peiyuan_gong 8 Dec 03, 2022
Realtime Face Anti Spoofing with Face Detector based on Deep Learning using Tensorflow/Keras and OpenCV

Realtime Face Anti-Spoofing Detection 🤖 Realtime Face Anti Spoofing Detection with Face Detector to detect real and fake faces Please star this repo

Prem Kumar 86 Aug 03, 2022
Segmentation-Aware Convolutional Networks Using Local Attention Masks

Segmentation-Aware Convolutional Networks Using Local Attention Masks [Project Page] [Paper] Segmentation-aware convolution filters are invariant to b

144 Jun 29, 2022
[CVPR-2021] UnrealPerson: An adaptive pipeline for costless person re-identification

UnrealPerson: An Adaptive Pipeline for Costless Person Re-identification In our paper (arxiv), we propose a novel pipeline, UnrealPerson, that decreas

ZhangTianyu 70 Oct 10, 2022
Code for 'Blockwise Sequential Model Learning for Partially Observable Reinforcement Learning' (AAAI 2022)

Blockwise Sequential Model Learning Code for 'Blockwise Sequential Model Learning for Partially Observable Reinforcement Learning' (AAAI 2022) For ins

2 Jun 17, 2022
Meta-TTS: Meta-Learning for Few-shot SpeakerAdaptive Text-to-Speech

Meta-TTS: Meta-Learning for Few-shot SpeakerAdaptive Text-to-Speech This repository is the official implementation of "Meta-TTS: Meta-Learning for Few

Sung-Feng Huang 128 Dec 25, 2022
Code for EMNLP2021 paper "Allocating Large Vocabulary Capacity for Cross-lingual Language Model Pre-training"

VoCapXLM Code for EMNLP2021 paper Allocating Large Vocabulary Capacity for Cross-lingual Language Model Pre-training Environment DockerFile: dancingso

Bo Zheng 15 Jul 28, 2022
CMSC320 - Introduction to Data Science - Fall 2021

CMSC320 - Introduction to Data Science - Fall 2021 Instructors: Elias Jonatan Gonzalez and José Manuel Calderón Trilla Lectures: MW 3:30-4:45 & 5:00-6

Introduction to Data Science 6 Sep 12, 2022
Multi-objective constrained optimization for energy applications via tree ensembles

Multi-objective constrained optimization for energy applications via tree ensembles

C⚙G - Imperial College London 1 Nov 19, 2021
A minimal yet resourceful implementation of diffusion models (along with pretrained models + synthetic images for nine datasets)

A minimal yet resourceful implementation of diffusion models (along with pretrained models + synthetic images for nine datasets)

Vikash Sehwag 65 Dec 19, 2022
Low-code/No-code approach for deep learning inference on devices

EzEdgeAI A concept project that uses a low-code/no-code approach to implement deep learning inference on devices. It provides a componentized framewor

On-Device AI Co., Ltd. 7 Apr 05, 2022
ilpyt: imitation learning library with modular, baseline implementations in Pytorch

ilpyt The imitation learning toolbox (ilpyt) contains modular implementations of common deep imitation learning algorithms in PyTorch, with unified in

The MITRE Corporation 11 Nov 17, 2022
Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

El Bruno 3 Mar 30, 2022
Training neural models with structured signals.

Neural Structured Learning in TensorFlow Neural Structured Learning (NSL) is a new learning paradigm to train neural networks by leveraging structured

955 Jan 02, 2023
Deploy optimized transformer based models on Nvidia Triton server

Deploy optimized transformer based models on Nvidia Triton server

Lefebvre Sarrut Services 1.2k Jan 05, 2023
Pytorch implementation of our paper LIMUSE: LIGHTWEIGHT MULTI-MODAL SPEAKER EXTRACTION.

LiMuSE Overview Pytorch implementation of our paper LIMUSE: LIGHTWEIGHT MULTI-MODAL SPEAKER EXTRACTION. LiMuSE explores group communication on a multi

Auditory Model and Cognitive Computing Lab 17 Oct 26, 2022
[CVPR 2021] Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach

Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach This is the repo to host the dataset TextSeg and code for TexRNe

SHI Lab 174 Dec 19, 2022
PyTorch implementation of "Learn to Dance with AIST++: Music Conditioned 3D Dance Generation."

Learn to Dance with AIST++: Music Conditioned 3D Dance Generation. Installation pip install -r requirements.txt Prepare Dataset bash data/scripts/pre

Zj Li 8 Sep 07, 2021
This is the repository for the paper "Have I done enough planning or should I plan more?"

Metacognitive Learning Tool box https://re.is.mpg.de What Is This? This repository contains two modules used to analyse metacognitive learning in huma

0 Dec 01, 2021