PyTorch code of my WACV 2022 paper Improving Model Generalization by Agreement of Learned Representations from Data Augmentation

Overview

Improving Model Generalization by Agreement of Learned Representations from Data Augmentation (WACV 2022)

Paper

ArXiv

Why it matters?

When data augmentation is applied on an input image, a model is forced to learn invariant features to improve model generalization (Figure 1).

Since data augmentation incurs little overhead, why not generate 2 data augmented images (also known as 2 positive samples) from a given input. Then, force the model to agree on the common invariant features to support the correct label (Figure 2). It turns out that maximizing this agreement further improves model model generalization. We call our method AgMax.

Unlike label smoothing, AgMax consistently improves model accuracy. For example on ImageNet1k for 90 epochs, the ResNet50 performance is as follows:

Data Augmentation Baseline Label Smoothing AgMax (Ours)
Standard 76.4 76.8 76.9
CutOut 76.2 76.5 77.1
MixUp 76.5 76.7 77.6
CutMix 76.3 76.4 77.4
AutoAugment (AA) 76.2 76.2 77.1
CutOut+AA 75.7 75.7 76.6
MixUp+AA 75.9 76.5 77.1
CutMix+AA 75.5 75.5 77.0

The figure below demonstrates consistent improvement across different data augmnentation methods:

Install requirements

pip3 install -r requirements.txt

Train

For example, train ResNet50 with AgMax on 2 GPUs for 90 epochs, SGD with lr=0.1 and multistep learning rate scheduler:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=ResNet50-standard-agmax --train \
--multisteplr --dataset=imagenet --epochs=90 --save

Compare the results without AgMax:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=ResNet50-standard --train \
--multisteplr --dataset=imagenet --epochs=90 --save

Test

Using a pre-trained model:

ResNet101 trained with CutMix, AutoAugment and AgMax:

mkdir checkpoints
cd checkpoints
wget https://github.com/roatienza/agmax/releases/download/agmax-0.1.0/imagenet-agmax-mi-ResNet101-cutmix-auto_augment-81.19-mlp-4096.pth
cd ..
python3 main.py --config=ResNet101-auto_augment-cutmix-agmax --eval \
--dataset=imagenet \
--resume imagenet-agmax-mi-ResNet101-cutmix-auto_augment-81.19-mlp-4096.pth

ResNet50 trained with CutMix, AutoAugment and AgMax:

python3 main.py --config=ResNet50-auto_augment-cutmix-agmax --eval --n-units=2048 \
--dataset=imagenet --resume imagenet-agmax-ResNet50-cutmix-auto_augment-79.12-mlp-2048.pth

Other pre-trained models (Baselines):

Citation

If you find this work useful, please cite:

@inproceedings{atienza2022agmax,
  title={Improving Model Generalization by Agreement of Learned Representations from Data Augmentation},
  author={Atienza, Rowel},
  booktitle = {IEEE/CVF Winter Conference on Applications of Computer Vision},
  year={2022},
  pubstate={published},
  tppubtype={inproceedings}
}
You might also like...
Code for the AAAI-2022 paper: Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification

Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification (AAAI 2022) Prerequisite PyTorch = 1.2.0 P

Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).
Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).

Relation Prediction as an Auxiliary Training Objective for Knowledge Base Completion This repo provides the code for the paper Relation Prediction as

Sharpness-Aware Minimization for Efficiently Improving Generalization
Sharpness-Aware Minimization for Efficiently Improving Generalization

