Learning to Prompt for Vision-Language Models.

Related tags

Deep LearningCoOp
Overview

CoOp

Paper: Learning to Prompt for Vision-Language Models

Authors: Kaiyang Zhou, Jingkang Yang, Chen Change Loy, Ziwei Liu

CoOp (Context Optimization) is a differentiable approach that focuses on continuous prompt learning to facilitate deployment of pre-trained vision language models (like CLIP) in downstream datasets.

Updates

  • 15.10.2021: We find that the best_val model and the last_step model achieve similar performance, so we set TEST.FINAL_MODEL = "last_step" for all datasets to save training time. Why we used best_val: the (tiny) validation set was designed for the linear probe approach, which requires extensive tuning for its hyperparameters, so we used the best_val model for CoOp as well for fair comparison (in this way, both approaches have access to the validation set).

  • 09.10.2021: Important changes are made to Dassl's transforms.py. Please pull the latest commits from https://github.com/KaiyangZhou/Dassl.pytorch and this repo to make sure the code works properly. In particular, 1) center_crop now becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio), and 2) for training, Resize(cfg.INPUT.SIZE) is deactivated when random_crop or random_resized_crop is used. Please read this issue on how these changes might affect the performance.

  • 18.09.2021: We have fixed an error in Dassl which could cause a training data loader to have zero length (so no training will be performed) when the dataset size is smaller than the batch size (due to drop_last=True). Please pull the latest commit for Dassl (>= 8eecc3c). This error led to lower results for CoOp in EuroSAT's 1- and 2-shot settings (others are all correct). We will update the paper on arxiv to fix this error.

How to Install

This code is built on top of the awesome toolbox Dassl.pytorch so you need to install the dassl environment first. Simply follow the instructions described here to install dassl as well as PyTorch. After that, run pip install -r requirements.txt under CoOp/ to install a few more packages required by CLIP (this should be done when dassl is activated). Then, you are ready to go.

Follow DATASETS.md to install the datasets.

How to Run

We provide the running scripts in scripts/. Make sure you change the path in DATA and run the commands under CoOp/scripts/.

Few-Shot Learning

All you need is CoOp/scripts/main.sh, which contains six input arguments.

DATASET takes as input a dataset name, like imagenet or caltech101. The valid names are the files' names in CoOp/configs/datasets/.

