Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT)

Overview

Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT)

Paper, Project Page

This repo contains the official implementation of CVPR 2020 paper: Semi-Supervised Semantic Segmentation with Cross-Consistecy Training, which adapts the traditional consistency training framework of semi-supervised learning for semantic segmentation, with an extension to weak-supervised learning and learning on multiple domains.

Highlights

(1) Consistency Training for semantic segmentation.
We observe that for semantic segmentation, due to the dense nature of the task, the cluster assumption is more easily enforced over the hidden representations rather than the inputs.

(2) Cross-Consistecy Training.
We propose CCT (Cross-Consistecy Training) for semi-supervised semantic segmentation, where we define a number of novel perturbations, and show the effectiveness of enforcing consistency over the encoder's outputs rather than the inputs.

(3) Using weak-labels and pixel-level labels from multiple domains.
The proposed method is quite simple and flexible, and can easily be extended to use image-level labels and pixel-level labels from multiple-domains.

Requirements

This repo was tested with Ubuntu 18.04.3 LTS, Python 3.7, PyTorch 1.1.0, and CUDA 10.0. But it should be runnable with recent PyTorch versions >=1.1.0.

The required packages are pytorch and torchvision, together with PIL and opencv for data-preprocessing and tqdm for showing the training progress. With some additional modules like dominate to save the results in the form of HTML files. To setup the necessary modules, simply run:

pip install -r requirements.txt

Dataset

In this repo, we use Pascal VOC, to obtain it, first download the original dataset, after extracting the files we'll end up with VOCtrainval_11-May-2012/VOCdevkit/VOC2012 containing the image sets, the XML annotation for both object detection and segmentation, and JPEG images.
The second step is to augment the dataset using the additionnal annotations provided by Semantic Contours from Inverse Detectors. Download the rest of the annotations SegmentationClassAug and add them to the path VOCtrainval_11-May-2012/VOCdevkit/VOC2012, now we're set, for training use the path to VOCtrainval_11-May-2012.

Training

To train a model, first download PASCAL VOC as detailed above, then set data_dir to the dataset path in the config file in configs/config.json and set the rest of the parameters, like the number of GPUs, cope size, data augmentation ... etc ,you can also change CCT hyperparameters if you wish, more details below. Then simply run:

python train.py --config configs/config.json

The log files and the .pth checkpoints will be saved in saved\EXP_NAME, to monitor the training using tensorboard, please run:

tensorboard --logdir saved

To resume training using a saved .pth model:

python train.py --config configs/config.json --resume saved/CCT/checkpoint.pth

Results: The results will be saved in saved as an html file, containing the validation results, and the name it will take is experim_name specified in configs/config.json.

Pseudo-labels

If you want to use image level labels to train the auxiliary labels as explained in section 3.3 of the paper. First generate the pseudo-labels using the code in pseudo_labels:

cd pseudo_labels
python run.py --voc12_root DATA_PATH

DATA_PATH must point to the folder containing JPEGImages in Pascal Voc dataset. The results will be saved in pseudo_labels/result/pseudo_labels as PNG files, the flag use_weak_labels needs to be set to True in the config file, and then we can train the model as detailed above.

Inference

For inference, we need a pretrained model, the jpg images we'd like to segment and the config used in training (to load the correct model and other parameters),

python inference.py --config config.json --model best_model.pth --images images_folder

The predictions will be saved as .png images in outputs\ is used, for Pacal VOC the default palette is:

Here are the flags available for inference:

--images       Folder containing the jpg images to segment.
--model        Path to the trained pth model.
--config       The config file used for training the model.

Pre-trained models

Pre-trained models can be downloaded here.

Citation ✏️ 📄

If you find this repo useful for your research, please consider citing the paper as follows:

@InProceedings{Ouali_2020_CVPR,
  author = {Ouali, Yassine and Hudelot, Celine and Tami, Myriam},
  title = {Semi-Supervised Semantic Segmentation With Cross-Consistency Training},
  booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  month = {June},
  year = {2020}
}

For any questions, please contact Yassine Ouali.

Config file details ⚙️

Bellow we detail the CCT parameters that can be controlled in the config file configs/config.json, the rest of the parameters are self-explanatory.