Sharpness-Aware-Minimization-TensorFlow This repository provides a minimal implementation of sharpness-aware minimization (SAM) (Sharpness-Aware Minim

ImageNet-CoG is a benchmark for concept generalization. It provides a full evaluation framework for pre-trained visual representations which measure how well they generalize to unseen concepts.

The ImageNet-CoG Benchmark Project Website Paper (arXiv) Code repository for the ImageNet-CoG Benchmark introduced in the paper "Concept Generalizatio

The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022
The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022

DG-TrajGen The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022. Our Meth

Code for
Code for "ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on", accepted at WACV 2021 Generation of Human Behavior Workshop.

ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on [ Paper ] [ Project Page ] This repository contains the code fo

This is the code for the paper "Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei: Gait Recognition in the Wild with Dense 3D Representations and A Benchmark. (CVPR 2022)"

Gait3D-Benchmark This is the code for the paper "Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei: Gait Recognition in the Wild

Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

[WACV 2020] Reducing Footskate in Human Motion Reconstruction with Ground Contact Constraints

Reducing Footskate in Human Motion Reconstruction with Ground Contact Constraints Official implementation for Reducing Footskate in Human Motion Recon

Comments
  • Question about the mutual information in the agmax loss

    Question about the mutual information in the agmax loss

    The agmax loss is computed by "agreement_loss, dl = agmax_loss(y, target, self.args.dl_weight)" , I do not know what does the argeement_loss and dl mean, and do not know the relationship between these two loss and the calculation of mutual information. Could you give me some help?

    opened by YananGu 4
Owner
Rowel Atienza
Rowel Atienza
PyTorch Implementation for Fracture Detection in Wrist Bone X-ray Images

wrist-d PyTorch Implementation for Fracture Detection in Wrist Bone X-ray Images note: Paper: Under Review at MPDI Diagnostics Submission Date: Novemb

Fatih UYSAL 5 Oct 12, 2022
RP-GAN: Stable GAN Training with Random Projections

RP-GAN: Stable GAN Training with Random Projections This repository contains a reference implementation of the algorithm described in the paper: Behna

Ayan Chakrabarti 20 Sep 18, 2021
Official Implementation of PCT

Official Implementation of PCT Prerequisites python == 3.8.5 Please make sure you have the following libraries installed: numpy torch=1.4.0 torchvisi

32 Nov 21, 2022
Neural Network Libraries

Neural Network Libraries Neural Network Libraries is a deep learning framework that is intended to be used for research, development and production. W

Sony 2.6k Dec 30, 2022
ARAE-Tensorflow for Discrete Sequences (Adversarially Regularized Autoencoder)

ARAE Tensorflow Code Code for the paper Adversarially Regularized Autoencoders for Generating Discrete Structures by Zhao, Kim, Zhang, Rush and LeCun

19 Nov 12, 2021
Official implementation of GraphMask as presented in our paper Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking.

GraphMask This repository contains an implementation of GraphMask, the interpretability technique for graph neural networks presented in our ICLR 2021

Michael Schlichtkrull 29 Sep 02, 2022
Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

ETSformer - Pytorch Implementation of ETSformer, state of the art time-series Transformer, in Pytorch Install $ pip install etsformer-pytorch Usage im

Phil Wang 121 Dec 30, 2022
Worktory is a python library created with the single purpose of simplifying the inventory management of network automation scripts.

Worktory is a python library created with the single purpose of simplifying the inventory management of network automation scripts.

Renato Almeida de Oliveira 18 Aug 31, 2022
Official implementation of FCL-taco2: Fast, Controllable and Lightweight version of Tacotron2 @ ICASSP 2021

FCL-Taco2: Towards Fast, Controllable and Lightweight Text-to-Speech synthesis (ICASSP 2021) Paper | Demo Block diagram of FCL-taco2, where the decode

Disong Wang 39 Sep 28, 2022
Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Facebook Research 68 Dec 29, 2022
Keyhole Imaging: Non-Line-of-Sight Imaging and Tracking of Moving Objects Along a Single Optical Path

Keyhole Imaging Code & Dataset Code associated with the paper "Keyhole Imaging: Non-Line-of-Sight Imaging and Tracking of Moving Objects Along a Singl

Stanford Computational Imaging Lab 20 Feb 03, 2022
Bayesian Generative Adversarial Networks in Tensorflow

Bayesian Generative Adversarial Networks in Tensorflow This repository contains the Tensorflow implementation of the Bayesian GAN by Yunus Saatchi and

Andrew Gordon Wilson 1k Nov 29, 2022
Prototypical Networks for Few shot Learning in PyTorch

Prototypical Networks for Few shot Learning in PyTorch Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code)

Orobix 835 Jan 08, 2023
Block-wisely Supervised Neural Architecture Search with Knowledge Distillation (CVPR 2020)

DNA This repository provides the code of our paper: Blockwisely Supervised Neural Architecture Search with Knowledge Distillation. Illustration of DNA

Changlin Li 215 Dec 19, 2022
Robot Hacking Manual (RHM). From robotics to cybersecurity. Papers, notes and writeups from a journey into robot cybersecurity.

RHM: Robot Hacking Manual Download in PDF RHM v0.4 ┃ Read online The Robot Hacking Manual (RHM) is an introductory series about cybersecurity for robo

Víctor Mayoral Vilches 233 Dec 30, 2022
Fast Neural Representations for Direct Volume Rendering

Fast Neural Representations for Direct Volume Rendering Sebastian Weiss, Philipp Hermüller, Rüdiger Westermann This repository contains the code and s

Sebastian Weiss 20 Dec 03, 2022
Pytorch implementation of Depth-conditioned Dynamic Message Propagation forMonocular 3D Object Detection

DDMP-3D Pytorch implementation of Depth-conditioned Dynamic Message Propagation forMonocular 3D Object Detection, a paper on CVPR2021. Instroduction T

Li Wang 32 Nov 09, 2022
Yolov5 deepsort inference,使用YOLOv5+Deepsort实现车辆行人追踪和计数,代码封装成一个Detector类,更容易嵌入到自己的项目中

使用YOLOv5+Deepsort实现车辆行人追踪和计数,代码封装成一个Detector类,更容易嵌入到自己的项目中。

813 Dec 31, 2022
Human head pose estimation using Keras over TensorFlow.

RealHePoNet: a robust single-stage ConvNet for head pose estimation in the wild.

Rafael Berral Soler 71 Jan 05, 2023
Implementation of ReSeg using PyTorch

Implementation of ReSeg using PyTorch ReSeg: A Recurrent Neural Network-based Model for Semantic Segmentation Pascal-Part Annotations Pascal VOC 2010

Onur Kaplan 46 Nov 23, 2022