ICLR 2021 i-Mix: A Domain-Agnostic Strategy for Contrastive Representation Learning

Related tags

Deep Learningimix
Overview

Introduction

PyTorch code for the ICLR 2021 paper [i-Mix: A Domain-Agnostic Strategy for Contrastive Representation Learning].

@inproceedings{lee2021imix,
  title={i-Mix: A Domain-Agnostic Strategy for Contrastive Representation Learning},
  author={Lee, Kibok and Zhu, Yian and Sohn, Kihyuk and Li, Chun-Liang and Shin, Jinwoo and Lee, Honglak},
  booktitle={ICLR},
  year={2021}
}

Dependencies

  • python 3.7.4
  • numpy 1.17.2
  • pytorch 1.4.0
  • torchvision 0.5.0
  • cudatoolkit 10.1
  • librosa 0.8.0 for speech_commands
  • PIL 6.2.0 for GaussianBlur

Data

  • CIFAR-10/100 will automatically be downloaded.
  • For ImageNet, please refer to the [PyTorch ImageNet example]. The folder structure should be like data/imagenet/train/n01440764/
  • For speech commands, run bash speech_commands/download_speech_commands_dataset.sh.
  • For tabular datasets, download [covtype.data.gz] and [HIGGS.csv.gz], and place them in data/. They are processed when first loaded.

Running scripts

Please refer to [run.sh].

Plug-in example

For those who want to apply our method in their own code, we provide a minimal example based on [MoCo]:

# mixup: somewhere in main_moco.py
def mixup(input, alpha):
    beta = torch.distributions.beta.Beta(alpha, alpha)
    randind = torch.randperm(input.shape[0], device=input.device)
    lam = beta.sample([input.shape[0]]).to(device=input.device)
    lam = torch.max(lam, 1. - lam)
    lam_expanded = lam.view([-1] + [1]*(input.dim()-1))
    output = lam_expanded * input + (1. - lam_expanded) * input[randind]
    return output, randind, lam

# cutmix: somewhere in main_moco.py
def cutmix(input, alpha):
    beta = torch.distributions.beta.Beta(alpha, alpha)
    randind = torch.randperm(input.shape[0], device=input.device)
    lam = beta.sample().to(device=input.device)
    lam = torch.max(lam, 1. - lam)
    (bbx1, bby1, bbx2, bby2), lam = rand_bbox(input.shape[-2:], lam)
    output = input.clone()
    output[..., bbx1:bbx2, bby1:bby2] = output[randind][..., bbx1:bbx2, bby1:bby2]
    return output, randind, lam