{
    "name": "CCT",                              
    "experim_name": "CCT",                             // The name the results will take (html and the folder in /saved)
    "n_gpu": 1,                                             // Number of GPUs
    "n_labeled_examples": 1000,                             // Number of labeled examples (choices are 60, 100, 200, 
                                                            // 300, 500, 800, 1000, 1464, and the splits are in dataloaders/voc_splits)
    "diff_lrs": true,
    "ramp_up": 0.1,                                         // The unsupervised loss will be slowly scaled up in the first 10% of Training time
    "unsupervised_w": 30,                                   // Weighting of the unsupervised loss
    "ignore_index": 255,
    "lr_scheduler": "Poly",
    "use_weak_labels": false,                               // If the pseudo-labels were generated, we can use them to train the aux. decoders
    "weakly_loss_w": 0.4,                                   // Weighting of the weakly-supervised loss
    "pretrained": true,

    "model":{
        "supervised": true,                                  // Supervised setting (training only on the labeled examples)
        "semi": false,                                       // Semi-supervised setting
        "supervised_w": 1,                                   // Weighting of the supervised loss

        "sup_loss": "CE",                                    // supervised loss, choices are CE and ab-CE = ["CE", "ABCE"]
        "un_loss": "MSE",                                    // unsupervised loss, choices are CE and KL-divergence = ["MSE", "KL"]

        "softmax_temp": 1,
        "aux_constraint": false,                             // Pair-wise loss (sup. mat.)
        "aux_constraint_w": 1,
        "confidence_masking": false,                         // Confidence masking (sup. mat.)
        "confidence_th": 0.5,

        "drop": 6,                                           // Number of DropOut decoders
        "drop_rate": 0.5,                                    // Dropout probability
        "spatial": true,
    
        "cutout": 6,                                         // Number of G-Cutout decoders
        "erase": 0.4,                                        // We drop 40% of the area
    
        "vat": 2,                                            // Number of I-VAT decoders
        "xi": 1e-6,                                          // VAT parameters
        "eps": 2.0,

        "context_masking": 2,                               // Number of Con-Msk decoders
        "object_masking": 2,                                // Number of Obj-Msk decoders
        "feature_drop": 6,                                  // Number of F-Drop decoders

        "feature_noise": 6,                                 // Number of F-Noise decoders
        "uniform_range": 0.3                                // The range of the noise
    },

Acknowledgements

  • Pseudo-labels generation is based on Jiwoon Ahn's implementation irn.
  • Code structure was based on Pytorch-Template
  • ResNet backbone was downloaded from torchcv
Comments
  • custom dataset with 4 classes

    custom dataset with 4 classes

    Thank you so far for all your great help. I have an issue that I also found in the closed issues, but for me it isn't solved. I have my own custom data set with 4 classes (background and 3 objects, labeled 0-3), so I changed num_classes = 4 in voc.py The results with training fully supervised are as shown in the image below. There is one class with an IoU of 0.0. image I ran multiple tests, using semi and weakly supervised settings, the results are unpredictable and often show 0.0 for the object classes. Only the background has good results. Is there something I need to adjust in the code?

    opened by SuzannaLin 22
  • Training error!

    Training error!

    I want to train VOC2012, but get the error below:

    Traceback (most recent call last):
      File "train.py", line 98, in <module>
        main(config, args.resume)
      File "train.py", line 82, in main
        trainer.train()
      File "/home/byronnar/bigfile/projects/CCT/base/base_trainer.py", line 91, in train
        results = self._train_epoch(epoch)
      File "/home/byronnar/bigfile/projects/CCT/trainer.py", line 76, in _train_epoch
        total_loss, cur_losses, outputs = self.model(x_l=input_l, target_l=target_l, x_ul=input_ul, curr_iter=batch_idx, target_ul=target_ul, epoch=epoch-1)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
        return self.module(*inputs[0], **kwargs[0])
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/model.py", line 93, in forward
        output_l = self.main_decoder(self.encoder(x_l))
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 61, in forward
        x = self.psp(x)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 36, in forward
        align_corners=False) for stage in self.stages])
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 36, in <listcomp>
        align_corners=False) for stage in self.stages])
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
        input = module(input)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 81, in forward
        exponential_average_factor, self.eps)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py", line 1652, in batch_norm
        raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
    ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])
      0%|                                                                                                         | 0/9118 [00:02<?, ?it/s]
    

    How should I do? Thank you

    opened by Byronnar 12
  • Poor mIoU when training on 1464 images in a supervised manner

    Poor mIoU when training on 1464 images in a supervised manner

    Hi, I trained the model on the 1464 images in a supervised manner. The highest mIoU on Val is 67.00%, but the reported number in your paper is 69.4%. Here is my config.json file. Can you have a look at which part is wrong?

     {  "name": "CCT",
        "experim_name": "CCT",
        "n_gpu": 1,
        "n_labeled_examples": 1464,
        "diff_lrs": true,
        "ramp_up": 0.1,
        "unsupervised_w": 30,
        "ignore_index": 255,
        "lr_scheduler": "Poly",
        "use_weak_lables":false,
        "weakly_loss_w": 0.4,
        "pretrained": true,
        "model":{
            "supervised": true,
            "semi": false,
            "supervised_w": 1,
    
            "sup_loss": "CE",
            "un_loss": "MSE",
    
            "softmax_temp": 1,
            "aux_constraint": false,
            "aux_constraint_w": 1,
            "confidence_masking": false,
            "confidence_th": 0.5,
    
            "drop": 6,
            "drop_rate": 0.5,
            "spatial": true,
        
            "cutout": 6,
            "erase": 0.4,
        
            "vat": 2,
            "xi": 1e-6,
            "eps": 2.0,
    
            "context_masking": 2,
            "object_masking": 2,
            "feature_drop": 6,
    
            "feature_noise": 6,
            "uniform_range": 0.3
        },
    
    
        "optimizer": {
            "type": "SGD",
            "args":{
                "lr": 1e-2,
                "weight_decay": 1e-4,
                "momentum": 0.9
            }
        },
    
    
        "train_supervised": {
            "data_dir": "../data/VOC2012",
            "batch_size": 10,
            "crop_size": 320,
            "shuffle": true,
            "base_size": 400,
            "scale": true,
            "augment": true,
            "flip": true,
            "rotate": false,
            "blur": false,
            "split": "train_supervised",
            "num_workers": 8
        },
    
        "train_unsupervised": {
            "data_dir": "VOCtrainval_11-May-2012",
            "weak_labels_output": "pseudo_labels/result/pseudo_labels",
            "batch_size": 10,
            "crop_size": 320,
            "shuffle": true,
            "base_size": 400,
            "scale": true,
            "augment": true,
            "flip": true,
            "rotate": false,
            "blur": false,
            "split": "train_unsupervised",
            "num_workers": 8
        },
    
        "val_loader": {
            "data_dir": "../data/VOC2012",
            "batch_size": 1,
            "val": true,
            "split": "val",
            "shuffle": false,
            "num_workers": 4
        },
    
        "trainer": {
            "epochs": 80,
            "save_dir": "saved/",
            "save_period": 5,
      
            "monitor": "max Mean_IoU",
            "early_stop": 10,
            
            "tensorboardX": true,
            "log_dir": "saved/",
            "log_per_iter": 20,
    
            "val": true,
            "val_per_epochs": 5
        }
    }
    
    opened by xiaomengyc 11
  • Fail to reimplement your paper's result for semi-supervised.

    Fail to reimplement your paper's result for semi-supervised.

    I use the default config file to conduct experiments, but I only got 68.9mIoU for not adopting weak label and got 70.09mIoU for adopting weak label following your readme. These results are far lower than yours. My env is pytorch 1.7.0 and python 3.8.5. Could provide some advice?

    opened by TyroneLi 8
  • checkerboard

    checkerboard

    Hi Yassine, I am using the CCT model to train on a satellite dataset. The images are size 128x128. For some reason the predictions show a clear checkerboard pattern as shown in this example. Left: prediction, Right: ground truth. image Do you have any idea what causes this and how to avoid it?

    opened by SuzannaLin 7
  • inference with 4-channel model

    inference with 4-channel model

    Hi Yassine! I have managed to train a model with 4 channels, but the inference is not working. I get this error message:

    !python inference.py --config configs/config_70_30_sup_alti.json --model './saved/ABCE_70_30_sup_alti/best_model.pth' --output 'CCT_output/ABCE_70_30_sup_alti/Angers/' --images 'val/Angers/BDORTHO'

    Loading pretrained model:models/backbones/pretrained/3x3resnet50-imagenet.pth Traceback (most recent call last): File "inference.py", line 155, in main() File "inference.py", line 102, in main conf=config['model'], testing=True, pretrained = True) File "/home/scuypers/CCT_4/models/model.py", line 55, in init self.encoder = Encoder(pretrained=pretrained) File "/home/scuypers/CCT_4/models/encoder.py", line 49, in init model = ResNetBackbone(backbone='deepbase_resnet50_dilated8', pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/resnet_backbone.py", line 145, in ResNetBackbone orig_resnet = deepbase_resnet50(pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/resnet_models.py", line 227, in deepbase_resnet50 model = ModuleHelper.load_model(model, pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/module_helper.py", line 109, in load_model model.load_state_dict(load_dict) File "/home/scuypers/.conda/envs/envCCT/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for ResNet: size mismatch for prefix.conv1.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 4, 3, 3]).

    opened by SuzannaLin 7
  • Loss function

    Loss function

    Thank you for your contribution. I want to know what the |Dl| and |Du| in your cross-entropy loss function formula and semi-supervised loss function formula represent, thank you for your answer.

    opened by yrcrcy 7
  • Reproducing Cross-domain Experiments

    Reproducing Cross-domain Experiments

    @yassouali @SuzannaLin Firstly, thanks a lot for your great work! But I still have problems in reproducing the cross-domain experiments.

    I have seen the implementation from https://github.com/tarun005/USSS_ICCV19, however, I notice that there are still many alternatives for the training schemes.

    During training, in each iteration (i.e. each execution of 'optimizer.step()'), do you train the network by forwarding inputs from both two datasets? Or by forwarding inputs from only one dataset in the current iteration, and then executing "optimizer.step()", and then forwarding inputs from the other dataset in the next iteration, and so on?

    Also, are there any tricks to deal with the data imbalance situation, e.g. the CamVid dataset only contains 367 images while the Cityscapes dataset has 2975 training images? (Just like constructing a training batch with different ratios for two datasets or other sorts of things)

    Besides, can you give some hints on hyperparameters, e.g. the number of training iterations, batch size, learning rate, weight decay?

    Looking forward to your reply! Thanks a lot!

    opened by X-Lai 6
  • low performance for full supervised setting

    low performance for full supervised setting

    I modified the config file to set the code to 'supervised' mode, but the result seems to be very low: Epoch : 40 | Mean_IoU : 0.699999988079071 | PixelAcc : 0.933 | Val Loss : 0.26163 compared with 'semi' mode:

    Epoch : 40 | Mean_IoU : 0.7120000123977661 | PixelAcc : 0.931 | Val Loss : 0.31637

    Note that I have changed the supervised list to the 10k+ augmented list in the 'supervised' setting. Did I miss something here?

    opened by zhangyuygss 6
  • How to obtain figure 2(d)

    How to obtain figure 2(d)

    Hi, thank you for your nice work!

    I want to know how to produce figure2(d)? There are 2048 channels for hidden representation, how to visualize? Thanks for your help!

    opened by reluuu 5
  • low performance in semi-supervised mode when employing weakly_loss with 2 gpus

    low performance in semi-supervised mode when employing weakly_loss with 2 gpus

    Thank you for your nice work!

    I tried to training the model with 1464 labeled samples in semi-supervised mode, and I used 2 gpus. I set the epoch as 80, and stop it after 50 epoch. But the performance is poor, e.g., miou at epoch 5 is 34.70% while at epoch 10 is 11.40%. image

    I set the 'use_weak_labels' as true, the 'drop_last' as false, and the rest are default.

    Have you ever met this situation?

    opened by wqhIris 5
