Robust fine-tuning of zero-shot models

Related tags

Deep Learningwise-ft
Overview

Robust fine-tuning of zero-shot models

This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabriel Ilharco*, Jong Wook Kim, Mike Li, Simon Kornblith, Rebecca Roelofs, Raphael Gontijo-Lopes, Hannaneh Hajishirzi, Ali Farhadi, Hongseok Namkoong, Ludwig Schmidt.

Abstract

Large pre-trained models such as CLIP offer consistent accuracy across a range of data distributions when performing zero-shot inference (i.e., without fine-tuning on a specific dataset). Although existing fine-tuning approaches substantially improve accuracy in-distribution, they also reduce out-of-distribution robustness. We address this tension by introducing a simple and effective method for improving robustness: ensembling the weights of the zero-shot and fine-tuned models. Compared to standard fine-tuning, the resulting weight-space ensembles provide large accuracy improvements out-of-distribution, while matching or improving in-distribution accuracy. On ImageNet and five derived distribution shifts, weight-space ensembles improve out-of-distribution accuracy by 2 to 10 percentage points while increasing in-distribution accuracy by nearly 1 percentage point relative to standard fine-tuning. These improvements come at no additional computational cost during fine-tuning or inference.

Summary figure

figure1

Compared to standard fine-tuning, weight-space ensembles for fine-tuning (WiSE-FT) improve out-of-distribution (OOD) accuracy without decreasing in-distribution (ID) performance. Top left: Zero-shot CLIP models exhibit high effective robustness and moderate in-distribution accuracy, while standard fine-tuning (end-to-end or with a linear classifier) attains higher ID accuracy and less effective robustness. Top right: Our method linearly interpolates between the zero-shot and fine-tuned models with a mixing coefficient alpha in [0,1]. Bottom: On five distribution shifts derived from ImageNet (ImageNetV2, ImageNet-R, ImageNet Sketch, ObjectNet, and ImageNet-A), WiSE-FT improves average OOD accuracy by 8.7 percentage points (pp) when fine-tuning end-to-end (+2.1 pp when fine-tuning a linear classifier) while maintaining ID accuracy.

Code

Overview

WiSE-FT can be implemented in a few lines of code in addition to standard fine-tuning, as shown below. See src/wise_ft.py for more details.

# Load models
zeroshot = ImageClassifier.load(zeroshot_checkpoint)
finetuned = ImageClassifier.load(finetuned_checkpoint)
theta_0 = zeroshot.state_dict()
theta_1 = finetuned.state_dict()

# make sure checkpoints are compatible
assert set(theta_0.keys()) == set(theta_1.keys())

# interpolate between checkpoints with mixing coefficient alpha
theta = {
    key: (1-alpha) * theta_0[key] + alpha * theta_1[key]
    for key in theta_0.keys()
}

# update the model acccording to the new weights
finetuned.load_state_dict(theta)

# evaluate
evaluate(finetuned, args)

Install dependencies

conda env create
conda activate wiseft

Add directory to PYTHONPATH:

cd wise-ft
export PYTHONPATH="$PYTHONPATH:$PWD"

Download data

When necessary, please refer to datasets.md for instructions on how to download datasets.

Run WiSE-FT

Sample command when zeroshot and fine-tuned models are available:

python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Sample command for running WiSE-FT from scratch using ViT-B/32:

python src/wise_ft.py   \
    --train-dataset=ImageNet  \
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=512  \
    --cache-dir=cache  \
    --model=ViT-B/32  \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --template=openai_imagenet_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ViTB32  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Note: the flag --freeze-encoder controls whether only a linear classifier is fine-tuned, or if all weights are fine-tuned (end-to-end).

Plotting results

Sample command for generating a scatter plot:

python src/scatter_plot.py  \
    --eval-datasets=ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --results-db=results.jsonl  \
    --save plots

We show samples of expected behavior below when running the commands above using ViT-B/16 (models can be downloaded here):

ImageNet-Sketch         ImageNet-A

ImageNet-R         ImageNetV2

ObjectNet

Citing

If you found this repository useful, please consider citing:

