Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT)

Overview

Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT)

Paper, Project Page

This repo contains the official implementation of CVPR 2020 paper: Semi-Supervised Semantic Segmentation with Cross-Consistecy Training, which adapts the traditional consistency training framework of semi-supervised learning for semantic segmentation, with an extension to weak-supervised learning and learning on multiple domains.

Highlights

(1) Consistency Training for semantic segmentation.
We observe that for semantic segmentation, due to the dense nature of the task, the cluster assumption is more easily enforced over the hidden representations rather than the inputs.

(2) Cross-Consistecy Training.
We propose CCT (Cross-Consistecy Training) for semi-supervised semantic segmentation, where we define a number of novel perturbations, and show the effectiveness of enforcing consistency over the encoder's outputs rather than the inputs.

(3) Using weak-labels and pixel-level labels from multiple domains.
The proposed method is quite simple and flexible, and can easily be extended to use image-level labels and pixel-level labels from multiple-domains.

Requirements

This repo was tested with Ubuntu 18.04.3 LTS, Python 3.7, PyTorch 1.1.0, and CUDA 10.0. But it should be runnable with recent PyTorch versions >=1.1.0.

The required packages are pytorch and torchvision, together with PIL and opencv for data-preprocessing and tqdm for showing the training progress. With some additional modules like dominate to save the results in the form of HTML files. To setup the necessary modules, simply run:

pip install -r requirements.txt

Dataset

In this repo, we use Pascal VOC, to obtain it, first download the original dataset, after extracting the files we'll end up with VOCtrainval_11-May-2012/VOCdevkit/VOC2012 containing the image sets, the XML annotation for both object detection and segmentation, and JPEG images.
The second step is to augment the dataset using the additionnal annotations provided by Semantic Contours from Inverse Detectors. Download the rest of the annotations SegmentationClassAug and add them to the path VOCtrainval_11-May-2012/VOCdevkit/VOC2012, now we're set, for training use the path to VOCtrainval_11-May-2012.

Training

To train a model, first download PASCAL VOC as detailed above, then set data_dir to the dataset path in the config file in configs/config.json and set the rest of the parameters, like the number of GPUs, cope size, data augmentation ... etc ,you can also change CCT hyperparameters if you wish, more details below. Then simply run:

python train.py --config configs/config.json

The log files and the .pth checkpoints will be saved in saved\EXP_NAME, to monitor the training using tensorboard, please run:

tensorboard --logdir saved

To resume training using a saved .pth model:

python train.py --config configs/config.json --resume saved/CCT/checkpoint.pth

Results: The results will be saved in saved as an html file, containing the validation results, and the name it will take is experim_name specified in configs/config.json.

Pseudo-labels

If you want to use image level labels to train the auxiliary labels as explained in section 3.3 of the paper. First generate the pseudo-labels using the code in pseudo_labels:

cd pseudo_labels
python run.py --voc12_root DATA_PATH

DATA_PATH must point to the folder containing JPEGImages in Pascal Voc dataset. The results will be saved in pseudo_labels/result/pseudo_labels as PNG files, the flag use_weak_labels needs to be set to True in the config file, and then we can train the model as detailed above.

Inference

For inference, we need a pretrained model, the jpg images we'd like to segment and the config used in training (to load the correct model and other parameters),

python inference.py --config config.json --model best_model.pth --images images_folder

The predictions will be saved as .png images in outputs\ is used, for Pacal VOC the default palette is:

Here are the flags available for inference:

--images       Folder containing the jpg images to segment.
--model        Path to the trained pth model.
--config       The config file used for training the model.

Pre-trained models

Pre-trained models can be downloaded here.

Citation ✏️ 📄

If you find this repo useful for your research, please consider citing the paper as follows:

@InProceedings{Ouali_2020_CVPR,
  author = {Ouali, Yassine and Hudelot, Celine and Tami, Myriam},
  title = {Semi-Supervised Semantic Segmentation With Cross-Consistency Training},
  booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  month = {June},
  year = {2020}
}

For any questions, please contact Yassine Ouali.

Config file details ⚙️

Bellow we detail the CCT parameters that can be controlled in the config file configs/config.json, the rest of the parameters are self-explanatory.

