Robust fine-tuning of zero-shot models

Related tags

Deep Learningwise-ft
Overview

Robust fine-tuning of zero-shot models

This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabriel Ilharco*, Jong Wook Kim, Mike Li, Simon Kornblith, Rebecca Roelofs, Raphael Gontijo-Lopes, Hannaneh Hajishirzi, Ali Farhadi, Hongseok Namkoong, Ludwig Schmidt.

Abstract

Large pre-trained models such as CLIP offer consistent accuracy across a range of data distributions when performing zero-shot inference (i.e., without fine-tuning on a specific dataset). Although existing fine-tuning approaches substantially improve accuracy in-distribution, they also reduce out-of-distribution robustness. We address this tension by introducing a simple and effective method for improving robustness: ensembling the weights of the zero-shot and fine-tuned models. Compared to standard fine-tuning, the resulting weight-space ensembles provide large accuracy improvements out-of-distribution, while matching or improving in-distribution accuracy. On ImageNet and five derived distribution shifts, weight-space ensembles improve out-of-distribution accuracy by 2 to 10 percentage points while increasing in-distribution accuracy by nearly 1 percentage point relative to standard fine-tuning. These improvements come at no additional computational cost during fine-tuning or inference.

Summary figure

figure1

Compared to standard fine-tuning, weight-space ensembles for fine-tuning (WiSE-FT) improve out-of-distribution (OOD) accuracy without decreasing in-distribution (ID) performance. Top left: Zero-shot CLIP models exhibit high effective robustness and moderate in-distribution accuracy, while standard fine-tuning (end-to-end or with a linear classifier) attains higher ID accuracy and less effective robustness. Top right: Our method linearly interpolates between the zero-shot and fine-tuned models with a mixing coefficient alpha in [0,1]. Bottom: On five distribution shifts derived from ImageNet (ImageNetV2, ImageNet-R, ImageNet Sketch, ObjectNet, and ImageNet-A), WiSE-FT improves average OOD accuracy by 8.7 percentage points (pp) when fine-tuning end-to-end (+2.1 pp when fine-tuning a linear classifier) while maintaining ID accuracy.

Code

Overview

WiSE-FT can be implemented in a few lines of code in addition to standard fine-tuning, as shown below. See src/wise_ft.py for more details.

# Load models
zeroshot = ImageClassifier.load(zeroshot_checkpoint)
finetuned = ImageClassifier.load(finetuned_checkpoint)
theta_0 = zeroshot.state_dict()
theta_1 = finetuned.state_dict()

# make sure checkpoints are compatible
assert set(theta_0.keys()) == set(theta_1.keys())

# interpolate between checkpoints with mixing coefficient alpha
theta = {
    key: (1-alpha) * theta_0[key] + alpha * theta_1[key]
    for key in theta_0.keys()
}

# update the model acccording to the new weights
finetuned.load_state_dict(theta)

# evaluate
evaluate(finetuned, args)

Install dependencies

conda env create
conda activate wiseft

Add directory to PYTHONPATH:

cd wise-ft
export PYTHONPATH="$PYTHONPATH:$PWD"

Download data

When necessary, please refer to datasets.md for instructions on how to download datasets.

Run WiSE-FT

Sample command when zeroshot and fine-tuned models are available:

python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Sample command for running WiSE-FT from scratch using ViT-B/32:

python src/wise_ft.py   \
    --train-dataset=ImageNet  \
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=512  \
    --cache-dir=cache  \
    --model=ViT-B/32  \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --template=openai_imagenet_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ViTB32  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Note: the flag --freeze-encoder controls whether only a linear classifier is fine-tuned, or if all weights are fine-tuned (end-to-end).

Plotting results

Sample command for generating a scatter plot:

python src/scatter_plot.py  \
    --eval-datasets=ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --results-db=results.jsonl  \
    --save plots

We show samples of expected behavior below when running the commands above using ViT-B/16 (models can be downloaded here):

ImageNet-Sketch         ImageNet-A

ImageNet-R         ImageNetV2

ObjectNet

Citing

If you found this repository useful, please consider citing:

