Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

Overview

Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

The code repository for "Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation" [paper, to appear] [slides, to appear] [poster, to appear] in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:

@misc{Proto-CAT,
  author = {Yi-Kai Zhang},
  title = {Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/ZhangYikaii/Proto-CAT}},
  commit = {main}
}

Prototype-based Co-Adaptation with Transformer

Illustration of Proto-CAT. The model transforms the classification space using [公式] based on two kinds of audio-visual prototypes (class centers): (1) the base training categories (color with [公式]blue, [公式]green, and [公式]pink); and (2) the additional novel test categories (color with [公式]burning transition). Proto-CAT learns and generalizes on novel test categories from limited labeled examples, maintaining performance on the base training ones. [公式] includes audio-visual level and category level prototype-based co-adaptation. From left to right, more coverage and more bright colors represent a more reliable classification space.

 

Results

Dataset LRW LRW-1000
Data Source Audio () Video () Audio-Video () Audio-Video ()
Perf. Measures on H-mean H-mean Base Novel H-mean Base Novel H-mean
LSTM-based 32.20 8.00 97.09 23.76 37.22 71.34 0.03 0.07
GRU-based 37.01 10.58 97.44 27.35 41.71 71.34 0.05 0.09
MS-TCN-based 62.29 19.06 80.96 51.28 61.76 71.55 0.33 0.63
ProtoNet-GFSL 39.95 14.40 96.33 39.23 54.79 69.33 0.76 1.47
FEAT-GFSL 49.90 25.75 96.26 54.52 68.83 71.69 2.62 4.89
DFSL 72.13 42.56 66.10 84.62 73.81 31.68 68.72 42.56
CASTLE 75.48 34.68 73.50 90.20 80.74 11.13 54.07 17.84
Proto-CAT (Ours) 84.18 74.55 93.37 91.20 92.13 49.70 38.27 42.25
Proto-CAT+ (Ours) 93.18 90.16 91.49 54.55 38.16 43.88

Audio-visual generalized few-shot learning classification performance (in %; measured over 10,000 rounds; higher is better) of 5-way 1-shot training tasks on LRW and LRW-1000 datasets. The best result of each scenario is in bold font. The performance measure on both base and novel classes (Base, Novel in the table) is mean accuracy. Harmonic mean (i.e., H-mean) of the above two is a better generalized few-shot learning performance measure.

 

Prerequisites

Environment

Please refer to requirements.txt and run:

pip install -r requirement.txt

Dataset

  • Use preprocessed data (suggested):

    LRW and LRW-1000 forbid directly share the preprocessed data.

  • Use raw data and do preprocess:

    Download LRW Dataset and unzip, like,

    /your data_path set in .sh file
    ├── lipread_mp4
    │   ├── [ALL CLASS FOLDER]
    │   ├── ...
    

    Run prepare_lrw_audio.py and prepare_lrw_video.py to preprocess data on video and audio modality, respectively. Please modify the data path in the above preprocessing file in advance.

    Similarly, Download LRW-1000 dataset and unzip. Run prepare_lrw1000_audio.py and prepare_lrw1000_video.py to preprocess it.

Pretrained Weights

We provide pretrained weights on LRW and LRW-1000 dataset. Download from Google Drive or Baidu Yun(password: 3ad2) and put them as:

/your init_weights set in .sh file
├── Conv1dResNetGRU_LRW-pre.pth
├── Conv3dResNetLSTM_LRW-pre.pth
├── Conv1dResNetGRU_LRW1000-pre.pth
├── Conv3dResNetLSTM_LRW1000-pre.pth

 

How to Train Proto-CAT

For LRW dataset, fine-tune the parameters in run/protocat_lrw.sh, and run:

cd ./Proto-CAT/run
bash protocat_lrw.sh

Similarly, run bash protocat_lrw1000.sh for dataset LRW-1000.

Run bash protocat_plus_lrw.sh / bash protocat_plus_lrw1000.sh to train Proto-CAT+.

How to Reproduce the Result of Proto-CAT

Download the trained models from Google Drive or Baidu Yun(password: swzd) and run:

bash test_protocat_lrw.sh

Run bash test_protocat_lrw1000.sh, bash test_protocat_plus_lrw.sh, or bash test_protocat_plus_lrw1000.sh to evaluate other models.

 

Code Structures

Proto-CAT's entry function is in main.py. It calls the manager Trainer in models/train.py that contains the main training logic. In Trainer, prepare_handle.prepare_dataloader combined with train_prepare_batch inputs and preprocesses generalized few-shot style data. fit_handle controls forward and backward propagation. callbacks deals with the behaviors at each stage.

Arguments

All parameters are defined in models/utils.py. We list the main ones below:

  • do_train, do_test: Store-true switch for whether to train or test.
  • data_path: Data directory to be set.
  • model_save_path: Optimal model save directory to be set.
  • init_weights: Pretrained weights to be set.
  • dataset: Option for the dataset.
  • model_class: Option for the top model.
  • backend_type: Option list for the backend type.
  • train_way, val_way, test_way, train_shot, val_shot, test_shot, train_query, val_query, test_query: Tasks setting of generalized few-shot learning.
  • gfsl_train, gfsl_test: Switch for whether train or test in generalized few-shot learning way, i.e., whether additional base class data is included.
  • mm_list: Participating modalities.
  • lr_scheduler: List of learning rate scheduler.
  • loss_fn: Option for the loss function.
  • max_epoch: Maximum training epoch.
  • episodes_per_train_epoch, episodes_per_val_epoch, episodes_per_test_epoch: Number of sampled episodes per epoch.
  • num_tasks: Number of tasks per episode.
  • meta_batch_size: Batch size of each task.
  • test_model_filepath: Trained weights .pth file path when testing a model.
  • gpu: Multi-GPU option like --gpu 0,1,2,3.
  • logger_filename: Logger file save directory.
  • time_str: Token for each run, and will generate by itself if empty.
  • acc_per_class: Switch for whether to measure the accuracy of each class with base, novel, and harmonic mean.
  • verbose, epoch_verbose: Switch for whether to output message or output progress bar.
  • torch_seed, cuda_seed, np_seed, random_seed: Seeds of random number generation.

 