{
    "name": "CCT",                              
    "experim_name": "CCT",                             // The name the results will take (html and the folder in /saved)
    "n_gpu": 1,                                             // Number of GPUs
    "n_labeled_examples": 1000,                             // Number of labeled examples (choices are 60, 100, 200, 
                                                            // 300, 500, 800, 1000, 1464, and the splits are in dataloaders/voc_splits)
    "diff_lrs": true,
    "ramp_up": 0.1,                                         // The unsupervised loss will be slowly scaled up in the first 10% of Training time
    "unsupervised_w": 30,                                   // Weighting of the unsupervised loss
    "ignore_index": 255,
    "lr_scheduler": "Poly",
    "use_weak_labels": false,                               // If the pseudo-labels were generated, we can use them to train the aux. decoders
    "weakly_loss_w": 0.4,                                   // Weighting of the weakly-supervised loss
    "pretrained": true,

    "model":{
        "supervised": true,                                  // Supervised setting (training only on the labeled examples)
        "semi": false,                                       // Semi-supervised setting
        "supervised_w": 1,                                   // Weighting of the supervised loss

        "sup_loss": "CE",                                    // supervised loss, choices are CE and ab-CE = ["CE", "ABCE"]
        "un_loss": "MSE",                                    // unsupervised loss, choices are CE and KL-divergence = ["MSE", "KL"]

        "softmax_temp": 1,
        "aux_constraint": false,                             // Pair-wise loss (sup. mat.)
        "aux_constraint_w": 1,
        "confidence_masking": false,                         // Confidence masking (sup. mat.)
        "confidence_th": 0.5,

        "drop": 6,                                           // Number of DropOut decoders
        "drop_rate": 0.5,                                    // Dropout probability
        "spatial": true,
    
        "cutout": 6,                                         // Number of G-Cutout decoders
        "erase": 0.4,                                        // We drop 40% of the area
    
        "vat": 2,                                            // Number of I-VAT decoders
        "xi": 1e-6,                                          // VAT parameters
        "eps": 2.0,

        "context_masking": 2,                               // Number of Con-Msk decoders
        "object_masking": 2,                                // Number of Obj-Msk decoders
        "feature_drop": 6,                                  // Number of F-Drop decoders

        "feature_noise": 6,                                 // Number of F-Noise decoders
        "uniform_range": 0.3                                // The range of the noise
    },

Acknowledgements

  • Pseudo-labels generation is based on Jiwoon Ahn's implementation irn.
  • Code structure was based on Pytorch-Template
  • ResNet backbone was downloaded from torchcv
Comments
  • custom dataset with 4 classes

    custom dataset with 4 classes

    Thank you so far for all your great help. I have an issue that I also found in the closed issues, but for me it isn't solved. I have my own custom data set with 4 classes (background and 3 objects, labeled 0-3), so I changed num_classes = 4 in voc.py The results with training fully supervised are as shown in the image below. There is one class with an IoU of 0.0. image I ran multiple tests, using semi and weakly supervised settings, the results are unpredictable and often show 0.0 for the object classes. Only the background has good results. Is there something I need to adjust in the code?

    opened by SuzannaLin 22
  • Training error!

    Training error!

    I want to train VOC2012, but get the error below:

    Traceback (most recent call last):
      File "train.py", line 98, in <module>
        main(config, args.resume)
      File "train.py", line 82, in main
        trainer.train()
      File "/home/byronnar/bigfile/projects/CCT/base/base_trainer.py", line 91, in train
        results = self._train_epoch(epoch)
      File "/home/byronnar/bigfile/projects/CCT/trainer.py", line 76, in _train_epoch
        total_loss, cur_losses, outputs = self.model(x_l=input_l, target_l=target_l, x_ul=input_ul, curr_iter=batch_idx, target_ul=target_ul, epoch=epoch-1)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
        return self.module(*inputs[0], **kwargs[0])
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/model.py", line 93, in forward
        output_l = self.main_decoder(self.encoder(x_l))
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 61, in forward
        x = self.psp(x)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 36, in forward
        align_corners=False) for stage in self.stages])
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 36, in <listcomp>
        align_corners=False) for stage in self.stages])
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
        input = module(input)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 81, in forward
        exponential_average_factor, self.eps)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py", line 1652, in batch_norm
        raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
    ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])
      0%|                                                                                                         | 0/9118 [00:02<?, ?it/s]
    

    How should I do? Thank you

    opened by Byronnar 12
  • Poor mIoU when training on 1464 images in a supervised manner

    Poor mIoU when training on 1464 images in a supervised manner

    Hi, I trained the model on the 1464 images in a supervised manner. The highest mIoU on Val is 67.00%, but the reported number in your paper is 69.4%. Here is my config.json file. Can you have a look at which part is wrong?

     {  "name": "CCT",
        "experim_name": "CCT",
        "n_gpu": 1,
        "n_labeled_examples": 1464,
        "diff_lrs": true,
        "ramp_up": 0.1,
        "unsupervised_w": 30,
        "ignore_index": 255,
        "lr_scheduler": "Poly",
        "use_weak_lables":false,
        "weakly_loss_w": 0.4,
        "pretrained": true,
        "model":{
            "supervised": true,
            "semi": false,
            "supervised_w": 1,
    
            "sup_loss": "CE",
            "un_loss": "MSE",
    
            "softmax_temp": 1,
            "aux_constraint": false,
            "aux_constraint_w": 1,
            "confidence_masking": false,
            "confidence_th": 0.5,
    
            "drop": 6,
            "drop_rate": 0.5,
            "spatial": true,
        
            "cutout": 6,
            "erase": 0.4,
        
            "vat": 2,
            "xi": 1e-6,
            "eps": 2.0,
    
            "context_masking": 2,
            "object_masking": 2,
            "feature_drop": 6,
    
            "feature_noise": 6,
            "uniform_range": 0.3
        },
    
    
        "optimizer": {
            "type": "SGD",
            "args":{
                "lr": 1e-2,
                "weight_decay": 1e-4,
                "momentum": 0.9
            }
        },
    
    
        "train_supervised": {
            "data_dir": "../data/VOC2012",
            "batch_size": 10,
            "crop_size": 320,
            "shuffle": true,
            "base_size": 400,
            "scale": true,
            "augment": true,
            "flip": true,
            "rotate": false,
            "blur": false,
            "split": "train_supervised",
            "num_workers": 8
        },
    
        "train_unsupervised": {
            "data_dir": "VOCtrainval_11-May-2012",
            "weak_labels_output": "pseudo_labels/result/pseudo_labels",
            "batch_size": 10,
            "crop_size": 320,
            "shuffle": true,
            "base_size": 400,
            "scale": true,
            "augment": true,
            "flip": true,
            "rotate": false,
            "blur": false,
            "split": "train_unsupervised",
            "num_workers": 8
        },
    
        "val_loader": {
            "data_dir": "../data/VOC2012",
            "batch_size": 1,
            "val": true,
            "split": "val",
            "shuffle": false,
            "num_workers": 4
        },
    
        "trainer": {
            "epochs": 80,
            "save_dir": "saved/",
            "save_period": 5,
      
            "monitor": "max Mean_IoU",
            "early_stop": 10,
            
            "tensorboardX": true,
            "log_dir": "saved/",
            "log_per_iter": 20,
    
            "val": true,
            "val_per_epochs": 5
        }
    }
    
    opened by xiaomengyc 11
  • Fail to reimplement your paper's result for semi-supervised.

    Fail to reimplement your paper's result for semi-supervised.

    I use the default config file to conduct experiments, but I only got 68.9mIoU for not adopting weak label and got 70.09mIoU for adopting weak label following your readme. These results are far lower than yours. My env is pytorch 1.7.0 and python 3.8.5. Could provide some advice?

    opened by TyroneLi 8
  • checkerboard

    checkerboard

    Hi Yassine, I am using the CCT model to train on a satellite dataset. The images are size 128x128. For some reason the predictions show a clear checkerboard pattern as shown in this example. Left: prediction, Right: ground truth. image Do you have any idea what causes this and how to avoid it?

    opened by SuzannaLin 7
  • inference with 4-channel model

    inference with 4-channel model

    Hi Yassine! I have managed to train a model with 4 channels, but the inference is not working. I get this error message:

    !python inference.py --config configs/config_70_30_sup_alti.json --model './saved/ABCE_70_30_sup_alti/best_model.pth' --output 'CCT_output/ABCE_70_30_sup_alti/Angers/' --images 'val/Angers/BDORTHO'

    Loading pretrained model:models/backbones/pretrained/3x3resnet50-imagenet.pth Traceback (most recent call last): File "inference.py", line 155, in main() File "inference.py", line 102, in main conf=config['model'], testing=True, pretrained = True) File "/home/scuypers/CCT_4/models/model.py", line 55, in init self.encoder = Encoder(pretrained=pretrained) File "/home/scuypers/CCT_4/models/encoder.py", line 49, in init model = ResNetBackbone(backbone='deepbase_resnet50_dilated8', pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/resnet_backbone.py", line 145, in ResNetBackbone orig_resnet = deepbase_resnet50(pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/resnet_models.py", line 227, in deepbase_resnet50 model = ModuleHelper.load_model(model, pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/module_helper.py", line 109, in load_model model.load_state_dict(load_dict) File "/home/scuypers/.conda/envs/envCCT/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for ResNet: size mismatch for prefix.conv1.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 4, 3, 3]).

    opened by SuzannaLin 7
  • Loss function

    Loss function

    Thank you for your contribution. I want to know what the |Dl| and |Du| in your cross-entropy loss function formula and semi-supervised loss function formula represent, thank you for your answer.

    opened by yrcrcy 7
  • Reproducing Cross-domain Experiments

    Reproducing Cross-domain Experiments

    @yassouali @SuzannaLin Firstly, thanks a lot for your great work! But I still have problems in reproducing the cross-domain experiments.

    I have seen the implementation from https://github.com/tarun005/USSS_ICCV19, however, I notice that there are still many alternatives for the training schemes.

    During training, in each iteration (i.e. each execution of 'optimizer.step()'), do you train the network by forwarding inputs from both two datasets? Or by forwarding inputs from only one dataset in the current iteration, and then executing "optimizer.step()", and then forwarding inputs from the other dataset in the next iteration, and so on?

    Also, are there any tricks to deal with the data imbalance situation, e.g. the CamVid dataset only contains 367 images while the Cityscapes dataset has 2975 training images? (Just like constructing a training batch with different ratios for two datasets or other sorts of things)

    Besides, can you give some hints on hyperparameters, e.g. the number of training iterations, batch size, learning rate, weight decay?

    Looking forward to your reply! Thanks a lot!

    opened by X-Lai 6
  • low performance for full supervised setting

    low performance for full supervised setting

    I modified the config file to set the code to 'supervised' mode, but the result seems to be very low: Epoch : 40 | Mean_IoU : 0.699999988079071 | PixelAcc : 0.933 | Val Loss : 0.26163 compared with 'semi' mode:

    Epoch : 40 | Mean_IoU : 0.7120000123977661 | PixelAcc : 0.931 | Val Loss : 0.31637

    Note that I have changed the supervised list to the 10k+ augmented list in the 'supervised' setting. Did I miss something here?

    opened by zhangyuygss 6
  • How to obtain figure 2(d)

    How to obtain figure 2(d)

    Hi, thank you for your nice work!

    I want to know how to produce figure2(d)? There are 2048 channels for hidden representation, how to visualize? Thanks for your help!

    opened by reluuu 5
  • low performance in semi-supervised mode when employing weakly_loss with 2 gpus

    low performance in semi-supervised mode when employing weakly_loss with 2 gpus

    Thank you for your nice work!

    I tried to training the model with 1464 labeled samples in semi-supervised mode, and I used 2 gpus. I set the epoch as 80, and stop it after 50 epoch. But the performance is poor, e.g., miou at epoch 5 is 34.70% while at epoch 10 is 11.40%. image

    I set the 'use_weak_labels' as true, the 'drop_last' as false, and the rest are default.

    Have you ever met this situation?

    opened by wqhIris 5