CFG means which config file to use, such as rn50, rn101 or vit_b32 (see CoOp/configs/trainers/CoOp/). Note that for ImageNet, we use CoOp/configs/trainers/CoOp/*_ep50.yaml for all settings (please follow the implementation details shown in the paper).

Below we provide examples on how to run CoOp on Caltech101.

CLIP + CoOp (M=16, end):

  • 1 shot: bash main.sh caltech101 rn50_ep50 end 16 1 False
  • 2 shots: bash main.sh caltech101 rn50_ep100 end 16 2 False
  • 4 shots: bash main.sh caltech101 rn50_ep100 end 16 4 False
  • 8 shots: bash main.sh caltech101 rn50 end 16 8 False
  • 16 shots: bash main.sh caltech101 rn50 end 16 16 False

CLIP + CoOp (M=16, mid):

  • 1 shot: bash main.sh caltech101 rn50_ep50 middle 16 1 False
  • 2 shots: bash main.sh caltech101 rn50_ep100 middle 16 2 False
  • 4 shots: bash main.sh caltech101 rn50_ep100 middle 16 4 False
  • 8 shots: bash main.sh caltech101 rn50 middle 16 8 False
  • 16 shots: bash main.sh caltech101 rn50 middle 16 16 False

CLIP + CoOp (M=16, end, CSC):

  • 1 shot: bash main.sh caltech101 rn50_ep50 end 16 1 True
  • 2 shots: bash main.sh caltech101 rn50_ep100 end 16 2 True
  • 4 shots: bash main.sh caltech101 rn50_ep100 end 16 4 True
  • 8 shots: bash main.sh caltech101 rn50 end 16 8 True
  • 16 shots: bash main.sh caltech101 rn50 end 16 16 True

CLIP + CoOp (M=16, mid, CSC):

  • 1 shot: bash main.sh caltech101 rn50_ep50 middle 16 1 True
  • 2 shots: bash main.sh caltech101 rn50_ep100 middle 16 2 True
  • 4 shots: bash main.sh caltech101 rn50_ep100 middle 16 4 True
  • 8 shots: bash main.sh caltech101 rn50 middle 16 8 True
  • 16 shots: bash main.sh caltech101 rn50 middle 16 16 True

After the experiments are finished, you can use parse_test_res.py to calculate the average results instead of manually looking into the log files. Say the structure of output/ is

output
|–– caltech101/
|   |–– CoOp/
|   |   |–– rn50_16shots/
|   |   |   |–– nctx16_cscFalse_ctpend/
|   |   |   |   |–– seed1/
|   |   |   |   |–– seed2/
|   |   |   |   |–– seed3/
|   |   |–– rn50_8shots/
|   |   |   |–– nctx16_cscFalse_ctpend/
|   |   |   |   |–– seed1/
|   |   |   |   |–– seed2/
|   |   |   |   |–– seed3/

To calculate the average results for the folder rn50_16shots/nctx16_cscFalse_ctpend/, you can run

python parse_test_res.py output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend

Then, you will see something like this in your terminal

Parsing files in output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed1/log.txt. accuracy: 91.81%. error: 8.19%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed2/log.txt. accuracy: 92.01%. error: 7.99%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed3/log.txt. accuracy: 92.17%. error: 7.83%.
===
Summary of directory: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
* accuracy: 92.00% +- 0.15%
* error: 8.00% +- 0.15%
===

How to initialize the context tokens with pre-trained word vectors? Specify the words for the parameter TRAINER.COOP.CTX_INIT in your config file. In our paper, we use configs/trainers/rn50_ctxv1.yaml (give this file to --config-file, see scripts/main.sh), which uses "a photo of a" as the initialization words.

How to visualize nearest words for the learned context tokens? All you need is interpret_prompt.py. Say the learned tokens are saved in a/b/c/prompt_learner/model.pth.tar and you would like to see the top-3 nearest words for each token. In this case, run python interpret_prompt.py a/b/c/prompt_learner/model.pth.tar 3

Robustness to Distribution Shift

To reproduce the robustness experiments, you can simply load the models learned on ImageNet and evaluate them on the following datasets: imagenetv2, imagenet-sketch, imagenet-a and imagenet-r.

The command is provided in CoOp/scripts/eval.sh. The key arguments are --model-dir, --load-epoch and --eval-only. --model-dir indicates the directory where the models are saved (i.e. the entire folder containing log.txt, the tensorboard file and prompt_learner/). --load-epoch tells the code to load the model saved at a specific epoch, like --load-epoch 50 for ImageNet (see the source code for more details).

For example, to evaluate CLIP + CoOp (M=16, end) on ImageNetV2, you can do

# Don't need to use rn5_ep50 here as no training is performed
bash eval.sh imagenetv2 rn50

The default setting is SHOTS=16. Feel free to modify the script.

Again, you can use parse_test_res.py to automate the calculation of average performance. This time you should append --test-log, e.g., python parse_test_res.py directory --test-log.

Zero-Shot CLIP

See CoOp/scripts/zeroshot.sh.

Linear Probe CLIP

Please move to lpclip/.

How to Cite CoOp

If you use this code in your research, please kindly cite the following paper

@article{zhou2021coop,
    title={Learning to Prompt for Vision-Language Models},
    author={Zhou, Kaiyang and Yang, Jingkang and Loy, Chen Change and Liu, Ziwei},
    journal={arXiv preprint arXiv:2109.01134},
    year={2021}
}
Owner
Kaiyang
Kaiyang
EMNLP 2021 paper Models and Datasets for Cross-Lingual Summarisation.

This repository contains data and code for our EMNLP 2021 paper Models and Datasets for Cross-Lingual Summarisation. Please contact me at

9 Oct 28, 2022
Repository for reproducing `Model-Based Robust Deep Learning`

Model-Based Robust Deep Learning (MBRDL) In this repository, we include the code necessary for reproducing the code used in Model-Based Robust Deep Le

Alex Robey 16 Sep 19, 2022
This repo contains source code and materials for the TEmporally COherent GAN SIGGRAPH project.

TecoGAN This repository contains source code and materials for the TecoGAN project, i.e. code for a TEmporally COherent GAN for video super-resolution

Nils Thuerey 5.2k Jan 02, 2023
A PyTorch implementation of "CoAtNet: Marrying Convolution and Attention for All Data Sizes".

CoAtNet Overview This is a PyTorch implementation of CoAtNet specified in "CoAtNet: Marrying Convolution and Attention for All Data Sizes", arXiv 2021

Justin Wu 268 Jan 07, 2023
We simulate traveling back in time with a modern camera to rephotograph famous historical subjects.

[SIGGRAPH Asia 2021] Time-Travel Rephotography [Project Website] Many historical people were only ever captured by old, faded, black and white photos,

298 Jan 02, 2023
Anomaly detection in multi-agent trajectories: Code for training, evaluation and the OpenAI highway simulation.

Anomaly Detection in Multi-Agent Trajectories for Automated Driving This is the official project page including the paper, code, simulation, baseline

12 Dec 02, 2022
Split Variational AutoEncoder

Split-VAE Split Variational AutoEncoder Introduction This repository contains and implemementation of a Split Variational AutoEncoder (SVAE). In a SVA

Andrea Asperti 2 Sep 02, 2022
Tree-based Search Graph for Approximate Nearest Neighbor Search

TBSG: Tree-based Search Graph for Approximate Nearest Neighbor Search. TBSG is a graph-based algorithm for ANNS based on Cover Tree, which is also an

Fanxbin 2 Dec 27, 2022
Modification of convolutional neural net "UNET" for image segmentation in Keras framework

ZF_UNET_224 Pretrained Model Modification of convolutional neural net "UNET" for image segmentation in Keras framework Requirements Python 3.*, Keras

209 Nov 02, 2022
A Neural Net Training Interface on TensorFlow, with focus on speed + flexibility

Tensorpack is a neural network training interface based on TensorFlow. Features: It's Yet Another TF high-level API, with speed, and flexibility built

Tensorpack 6.2k Jan 01, 2023
JstDoS - HTTP Protocol Stack Remote Code Execution Vulnerability

jstDoS If you are going to skid that, please give credits ! ^^ ¿How works? This

apolo 4 Feb 11, 2022
Torchlight2 lan game server tool - A message forwarding tool for Torchlight 2 lan game

Torchlight 2 Lan Game Server Tool A message forwarding tool for Torchlight 2 lan

Huaijun Jiang 3 Nov 01, 2022
A repository for the updated version of CoinRun used to collect MUGEN, a multimodal video-audio-text dataset.

A repository for the updated version of CoinRun used to collect MUGEN, a multimodal video-audio-text dataset. This repo contains scripts to train RL agents to navigate the closed world and collect vi

MUGEN 11 Oct 22, 2022
MAGMA - a GPT-style multimodal model that can understand any combination of images and language

MAGMA -- Multimodal Augmentation of Generative Models through Adapter-based Finetuning Authors repo (alphabetical) Constantin (CoEich), Mayukh (Mayukh

Aleph Alpha GmbH 331 Jan 03, 2023
Implementation for ACProp ( Momentum centering and asynchronous update for adaptive gradient methdos, NeurIPS 2021)

This repository contains code to reproduce results for submission NeurIPS 2021, "Momentum Centering and Asynchronous Update for Adaptive Gradient Meth

Juntang Zhuang 15 Jun 11, 2022
Using this you can control your PC/Laptop volume by Hand Gestures (pinch-in, pinch-out) created with Python.

Hand Gesture Volume Controller Using this you can control your PC/Laptop volume by Hand Gestures (pinch-in, pinch-out). Code Firstly I have created a

Tejas Prajapati 16 Sep 11, 2021
Distributed Evolutionary Algorithms in Python

DEAP DEAP is a novel evolutionary computation framework for rapid prototyping and testing of ideas. It seeks to make algorithms explicit and data stru

Distributed Evolutionary Algorithms in Python 4.9k Jan 05, 2023
Image augmentation library in Python for machine learning.

Augmentor is an image augmentation library in Python for machine learning. It aims to be a standalone library that is platform and framework independe

Marcus D. Bloice 4.8k Jan 07, 2023
Deformable DETR is an efficient and fast-converging end-to-end object detector.

Deformable DETR: Deformable Transformers for End-to-End Object Detection.

2k Jan 05, 2023
official implemntation for "Contrastive Learning with Stronger Augmentations"

CLSA CLSA is a self-supervised learning methods which focused on the pattern learning from strong augmentations. Copyright (C) 2020 Xiao Wang, Guo-Jun

Lab for MAchine Perception and LEarning (MAPLE) 47 Nov 29, 2022