Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

Overview

Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

The code repository for "Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation" [paper, to appear] [slides, to appear] [poster, to appear] in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:

@misc{Proto-CAT,
  author = {Yi-Kai Zhang},
  title = {Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/ZhangYikaii/Proto-CAT}},
  commit = {main}
}

Prototype-based Co-Adaptation with Transformer

Illustration of Proto-CAT. The model transforms the classification space using [公式] based on two kinds of audio-visual prototypes (class centers): (1) the base training categories (color with [公式]blue, [公式]green, and [公式]pink); and (2) the additional novel test categories (color with [公式]burning transition). Proto-CAT learns and generalizes on novel test categories from limited labeled examples, maintaining performance on the base training ones. [公式] includes audio-visual level and category level prototype-based co-adaptation. From left to right, more coverage and more bright colors represent a more reliable classification space.

 

Results

Dataset LRW LRW-1000
Data Source Audio () Video () Audio-Video () Audio-Video ()
Perf. Measures on H-mean H-mean Base Novel H-mean Base Novel H-mean
LSTM-based 32.20 8.00 97.09 23.76 37.22 71.34 0.03 0.07
GRU-based 37.01 10.58 97.44 27.35 41.71 71.34 0.05 0.09
MS-TCN-based 62.29 19.06 80.96 51.28 61.76 71.55 0.33 0.63
ProtoNet-GFSL 39.95 14.40 96.33 39.23 54.79 69.33 0.76 1.47
FEAT-GFSL 49.90 25.75 96.26 54.52 68.83 71.69 2.62 4.89
DFSL 72.13 42.56 66.10 84.62 73.81 31.68 68.72 42.56
CASTLE 75.48 34.68 73.50 90.20 80.74 11.13 54.07 17.84
Proto-CAT (Ours) 84.18 74.55 93.37 91.20 92.13 49.70 38.27 42.25
Proto-CAT+ (Ours) 93.18 90.16 91.49 54.55 38.16 43.88

Audio-visual generalized few-shot learning classification performance (in %; measured over 10,000 rounds; higher is better) of 5-way 1-shot training tasks on LRW and LRW-1000 datasets. The best result of each scenario is in bold font. The performance measure on both base and novel classes (Base, Novel in the table) is mean accuracy. Harmonic mean (i.e., H-mean) of the above two is a better generalized few-shot learning performance measure.

 

Prerequisites

Environment

Please refer to requirements.txt and run:

pip install -r requirement.txt

Dataset

  • Use preprocessed data (suggested):

    LRW and LRW-1000 forbid directly share the preprocessed data.

  • Use raw data and do preprocess:

    Download LRW Dataset and unzip, like,

    /your data_path set in .sh file
    ├── lipread_mp4
    │   ├── [ALL CLASS FOLDER]
    │   ├── ...
    

    Run prepare_lrw_audio.py and prepare_lrw_video.py to preprocess data on video and audio modality, respectively. Please modify the data path in the above preprocessing file in advance.

    Similarly, Download LRW-1000 dataset and unzip. Run prepare_lrw1000_audio.py and prepare_lrw1000_video.py to preprocess it.

Pretrained Weights

We provide pretrained weights on LRW and LRW-1000 dataset. Download from Google Drive or Baidu Yun(password: 3ad2) and put them as:

/your init_weights set in .sh file
├── Conv1dResNetGRU_LRW-pre.pth
├── Conv3dResNetLSTM_LRW-pre.pth
├── Conv1dResNetGRU_LRW1000-pre.pth
├── Conv3dResNetLSTM_LRW1000-pre.pth

 

How to Train Proto-CAT

For LRW dataset, fine-tune the parameters in run/protocat_lrw.sh, and run:

cd ./Proto-CAT/run
bash protocat_lrw.sh

Similarly, run bash protocat_lrw1000.sh for dataset LRW-1000.

Run bash protocat_plus_lrw.sh / bash protocat_plus_lrw1000.sh to train Proto-CAT+.