@article{wortsman2021robust,
  title={Robust fine-tuning of zero-shot models},
  author={Wortsman, Mitchell and Ilharco, Gabriel and Kim, Jong Wook and Li, Mike and Kornblith, Simon and Roelofs, Rebecca and Gontijo-Lopes, Raphael and Hajishirzi, Hannaneh and Farhadi, Ali and Namkoong, Hongseok and Schmidt, Ludwig},
  journal={arXiv preprint arXiv:2109.01903},
  note={\url{https://arxiv.org/abs/2109.01903}},
  year={2021}
}
An Empirical Investigation of Model-to-Model Distribution Shifts in Trained Convolutional Filters

CNN-Filter-DB An Empirical Investigation of Model-to-Model Distribution Shifts in Trained Convolutional Filters Paul Gavrikov, Janis Keuper Paper: htt

Paul Gavrikov 18 Dec 30, 2022
I-BERT: Integer-only BERT Quantization

I-BERT: Integer-only BERT Quantization HuggingFace Implementation I-BERT is also available in the master branch of HuggingFace! Visit the following li

Sehoon Kim 139 Dec 27, 2022
Machine-in-the-Loop Rewriting for Creative Image Captioning

Machine-in-the-Loop Rewriting for Creative Image Captioning Data Annotated sources of data used in the paper: Data Source URL Mohammed et al. Link Gor

Vishakh P 6 Jul 24, 2022
Jarvis Project is a basic virtual assistant that uses TensorFlow for learning.

Jarvis_proyect Jarvis Project is a basic virtual assistant that uses TensorFlow for learning. Latest version 0.1 Features: Good morning protocol Tell

Anze Kovac 3 Aug 31, 2022
Code for "Finding Regions of Heterogeneity in Decision-Making via Expected Conditional Covariance" at NeurIPS 2021

Finding Regions of Heterogeneity in Decision-Making via Expected Conditional Covariance Justin Lim, Christina X Ji, Michael Oberst, Saul Blecker, Leor

Sontag Lab 3 Feb 03, 2022
Self-supervised learning (SSL) is a method of machine learning

Self-supervised learning (SSL) is a method of machine learning. It learns from unlabeled sample data. It can be regarded as an intermediate form between supervised and unsupervised learning.

Ashish Patel 4 May 26, 2022
[NeurIPS'21 Spotlight] PyTorch code for our paper "Aligned Structured Sparsity Learning for Efficient Image Super-Resolution"

ASSL This repository is for a new network pruning method (Aligned Structured Sparsity Learning, ASSL) for efficient single image super-resolution (SR)

Huan Wang 47 Nov 28, 2022
A tensorflow model that predicts if the image is of a cat or of a dog.

Quick intro Hello and thank you for your interest in my project! This is the backend part of a two-repo application. The other part can be found here

Tudor Matei 0 Mar 08, 2022
Submission to Twitter's algorithmic bias bounty challenge

Twitter Ethics Challenge: Pixel Perfect Submission to Twitter's algorithmic bias bounty challenge, by Travis Hoppe (@metasemantic). Abstract We build

Travis Hoppe 4 Aug 19, 2022
TJU Deep Learning & Neural Network

Deep_Learning & Neural_Network_Lab 实验环境 Python 3.9 Anaconda3(官网下载或清华镜像都行) PyTorch 1.10.1(安装代码如下) conda install pytorch torchvision torchaudio cudatool

St3ve Lee 1 Jan 19, 2022
Differentiable scientific computing library

xitorch: differentiable scientific computing library xitorch is a PyTorch-based library of differentiable functions and functionals that can be widely

98 Dec 26, 2022
Tutorials and implementations for "Self-normalizing networks"

Self-Normalizing Networks Tutorials and implementations for "Self-normalizing networks"(SNNs) as suggested by Klambauer et al. (arXiv pre-print). Vers

Institute of Bioinformatics, Johannes Kepler University Linz 1.6k Jan 07, 2023
Making self-supervised learning work on molecules by using their 3D geometry to pre-train GNNs. Implemented in DGL and Pytorch Geometric.

3D Infomax improves GNNs for Molecular Property Prediction Video | Paper We pre-train GNNs to understand the geometry of molecules given only their 2D

Hannes Stärk 95 Dec 30, 2022
Implementing yolov4 target detection and tracking based on nao robot

Implementing yolov4 target detection and tracking based on nao robot

6 Apr 19, 2022
Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

FPT_data_centric_competition - Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

Pham Viet Hoang (Harry) 2 Oct 30, 2022
[ICCV 2021 Oral] SnowflakeNet: Point Cloud Completion by Snowflake Point Deconvolution with Skip-Transformer

This repository contains the source code for the paper SnowflakeNet: Point Cloud Completion by Snowflake Point Deconvolution with Skip-Transformer (ICCV 2021 Oral). The project page is here.

AllenXiang 65 Dec 26, 2022
CV backbones including GhostNet, TinyNet and TNT, developed by Huawei Noah's Ark Lab.

CV Backbones including GhostNet, TinyNet, TNT (Transformer in Transformer) developed by Huawei Noah's Ark Lab. GhostNet Code TinyNet Code TNT Code Pyr

HUAWEI Noah's Ark Lab 3k Jan 08, 2023
PyTorch DepthNet Training on Still Box dataset

DepthNet training on Still Box Project page This code can replicate the results of our paper that was published in UAVg-17. If you use this repo in yo

Clément Pinard 115 Nov 21, 2022
Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

Mikhail Grankin 768 Jan 01, 2023
An University Project of Quera Web Crawling.

WebCrawlerProject An University Project of Quera Web Crawling. خزشگر اینستاگرام در این پروژه شما باید با استفاده از کتابخانه های زیر یک خزشگر اینستاگر

Mahdi 3 Aug 12, 2022