This is an official implementation for "Self-Supervised Learning with Swin Transformers".

Self-Supervised Learning with Vision Transformers By Zhenda Xie*, Yutong Lin*, Zhuliang Yao, Zheng Zhang, Qi Dai, Yue Cao and Han Hu This repo is the

Swin Transformer 529 Jan 02, 2023
Using fully convolutional networks for semantic segmentation with caffe for the cityscapes dataset

Using fully convolutional networks for semantic segmentation (Shelhamer et al.) with caffe for the cityscapes dataset How to get started Download the

Simon Guist 27 Jun 06, 2022
Adds timm pretrained backbone to pytorch's FasterRcnn model

Operating Systems Lab (ETCS-352) Experiments for Operating Systems Lab (ETCS-352) performed by me in 2021 at uni. All codes are written by me except t

Mriganka Nath 12 Dec 03, 2022
Deep learning algorithms for muon momentum estimation in the CMS Trigger System

Deep learning algorithms for muon momentum estimation in the CMS Trigger System The Compact Muon Solenoid (CMS) is a general-purpose detector at the L

anuragB 2 Oct 06, 2021
This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

75 Dec 02, 2022
phylotorch-bito is a package providing an interface to BITO for phylotorch

phylotorch-bito phylotorch-bito is a package providing an interface to BITO for phylotorch Dependencies phylotorch BITO Installation Get the source co