How to Reproduce the Result of Proto-CAT

Download the trained models from Google Drive or Baidu Yun(password: swzd) and run:

bash test_protocat_lrw.sh

Run bash test_protocat_lrw1000.sh, bash test_protocat_plus_lrw.sh, or bash test_protocat_plus_lrw1000.sh to evaluate other models.

 

Code Structures

Proto-CAT's entry function is in main.py. It calls the manager Trainer in models/train.py that contains the main training logic. In Trainer, prepare_handle.prepare_dataloader combined with train_prepare_batch inputs and preprocesses generalized few-shot style data. fit_handle controls forward and backward propagation. callbacks deals with the behaviors at each stage.

Arguments

All parameters are defined in models/utils.py. We list the main ones below:

  • do_train, do_test: Store-true switch for whether to train or test.
  • data_path: Data directory to be set.
  • model_save_path: Optimal model save directory to be set.
  • init_weights: Pretrained weights to be set.
  • dataset: Option for the dataset.
  • model_class: Option for the top model.
  • backend_type: Option list for the backend type.
  • train_way, val_way, test_way, train_shot, val_shot, test_shot, train_query, val_query, test_query: Tasks setting of generalized few-shot learning.
  • gfsl_train, gfsl_test: Switch for whether train or test in generalized few-shot learning way, i.e., whether additional base class data is included.
  • mm_list: Participating modalities.
  • lr_scheduler: List of learning rate scheduler.
  • loss_fn: Option for the loss function.
  • max_epoch: Maximum training epoch.
  • episodes_per_train_epoch, episodes_per_val_epoch, episodes_per_test_epoch: Number of sampled episodes per epoch.
  • num_tasks: Number of tasks per episode.
  • meta_batch_size: Batch size of each task.
  • test_model_filepath: Trained weights .pth file path when testing a model.
  • gpu: Multi-GPU option like --gpu 0,1,2,3.
  • logger_filename: Logger file save directory.
  • time_str: Token for each run, and will generate by itself if empty.
  • acc_per_class: Switch for whether to measure the accuracy of each class with base, novel, and harmonic mean.
  • verbose, epoch_verbose: Switch for whether to output message or output progress bar.
  • torch_seed, cuda_seed, np_seed, random_seed: Seeds of random number generation.

 

Acknowledgment

We thank the following repos providing helpful components/functions in our work.

Owner
Kaiaicy
Few-Shot Learning
Kaiaicy
Scripts and misc. stuff related to the PortSwigger Web Academy

PortSwigger Web Academy Notes Mostly scripts to automate the exploits. Going in the order of the recomended learning path - starting with SQLi. Commun

pageinsec 17 Dec 30, 2022
PyTorch Implementations for DeeplabV3 and PSPNet

Pytorch-segmentation-toolbox DOC Pytorch code for semantic segmentation. This is a minimal code to run PSPnet and Deeplabv3 on Cityscape dataset. Shor

Zilong Huang 746 Dec 15, 2022
Collaborative forensic timeline analysis

Timesketch Table of Contents About Timesketch Getting started Community Contributing About Timesketch Timesketch is an open-source tool for collaborat

Google 2.1k Dec 28, 2022
Distilled coarse part of LoFTR adapted for compatibility with TensorRT and embedded divices

Coarse LoFTR TRT Google Colab demo notebook This project provides a deep learning model for the Local Feature Matching for two images that can be used

Kirill 46 Dec 24, 2022
Image Segmentation Evaluation

Image Segmentation Evaluation Martin Keršner, [email protected] Evaluation

Martin Kersner 273 Oct 28, 2022
Performant, differentiable reinforcement learning

deluca Performant, differentiable reinforcement learning Notes This is pre-alpha software and is undergoing a number of core changes. Updates to follo

Google 114 Dec 27, 2022
DeepFill v1/v2 with Contextual Attention and Gated Convolution, CVPR 2018, and ICCV 2019 Oral