@article{wortsman2021robust,
  title={Robust fine-tuning of zero-shot models},
  author={Wortsman, Mitchell and Ilharco, Gabriel and Kim, Jong Wook and Li, Mike and Kornblith, Simon and Roelofs, Rebecca and Gontijo-Lopes, Raphael and Hajishirzi, Hannaneh and Farhadi, Ali and Namkoong, Hongseok and Schmidt, Ludwig},
  journal={arXiv preprint arXiv:2109.01903},
  note={\url{https://arxiv.org/abs/2109.01903}},
  year={2021}
}
Official Pytorch implementation of Scene Representation Networks: Continuous 3D-Structure-Aware Neural Scene Representations

Scene Representation Networks This is the official implementation of the NeurIPS submission "Scene Representation Networks: Continuous 3D-Structure-Aw

Vincent Sitzmann 365 Jan 06, 2023
Official implementation for "Low-light Image Enhancement via Breaking Down the Darkness"

Low-light Image Enhancement via Breaking Down the Darkness by Qiming Hu, Xiaojie Guo. 1. Dependencies Python3 PyTorch=1.0 OpenCV-Python, TensorboardX

Qiming Hu 30 Jan 01, 2023
Real-Time-Student-Attendence-System - Real Time Student Attendence System

Real-Time-Student-Attendence-System The Student Attendance Management System Pro

Rounak Das 1 Feb 15, 2022
Code accompanying the paper on "An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers" published at NeurIPS, 2021

Code for "An Empirical Investigation of Domian Generalization with Empirical Risk Minimizers" (NeurIPS 2021) Motivation and Introduction Domain Genera

Meta Research 15 Dec 27, 2022
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Phil Wang 180 Jan 05, 2023
A real-time motion capture system that estimates poses and global translations using only 6 inertial measurement units

TransPose Code for our SIGGRAPH 2021 paper "TransPose: Real-time 3D Human Translation and Pose Estimation with Six Inertial Sensors". This repository

Xinyu Yi 261 Dec 31, 2022
PyTorch code for the ICCV'21 paper: "Always Be Dreaming: A New Approach for Class-Incremental Learning"

Always Be Dreaming: A New Approach for Data-Free Class-Incremental Learning PyTorch code for the ICCV 2021 paper: Always Be Dreaming: A New Approach f

49 Dec 21, 2022
Repository for the paper "Online Domain Adaptation for Occupancy Mapping", RSS 2020

RSS 2020 - Online Domain Adaptation for Occupancy Mapping Repository for the paper "Online Domain Adaptation for Occupancy Mapping", Robotics: Science

Anthony 26 Sep 22, 2022
The code succinctly shows how our ensemble learning based on deep learning CNN is used for LAM-avulsion-diagnosis.

deep-learning-LAM-avulsion-diagnosis The code succinctly shows how our ensemble learning based on deep learning CNN is used for LAM-avulsion-diagnosis

1 Jan 12, 2022
tree-math: mathematical operations for JAX pytrees

tree-math: mathematical operations for JAX pytrees tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterati

Google 137 Dec 28, 2022
[ACMMM 2021 Oral] Enhanced Invertible Encoding for Learned Image Compression

InvCompress Official Pytorch Implementation for "Enhanced Invertible Encoding for Learned Image Compression", ACMMM 2021 (Oral) Figure: Our framework

96 Nov 30, 2022
This is a repository of our model for weakly-supervised video dense anticipation.

Introduction This is a repository of our model for weakly-supervised video dense anticipation. More results on GTEA, Epic-Kitchens etc. will come soon

2 Apr 09, 2022
Google Landmark Recogntion and Retrieval 2021 Solutions

Google Landmark Recogntion and Retrieval 2021 Solutions In this repository you can find solution and code for Google Landmark Recognition 2021 and Goo

Vadim Timakin 5 Nov 25, 2022
PyElastica is the Python implementation of Elastica, an open-source software for the simulation of assemblies of slender, one-dimensional structures using Cosserat Rod theory.

PyElastica PyElastica is the python implementation of Elastica: an open-source project for simulating assemblies of slender, one-dimensional structure

Gazzola Lab 105 Jan 09, 2023
[Preprint] ConvMLP: Hierarchical Convolutional MLPs for Vision, 2021

Convolutional MLP ConvMLP: Hierarchical Convolutional MLPs for Vision Preprint link: ConvMLP: Hierarchical Convolutional MLPs for Vision By Jiachen Li

SHI Lab 143 Jan 03, 2023
The open-source and free to use Python package miseval was developed to establish a standardized medical image segmentation evaluation procedure

miseval: a metric library for Medical Image Segmentation EVALuation The open-source and free to use Python package miseval was developed to establish

59 Dec 10, 2022
ThunderGBM: Fast GBDTs and Random Forests on GPUs

Documentations | Installation | Parameters | Python (scikit-learn) interface What's new? ThunderGBM won 2019 Best Paper Award from IEEE Transactions o

Xtra Computing Group 647 Jan 04, 2023
RetinaFace: Deep Face Detection Library in TensorFlow for Python

RetinaFace is a deep learning based cutting-edge facial detector for Python coming with facial landmarks.

Sefik Ilkin Serengil 512 Dec 29, 2022
《K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters》(2020)

K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters This repository is the implementation of the paper "K-Adapter: Infusing Knowledge

Microsoft 118 Dec 13, 2022
Wav2Vec for speech recognition, classification, and audio classification

Soxan در زبان پارسی به نام سخن This repository consists of models, scripts, and notebooks that help you to use all the benefits of Wav2Vec 2.0 in your

Mehrdad Farahani 140 Dec 15, 2022