Mathieu Fourment 2 Sep 01, 2022
CTF challenges and write-ups for MicroCTF 2021.

MicroCTF 2021 Qualifications About This repository contains CTF challenges and official write-ups for MicroCTF 2021 Qualifications. License Distribute

Shellmates 12 Dec 27, 2022
Generalized Decision Transformer for Offline Hindsight Information Matching

Generalized Decision Transformer for Offline Hindsight Information Matching [arxiv] If you use this codebase for your research, please cite the paper:

Hiroki Furuta 35 Dec 12, 2022
This project uses Template Matching technique for object detecting by detection of template image over base image.

Object Detection Project Using OpenCV This project uses Template Matching technique for object detecting by detection the template image over base ima

Pratham Bhatnagar 7 May 29, 2022
VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech

VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech Jaehyeon Kim, Jungil Kong, and Juhee Son In our rece

Jaehyeon Kim 1.7k Jan 08, 2023
Block Sparse movement pruning

Movement Pruning: Adaptive Sparsity by Fine-Tuning Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; ho

Hugging Face 54 Dec 20, 2022
PyTorch implementation of Advantage Actor Critic (A2C), Proximal Policy Optimization (PPO), Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation (ACKTR) and Generative Adversarial Imitation Learning (GAIL).

PyTorch implementation of Advantage Actor Critic (A2C), Proximal Policy Optimization (PPO), Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation (ACKTR)