Acknowledgment

We thank the following repos providing helpful components/functions in our work.

Owner
Kaiaicy
Few-Shot Learning
Kaiaicy
Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context Code in both PyTorch and TensorFlow

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context This repository contains the code in both PyTorch and TensorFlow for our paper

Zhilin Yang 3.3k Jan 06, 2023
An official reimplementation of the method described in the INTERSPEECH 2021 paper - Speech Resynthesis from Discrete Disentangled Self-Supervised Representations.

Speech Resynthesis from Discrete Disentangled Self-Supervised Representations Implementation of the method described in the Speech Resynthesis from Di

Facebook Research 253 Jan 06, 2023
Implementation detail for paper "Multi-level colonoscopy malignant tissue detection with adversarial CAC-UNet"

Multi-level-colonoscopy-malignant-tissue-detection-with-adversarial-CAC-UNet Implementation detail for our paper "Multi-level colonoscopy malignant ti

CVSM Group - email: <a href=[email protected]"> 84 Nov 22, 2022
PyTorch common framework to accelerate network implementation, training and validation

pytorch-framework PyTorch common framework to accelerate network implementation, training and validation. This framework is inspired by works from MML

Dongliang Cao 3 Dec 19, 2022
pybaum provides tools to work with pytrees which is a concept burrowed from JAX.

pybaum provides tools to work with pytrees which is a concept burrowed from JAX.

Open Source Economics 9 May 11, 2022
DeepLab2: A TensorFlow Library for Deep Labeling

DeepLab2 is a TensorFlow library for deep labeling, aiming to provide a unified and state-of-the-art TensorFlow codebase for dense pixel labeling tasks.

Google Research 845 Jan 04, 2023
So-ViT: Mind Visual Tokens for Vision Transformer

So-ViT: Mind Visual Tokens for Vision Transformer        Introduction This repository contains the source code under PyTorch framework and models trai

Jiangtao Xie 44 Nov 24, 2022
The codes and models in 'Gaze Estimation using Transformer'.

GazeTR We provide the code of GazeTR-Hybrid in "Gaze Estimation using Transformer". We recommend you to use data processing codes provided in GazeHub.

65 Dec 27, 2022
Unadversarial Examples: Designing Objects for Robust Vision

Unadversarial Examples: Designing Objects for Robust Vision This repository contains the code necessary to replicate the major results of our paper: U

Microsoft 93 Nov 28, 2022
Few-Shot Graph Learning for Molecular Property Prediction

Few-shot Graph Learning for Molecular Property Prediction Introduction This is the source code and dataset for the following paper: Few-shot Graph Lea

Zhichun Guo 94 Dec 12, 2022
Pytorch codes for Feature Transfer Learning for Face Recognition with Under-Represented Data

FTLNet_Pytorch Pytorch codes for Feature Transfer Learning for Face Recognition with Under-Represented Data 1. Introduction This repo is an unofficial

1 Nov 04, 2020
Data and code for the paper "Importance of Kernel Bandwidth in Quantum Machine Learning"

Reproducibility materials for "Importance of Kernel Bandwidth in Quantum Machine Learning" Repo structure: code contains Python scripts used to genera

Ruslan Shaydulin 3 Oct 23, 2022
Efficient neural networks for analog audio effect modeling

micro-TCN Efficient neural networks for audio effect modeling

Christian Steinmetz 94 Dec 29, 2022
DiffQ performs differentiable quantization using pseudo quantization noise. It can automatically tune the number of bits used per weight or group of weights, in order to achieve a given trade-off between model size and accuracy.

Differentiable Model Compression via Pseudo Quantization Noise DiffQ performs differentiable quantization using pseudo quantization noise. It can auto

Facebook Research 145 Dec 30, 2022
Pathdreamer: A World Model for Indoor Navigation

Pathdreamer: A World Model for Indoor Navigation This repository hosts the open source code for Pathdreamer, to be presented at ICCV 2021. Paper | Pro

Google Research 122 Jan 04, 2023
InsightFace: 2D and 3D Face Analysis Project on MXNet and PyTorch

InsightFace: 2D and 3D Face Analysis Project on MXNet and PyTorch

Deep Insight 13.2k Jan 06, 2023
Self-Supervised Methods for Noise-Removal

SSMNR | Self-Supervised Methods for Noise Removal Image denoising is the task of removing noise from an image, which can be formulated as the task of

1 Jan 16, 2022
Mengzi Pretrained Models

中文 | English Mengzi 尽管预训练语言模型在 NLP 的各个领域里得到了广泛的应用,但是其高昂的时间和算力成本依然是一个亟需解决的问题。这要求我们在一定的算力约束下,研发出各项指标更优的模型。 我们的目标不是追求更大的模型规模,而是轻量级但更强大,同时对部署和工业落地更友好的模型。

Langboat 424 Jan 04, 2023
Pairwise Learning for Neural Link Prediction for OGB (PLNLP-OGB)

Pairwise Learning for Neural Link Prediction for OGB (PLNLP-OGB) This repository provides evaluation codes of PLNLP for OGB link property prediction t

Zhitao WANG 31 Oct 10, 2022
YOLOv3 in PyTorch > ONNX > CoreML > TFLite

This repository represents Ultralytics open-source research into future object detection methods, and incorporates lessons learned and best practices

Ultralytics 9.3k Jan 07, 2023