A benchmark dataset for mesh multi-label-classification based on cube engravings introduced in MeshCNN

Double Cube Engravings This script creates a dataset for multi-label mesh clasification, with an intentionally difficult setup for point cloud classif

Yotam Erel 1 Nov 30, 2021
Mix3D: Out-of-Context Data Augmentation for 3D Scenes (3DV 2021)

Mix3D: Out-of-Context Data Augmentation for 3D Scenes (3DV 2021) Alexey Nekrasov*, Jonas Schult*, Or Litany, Bastian Leibe, Francis Engelmann Mix3D is

Alexey Nekrasov 189 Dec 26, 2022
JupyterNotebook - C/C++, Javascript, HTML, LaTex, Shell scripts in Jupyter Notebook Also run them on remote computer

JupyterNotebook Read, write and execute C, C++, Javascript, Shell scripts, HTML, LaTex in jupyter notebook, And also execute them on remote computer R

1 Jan 09, 2022
Image super-resolution through deep learning

srez Image super-resolution through deep learning. This project uses deep learning to upscale 16x16 images by a 4x factor. The resulting 64x64 images

David Garcia 5.3k Dec 28, 2022
DSAC* for Visual Camera Re-Localization (RGB or RGB-D)

DSAC* for Visual Camera Re-Localization (RGB or RGB-D) Introduction Installation Data Structure Supported Datasets 7Scenes 12Scenes Cambridge Landmark