def rand_bbox(size, lam):
    W, H = size
    cut_rat = (1. - lam).sqrt()
    cut_w = (W * cut_rat).to(torch.long)
    cut_h = (H * cut_rat).to(torch.long)

    cx = torch.zeros_like(cut_w, dtype=cut_w.dtype).random_(0, W)
    cy = torch.zeros_like(cut_h, dtype=cut_h.dtype).random_(0, H)

    bbx1 = (cx - cut_w // 2).clamp(0, W)
    bby1 = (cy - cut_h // 2).clamp(0, H)
    bbx2 = (cx + cut_w // 2).clamp(0, W)
    bby2 = (cy + cut_h // 2).clamp(0, H)

    new_lam = 1. - (bbx2 - bbx1).to(lam.dtype) * (bby2 - bby1).to(lam.dtype) / (W * H)

    return (bbx1, bby1, bbx2, bby2), new_lam

# https://github.com/facebookresearch/moco/blob/master/main_moco.py#L193
criterion = nn.CrossEntropyLoss(reduction='none').cuda(args.gpu)

# https://github.com/facebookresearch/moco/blob/master/main_moco.py#L302-L303
images[0], target_aux, lam = mixup(images[0], alpha=1.)
# images[0], target_aux, lam = cutmix(images[0], alpha=1.)
target = torch.arange(images[0].shape[0], dtype=torch.long).cuda()
output, _ = model(im_q=images[0], im_k=images[1])
loss = lam * criterion(output, target) + (1. - lam) * criterion(output, target_aux)

# https://github.com/facebookresearch/moco/blob/master/moco/builder.py#L142-L149
contrast = torch.cat([k, self.queue.clone().detach().t()], dim=0)
logits = torch.mm(q, contrast.t())

Note

Owner
Kibok Lee
Kibok Lee
Finite-temperature variational Monte Carlo calculation of uniform electron gas using neural canonical transformation.

CoulombGas This code implements the neural canonical transformation approach to the thermodynamic properties of uniform electron gas. Building on JAX,

FermiFlow 9 Mar 03, 2022
Detection of PCBA defect

Detection_of_PCBA_defect Detection_of_PCBA_defect Use yolov5 to train. $pip install -r requirements.txt Detect.py will detect file(jpg,mp4...) in cu

6 Nov 28, 2022
Code for "The Box Size Confidence Bias Harms Your Object Detector"

The Box Size Confidence Bias Harms Your Object Detector - Code Disclaimer: This repository is for research purposes only. It is designed to maintain r

Johannes G. 24 Dec 07, 2022
Code for the paper titled "Generalized Depthwise-Separable Convolutions for Adversarially Robust and Efficient Neural Networks" (NeurIPS 2021 Spotlight).

Generalized Depthwise-Separable Convolutions for Adversarially Robust and Efficient Neural Networks This repository contains the code and pre-trained

Hassan Dbouk 7 Dec 05, 2022
Model of an AI powered sign language interpreter.

TEXT AND SPEECH TO SIGN LANGUAGE. A web application which takes in text or live audio speech recording as input, converts and displays the relevant Si

Mark Gatere 4 Mar 30, 2022
Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Dominic Rampas 247 Dec 16, 2022
A PyTorch implementation of "Graph Classification Using Structural Attention" (KDD 2018).

GAM ⠀⠀ A PyTorch implementation of Graph Classification Using Structural Attention (KDD 2018). Abstract Graph classification is a problem with practic

Benedek Rozemberczki 259 Dec 05, 2022
Like Dirt-Samples, but cleaned up

Clean-Samples Like Dirt-Samples, but cleaned up, with clear provenance and license info (generally a permissive creative commons licence but check the

TidalCycles 39 Nov 30, 2022
Reproducing code of hair style replacement method from Barbershorp.

Barbershorp Reproducing code of hair style replacement method from Barbershorp. Also reproduces II2S, an improved version of Image2StyleGAN. Requireme

1 Dec 24, 2021
Efficient training of deep recommenders on cloud.

HybridBackend Introduction HybridBackend is a training framework for deep recommenders which bridges the gap between evolving cloud infrastructure and

Alibaba 111 Dec 23, 2022
The versatile ocean simulator, in pure Python, powered by JAX.

Veros is the versatile ocean simulator -- it aims to be a powerful tool that makes high-performance ocean modeling approachable and fun. Because Veros

TeamOcean 245 Dec 20, 2022
Pytorch implementation of AngularGrad: A New Optimization Technique for Angular Convergence of Convolutional Neural Networks

AngularGrad Optimizer This repository contains the oficial implementation for AngularGrad: A New Optimization Technique for Angular Convergence of Con

mario 124 Sep 16, 2022
Lightwood is Legos for Machine Learning.

Lightwood is like Legos for Machine Learning. A Pytorch based framework that breaks down machine learning problems into smaller blocks that can be glu

MindsDB Inc 312 Jan 08, 2023
Implementation of our paper "Video Playback Rate Perception for Self-supervised Spatio-Temporal Representation Learning".

PRP Introduction This is the implementation of our paper "Video Playback Rate Perception for Self-supervised Spatio-Temporal Representation Learning".

yuanyao366 39 Dec 29, 2022
Pytorch implementation of our paper under review — Lottery Jackpots Exist in Pre-trained Models

Lottery Jackpots Exist in Pre-trained Models (Paper Link) Requirements Python = 3.7.4 Pytorch = 1.6.1 Torchvision = 0.4.1 Reproduce the Experiment

Yuxin Zhang 27 Jun 28, 2022
MvtecAD unsupervised Anomaly Detection

MvtecAD unsupervised Anomaly Detection This respository is the unofficial implementations of DFR: Deep Feature Reconstruction for Unsupervised Anomaly

0 Feb 25, 2022
Apache Spark - A unified analytics engine for large-scale data processing

Apache Spark Spark is a unified analytics engine for large-scale data processing. It provides high-level APIs in Scala, Java, Python, and R, and an op

The Apache Software Foundation 34.7k Jan 04, 2023
torchlm is aims to build a high level pipeline for face landmarks detection, it supports training, evaluating, exporting, inference(Python/C++) and 100+ data augmentations

💎A high level pipeline for face landmarks detection, supports training, evaluating, exporting, inference and 100+ data augmentations, compatible with torchvision and albumentations, can easily instal

DefTruth 142 Dec 25, 2022
Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation

Unseen Object Clustering: Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation Introduction In this work, we propose a new method

NVIDIA Research Projects 132 Dec 13, 2022
Code for intrusion detection system (IDS) development using CNN models and transfer learning

Intrusion-Detection-System-Using-CNN-and-Transfer-Learning This is the code for the paper entitled "A Transfer Learning and Optimized CNN Based Intrus

Western OC2 Lab 38 Dec 12, 2022