Ilya Kostrikov 3k Dec 31, 2022
Food recognition model using convolutional neural network & computer vision

Food recognition model using convolutional neural network & computer vision. The goal is to match or beat the DeepFood Research Paper

Hemanth Chandran 1 Jan 13, 2022
This is the official code for the paper "Ad2Attack: Adaptive Adversarial Attack for Real-Time UAV Tracking".

Ad^2Attack:Adaptive Adversarial Attack on Real-Time UAV Tracking Demo video 📹 Our video on bilibili demonstrates the test results of Ad^2Attack on se

Intelligent Vision for Robotics in Complex Environment 10 Nov 07, 2022
Code for ACL 2019 Paper: "COMET: Commonsense Transformers for Automatic Knowledge Graph Construction"

To run a generation experiment (either conceptnet or atomic), follow these instructions: First Steps First clone, the repo: git clone https://github.c

Antoine Bosselut 575 Jan 01, 2023
A Demo server serving Bert through ONNX with GPU written in Rust with <3

Demo BERT ONNX server written in rust This demo showcase the use of onnxruntime-rs on BERT with a GPU on CUDA 11 served by actix-web and tokenized wit

Xavier Tao 28 Jan 01, 2023
SatelliteSfM - A library for solving the satellite structure from motion problem

Satellite Structure from Motion Maintained by Kai Zhang. Overview This is a libr

Kai Zhang 190 Dec 08, 2022
Nested Graph Neural Network (NGNN) is a general framework to improve a base GNN's expressive power and performance

Nested Graph Neural Networks About Nested Graph Neural Network (NGNN) is a general framework to improve a base GNN's expressive power and performance.

Muhan Zhang 38 Jan 05, 2023
Chinese Mandarin tts text-to-speech 中文 (普通话) 语音 合成 , by fastspeech 2 , implemented in pytorch, using waveglow as vocoder,

Chinese mandarin text to speech based on Fastspeech2 and Unet This is a modification and adpation of fastspeech2 to mandrin(普通话). Many modifications t

291 Jan 02, 2023
As-ViT: Auto-scaling Vision Transformers without Training

As-ViT: Auto-scaling Vision Transformers without Training [PDF] Wuyang Chen, Wei Huang, Xianzhi Du, Xiaodan Song, Zhangyang Wang, Denny Zhou In ICLR 2

VITA 68 Sep 05, 2022