Visual Learning Lab 143 Dec 22, 2022
Using contrastive learning and OpenAI's CLIP to find good embeddings for images with lossy transformations

The official code for the paper "Inverse Problems Leveraging Pre-trained Contrastive Representations" (to appear in NeurIPS 2021).

Sriram Ravula 26 Dec 10, 2022
Code for the Active Speakers in Context Paper (CVPR2020)

Active Speakers in Context This repo contains the official code and models for the "Active Speakers in Context" CVPR 2020 paper. Before Training The c

43 Oct 14, 2022
Hidden-Fold Networks (HFN): Random Recurrent Residuals Using Sparse Supermasks

Hidden-Fold Networks (HFN): Random Recurrent Residuals Using Sparse Supermasks by Ángel López García-Arias, Masanori Hashimoto, Masato Motomura, and J

Ángel López García-Arias 4 May 19, 2022
Training Very Deep Neural Networks Without Skip-Connections

DiracNets v2 update (January 2018): The code was updated for DiracNets-v2 in which we removed NCReLU by adding per-channel a and b multipliers without

Sergey Zagoruyko 585 Oct 12, 2022
Honours project, on creating a depth estimation map from two stereo images of featureless regions

image-processing This module generates depth maps for shape-blocked-out images Install If working with anaconda, then from the root directory: conda e