Generative Image Inpainting An open source framework for generative image inpainting task, with the support of Contextual Attention (CVPR 2018) and Ga

2.9k Dec 16, 2022
StellarGraph - Machine Learning on Graphs

StellarGraph Machine Learning Library StellarGraph is a Python library for machine learning on graphs and networks. Table of Contents Introduction Get

S T E L L A R 2.6k Jan 05, 2023
The Fundamental Clustering Problems Suite (FCPS) summaries 54 state-of-the-art clustering algorithms, common cluster challenges and estimations of the number of clusters as well as the testing for cluster tendency.

FCPS Fundamental Clustering Problems Suite The package provides over sixty state-of-the-art clustering algorithms for unsupervised machine learning pu

9 Nov 27, 2022
A Nim frontend for pytorch, aiming to be mostly auto-generated and internally using ATen.

Master Release Pytorch - Py + Nim A Nim frontend for pytorch, aiming to be mostly auto-generated and internally using ATen. Because Nim compiles to C+

Giovanni Petrantoni 425 Dec 22, 2022
Code Impementation for "Mold into a Graph: Efficient Bayesian Optimization over Mixed Spaces"

Code Impementation for "Mold into a Graph: Efficient Bayesian Optimization over Mixed Spaces" This repo contains the implementation of GEBO algorithm.

Jaeyeon Ahn 2 Mar 22, 2022
Sound Source Localization for AI Grand Challenge 2021

Sound-Source-Localization Sound Source Localization study for AI Grand Challenge 2021 (sponsored by NC Soft Vision Lab) Preparation 1. Place the data-

sanghoon 19 Mar 29, 2022
Safe Local Motion Planning with Self-Supervised Freespace Forecasting, CVPR 2021

Safe Local Motion Planning with Self-Supervised Freespace Forecasting By Peiyun Hu, Aaron Huang, John Dolan, David Held, and Deva Ramanan Citing us Yo

Peiyun Hu 90 Dec 01, 2022
Multistream CNN for Robust Acoustic Modeling

Multistream Convolutional Neural Network (CNN) A multistream CNN is a novel neural network architecture for robust acoustic modeling in speech recogni

ASAPP Research 37 Sep 21, 2022
A toy project using OpenCV and PyMunk

A toy project using OpenCV, PyMunk and Mediapipe the source code for my LindkedIn post It's just a toy project and I didn't write a documentation yet,

Amirabbas Asadi 82 Oct 28, 2022
Code and project page for ICCV 2021 paper "DisUnknown: Distilling Unknown Factors for Disentanglement Learning"

DisUnknown: Distilling Unknown Factors for Disentanglement Learning See introduction on our project page Requirements PyTorch = 1.8.0 torch.linalg.ei

Sitao Xiang 24 May 16, 2022
La source de mon module 'pyfade' disponible sur Pypi.

Version: 1.2 Introduction Pyfade est un module permettant de créer des dégradés colorés. Il vous permettra de changer chaque ligne de votre texte par

Billy 20 Sep 12, 2021
Few-shot Neural Architecture Search

One-shot Neural Architecture Search uses a single supernet to approximate the performance each architecture. However, this performance estimation is super inaccurate because of co-adaption among oper

Yiyang Zhao 38 Oct 18, 2022
This is an open source library implementing hyperbox-based machine learning algorithms

hyperbox-brain is a Python open source toolbox implementing hyperbox-based machine learning algorithms built on top of scikit-learn and is distributed

Complex Adaptive Systems (CAS) Lab - University of Technology Sydney 21 Dec 14, 2022
End-to-End Dense Video Captioning with Parallel Decoding (ICCV 2021)

PDVC Official implementation for End-to-End Dense Video Captioning with Parallel Decoding (ICCV 2021) [paper] [valse论文速递(Chinese)] This repo supports:

Teng Wang 118 Dec 16, 2022