Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding (CVPR2022)

Overview

Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding

by Qiaole Dong*, Chenjie Cao*, Yanwei Fu

Paper and Supplemental Material (arXiv)

LICENSE

Pipeline

Click to expand

The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.

TO DO

We have updated weights of TSR!

Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.

  • Releasing inference codes.
  • Releasing pre-trained moodel.
  • Releasing training codes.

Preparation

Click to expand
  1. Preparing the environment:

    as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training

    conda create -n train_env python=3.6
    conda activate train_env
    pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
    pip install -r requirement.txt
    git clone https://github.com/NVIDIA/apex
    cd apex
    pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" ./
    
  2. For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.

  3. Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).

  4. Prepare the wireframes:

    as the MST train the LSM-HAWP in Pytorch 1.3.1 and it causes problem (link) when tested in Pytorch 1.9, we recommand to inference the lines(wireframes) with torch==1.3.1. If the line detection is not based on torch1.3.1, the performance may drop a little.

    conda create -n wireframes_inference_env python=3.6
    conda activate wireframes_inference_env
    pip install torch==1.3.1 torchvision==0.4.2
    pip install -r requirement.txt
    

    then extract wireframes with following code

    python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
    
  5. If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:

    mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/
    wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
    

Eval

Click to expand

Download pretrained models on Places2 here.

Link for BaiduDrive, password:qnm5

Batch Test

For batch test, you need to complete steps 3 and 4 above.

Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.

Test on 256 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'

Test on 512 images:

conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'

Single Image Test

Note: For single image test, environment 'wireframes_inference_env' in step 4 is recommended for a better line detection. This code only supports squared images (or they will be center cropped).

conda activate wireframes_inference_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
 --GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./

Training

Click to expand

⚠️ Warning: The training codes is not fully tested yet after refactoring

Training TSR

python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 12 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
 --train_line_path [training_wireframes_path] \
 --mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
 --train_epoch 15 --validation_path [validation_data_path] \
 --val_line_path [validation_wireframes_path] \
 --valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP

Train SSU

We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.

Training LaMa First

python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama

Training FTR

256:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP

256~512:

python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP

More 1K Results

Click to expand

Acknowledgments

Cite

If you found our program helpful, please consider citing:

@inproceedings{dong2022incremental,
      title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding}, 
      author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
      year={2022}
}
Owner
Qiaole Dong
Qiaole Dong
Pretrained models for Jax/Haiku; MobileNet, ResNet, VGG, Xception.

Pre-trained image classification models for Jax/Haiku Jax/Haiku Applications are deep learning models that are made available alongside pre-trained we

Alper Baris CELIK 14 Dec 20, 2022
Attentional Focus Modulates Automatic Finger‑tapping Movements

"Attentional Focus Modulates Automatic Finger‑tapping Movements", in Scientific Reports

Xingxun Jiang 1 Dec 02, 2021
PatrickStar enables Larger, Faster, Greener Pretrained Models for NLP. Democratize AI for everyone.

PatrickStar: Parallel Training of Large Language Models via a Chunk-based Memory Management Meeting PatrickStar Pre-Trained Models (PTM) are becoming

Tencent 633 Dec 28, 2022
A very impractical 3D rendering engine that runs in the python terminal.

Terminal-3D-Render A very impractical 3D rendering engine that runs in the python terminal. do NOT try to run this program using the standard python I

23 Dec 31, 2022
Self-Supervised Learning with Data Augmentations Provably Isolates Content from Style

Self-Supervised Learning with Data Augmentations Provably Isolates Content from Style [NeurIPS 2021] Official code to reproduce the results and data p

Yash Sharma 27 Sep 19, 2022
Official implementation of "Membership Inference Attacks Against Self-supervised Speech Models"

Introduction Official implementation of "Membership Inference Attacks Against Self-supervised Speech Models". In this work, we demonstrate that existi

Wei-Cheng Tseng 7 Nov 01, 2022
Code for paper "Learning to Reweight Examples for Robust Deep Learning"

learning-to-reweight-examples Code for paper Learning to Reweight Examples for Robust Deep Learning. [arxiv] Environment We tested the code on tensorf

Uber Research 261 Jan 01, 2023
J.A.R.V.I.S is an AI virtual assistant made in python.

J.A.R.V.I.S is an AI virtual assistant made in python. Running JARVIS Without Python To run JARVIS without python: 1. Head over to our installation pa

somePythonProgrammer 16 Dec 29, 2022
Hydra: an Extensible Fuzzing Framework for Finding Semantic Bugs in File Systems

Hydra: An Extensible Fuzzing Framework for Finding Semantic Bugs in File Systems Paper Finding Semantic Bugs in File Systems with an Extensible Fuzzin

gts3.org (<a href=[email protected])"> 129 Dec 15, 2022
AAI supports interdisciplinary research to help better understand human, animal, and artificial cognition.

AnimalAI 3 AAI supports interdisciplinary research to help better understand human, animal, and artificial cognition. It aims to support AI research t

Matthew Crosby 58 Dec 12, 2022
Cossim - Sharpened Cosine Distance implementation in PyTorch

Sharpened Cosine Distance PyTorch implementation of the Sharpened Cosine Distanc

Istvan Fehervari 10 Mar 22, 2022
Official PyTorch implementation of the paper "Self-Supervised Relational Reasoning for Representation Learning", NeurIPS 2020 Spotlight.

Official PyTorch implementation of the paper: "Self-Supervised Relational Reasoning for Representation Learning" (2020), Patacchiola, M., and Storkey,

Massimiliano Patacchiola 135 Jan 03, 2023
Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-wise Distributed Data based on Pytorch Framework

VFedPCA+VFedAKPCA This is the official source code for the Paper: Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-

John 9 Sep 18, 2022
A Pytorch implement of paper "Anomaly detection in dynamic graphs via transformer" (TADDY).

TADDY: Anomaly detection in dynamic graphs via transformer This repo covers an reference implementation for the paper "Anomaly detection in dynamic gr

Yue Tan 21 Nov 24, 2022
Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators

Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. It's also a suite of learning algorithms to train agents to operate in these enviro

Google 1.5k Jan 02, 2023
Open Source Light Field Toolbox for Super-Resolution

BasicLFSR BasicLFSR is an open-source and easy-to-use Light Field (LF) image Super-Ressolution (SR) toolbox based on PyTorch, including a collection o

Squidward 50 Nov 18, 2022
To provide 100 JAX exercises over different sections structured as a course or tutorials to teach and learn for beginners, intermediates as well as experts

JaxTon 💯 JAX exercises Mission 🚀 To provide 100 JAX exercises over different sections structured as a course or tutorials to teach and learn for beg

Rohan Rao 512 Jan 01, 2023
Implementation of Wasserstein adversarial attacks.

Stronger and Faster Wasserstein Adversarial Attacks Code for Stronger and Faster Wasserstein Adversarial Attacks, appeared in ICML 2020. This reposito

21 Oct 06, 2022
Fast convergence of detr with spatially modulated co-attention

Fast convergence of detr with spatially modulated co-attention Usage There are no extra compiled components in SMCA DETR and package dependencies are

peng gao 135 Dec 07, 2022
For IBM Quantum Challenge 2021 (May 20 - 26)

IBM Quantum Challenge 2021 Introduction Commemorating the 40-year anniversary of the Physics of Computation conference, and 5-year anniversary of IBM

Qiskit Community 140 Jan 01, 2023