2 Oct 17, 2022
Official PyTorch implementation of RIO

Image-Level or Object-Level? A Tale of Two Resampling Strategies for Long-Tailed Detection Figure 1: Our proposed Resampling at image-level and obect-

NVIDIA Research Projects 17 May 20, 2022
Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022)

Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022) Introdu

anonymous 14 Oct 27, 2022
The UI as a mobile display for OP25

OP25 Mobile Control Head A 'remote' control head that interfaces with an OP25 instance. We take advantage of some data end-points left exposed for the

Sarah Rose Giddings 13 Dec 28, 2022
Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized Codes

Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized C

Sam Bond-Taylor 139 Jan 04, 2023
Code/data of the paper "Hand-Object Contact Prediction via Motion-Based Pseudo-Labeling and Guided Progressive Label Correction" (BMVC2021)

Hand-Object Contact Prediction (BMVC2021) This repository contains the code and data for the paper "Hand-Object Contact Prediction via Motion-Based Ps

Takuma Yagi 13 Nov 07, 2022
SemiNAS: Semi-Supervised Neural Architecture Search

SemiNAS: Semi-Supervised Neural Architecture Search This repository contains the code used for Semi-Supervised Neural Architecture Search, by Renqian

Renqian Luo 21 Aug 31, 2022
Data visualization app for H&M competition in kaggle

handm_data_visualize_app Data visualization app by streamlit for H&M competition in kaggle. competition page: https://www.kaggle.com/competitions/h-an

Kyohei Uto 12 Apr 30, 2022
UpChecker is a simple opensource project to host it fast on your server and check is server up, view statistic, get messages if it is down. UpChecker - just run file and use project easy

UpChecker UpChecker is a simple opensource project to host it fast on your server and check is server up, view statistic, get messages if it is down.

Yan 4 Apr 07, 2022
Publication describing 3 ML examples at NSLS-II and interfacing into Bluesky

Machine learning enabling high-throughput and remote operations at large-scale user facilities. Overview This repository contains the source code and

BNL 4 Sep 24, 2022
CaFM-pytorch ICCV ACCEPT Introduction of dataset VSD4K

CaFM-pytorch ICCV ACCEPT Introduction of dataset VSD4K Our dataset VSD4K includes 6 popular categories: game, sport, dance, vlog, interview and city.

96 Jul 05, 2022