BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

Overview

key_visual

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

This is a demo implementation of BYOL for Audio (BYOL-A), a self-supervised learning method for general-purpose audio representation, includes:

  • Training code that can train models with arbitrary audio files.
  • Evaluation code that can evaluate trained models with downstream tasks.
  • Pretrained weights.

If you find BYOL-A useful in your research, please use the following BibTeX entry for citation.

@misc{niizumi2021byol-a,
      title={BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation}, 
      author={Daisuke Niizumi and Daiki Takeuchi and Yasunori Ohishi and Noboru Harada and Kunio Kashino},
      booktitle = {2021 International Joint Conference on Neural Networks, {IJCNN} 2021},
      year={2021},
      eprint={2103.06695},
      archivePrefix={arXiv},
      primaryClass={eess.AS}
}

Getting Started

  1. Download external source files, and apply a patch. Our implementation uses the following.

    curl -O https://raw.githubusercontent.com/lucidrains/byol-pytorch/2aa84ee18fafecaf35637da4657f92619e83876d/byol_pytorch/byol_pytorch.py
    patch < byol_a/byol_pytorch.diff
    mv byol_pytorch.py byol_a
    curl -O https://raw.githubusercontent.com/daisukelab/general-learning/7b31d31637d73e1a74aec3930793bd5175b64126/MLP/torch_mlp_clf.py
    mv torch_mlp_clf.py utils
  2. Install PyTorch 1.7.1, torchaudio, and other dependencies listed on requirements.txt.

Evaluating BYOL-A Representations

Downstream Task Evaluation

The following steps will perform a downstream task evaluation by linear-probe fashion. This is an example with SPCV2; Speech commands dataset v2.

  1. Preprocess metadata (.csv file) and audio files, processed files will be stored under a folder work.

    # usage: python -m utils.preprocess_ds <downstream task> <path to its dataset>
    python -m utils.preprocess_ds spcv2 /path/to/speech_commands_v0.02
  2. Run evaluation. This will convert all .wav audio to representation embeddings first, train a lineaer layer network, then calculate accuracy as a result.

    python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth spcv2

You can also run an evaluation multiple times and take an average result. Following will evaluate on UrbanSound8K with a unit audio duration of 4.0 seconds, for 10 times.

# usage: python evaluate.py <your weight> <downstream task> <unit duration sec.> <# of iteration>
python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth us8k 4.0 10

Evaluating Representations In Your Tasks

This is an example to calculate a feature vector for an audio sample.

from byol_a.common import *
from byol_a.augmentations import PrecomputedNorm
from byol_a.models import AudioNTT2020


device = torch.device('cuda')
cfg = load_yaml_config('config.yaml')
print(cfg)

# Mean and standard deviation of the log-mel spectrogram of input audio samples, pre-computed.
# See calc_norm_stats in evaluate.py for your reference.
stats = [-5.4919195,  5.0389895]

# Preprocessor and normalizer.
to_melspec = torchaudio.transforms.MelSpectrogram(
    sample_rate=cfg.sample_rate,
    n_fft=cfg.n_fft,
    win_length=cfg.win_length,
    hop_length=cfg.hop_length,
    n_mels=cfg.n_mels,
    f_min=cfg.f_min,
    f_max=cfg.f_max,
)
normalizer = PrecomputedNorm(stats)

# Load pretrained weights.
model = AudioNTT2020(d=cfg.feature_d)
model.load_weight('pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth', device)

# Load your audio file.
wav, sr = torchaudio.load('work/16k/spcv2/one/00176480_nohash_0.wav') # a sample from SPCV2 for now
assert sr == cfg.sample_rate, "Let's convert the audio sampling rate in advance, or do it here online."

# Convert to a log-mel spectrogram, then normalize.
lms = normalizer((to_melspec(wav) + torch.finfo(torch.float).eps).log())

# Now, convert the audio to the representation.
features = model(lms.unsqueeze(0))

Training From Scratch

You can also train models. Followings are an example of training on FSD50K.

  1. Convert all samples to 16kHz. This will convert all FSD50K files to a folder work/16k/fsd50k while preserving folder structure.

    python -m utils.convert_wav /path/to/fsd50k work/16k/fsd50k
  2. Start training, this example trains with all development set audio samples from FSD50K.

    python train.py work/16k/fsd50k/FSD50K.dev_audio

Refer to Table VI on our paper for the performance of a model trained on FSD50K.

Pretrained Weights

We include 3 pretrained weights of our encoder network.

Method Dim. Filename NSynth US8K VoxCeleb1 VoxForge SPCV2/12 SPCV2 Average
BYOL-A 512-d AudioNTT2020-BYOLA-64x96d512.pth 69.1% 78.2% 33.4% 83.5% 86.5% 88.9% 73.3%
BYOL-A 1024-d AudioNTT2020-BYOLA-64x96d1024.pth 72.7% 78.2% 38.0% 88.5% 90.1% 91.4% 76.5%
BYOL-A 2048-d AudioNTT2020-BYOLA-64x96d2048.pth 74.1% 79.1% 40.1% 90.2% 91.0% 92.2% 77.8%

License

This implementation is for your evaluation of BYOL-A paper, see LICENSE for the detail.

Acknowledgements

BYOL-A is built on top of byol-pytorch, a BYOL implementation by Phil Wang (@lucidrains). We thank Phil for open-source sophisticated code.

@misc{wang2020byol-pytorch,
  author =       {Phil Wang},
  title =        {Bootstrap Your Own Latent (BYOL), in Pytorch},
  howpublished = {\url{https://github.com/lucidrains/byol-pytorch}},
  year =         {2020}
}

References

Comments
  • Question for reproducing results

    Question for reproducing results

    Hi,

    Thanks for sharing this great work! I tried to reproduce the results using the official guidance but I failed.

    After processing the data, I run the following commands:

    CUDA_VISIBLE_DEVICES=0 python -W ignore train.py work/16k/fsd50k/FSD50K.dev_audio
    cp lightning_logs/version_4/checkpoints/epoch\=99-step\=16099.ckpt AudioNTT2020-BYOLA-64x96d2048.pth
    CUDA_VISIBLE_DEVICES=4 python evaluate.py AudioNTT2020-BYOLA-64x96d2048.pth spcv2
    

    However, the results are far from the reported results

    image

    Did I miss something important? Thank you very much.

    question 
    opened by ChenyangLEI 15
  • Evaluation on voxforge

    Evaluation on voxforge

    Hi,

    Thank you so much for your contribution. This works is very interesting and your code is easy for me to follow. But one of the downstream dataset, voxforge is missing from the preprocess_ds.py. Could you please release the code for that dataset, too?

    Thank you again for your time.

    Best regards

    opened by Huiimin5 9
  • A mistake in RunningMean

    A mistake in RunningMean

    Thank you for the fascinating paper and the code to reproduce it!

    I think there might be a problem in RunningMean. The current formula (the same in v1 and v2) looks like this:

    $$ m_n = m_{n - 1} + \frac{a_n - m_{n - 1}}{n - 1}, $$

    which is inconsistent with the correct formula listed on StackOverflow:

    $$ m_n = m_{n - 1} + \frac{a_n - m_{n - 1}}{n}. $$

    The problem is that self.n is incremented after the new mean is computed. Could you please either correct me if I am wrong or correct the code?

    opened by WhiteTeaDragon 4
  • a basic question:torch.randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3`

    a basic question:torch.randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3`

    Traceback (most recent call last):
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 2066, in <module>
        main()
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 2060, in main
        globals = debugger.run(setup['file'], None, None, is_module)
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 1411, in run
        return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 1418, in _exec
        pydev_imports.execfile(file, globals, locals)  # execute the script
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
        exec(compile(contents+"\n", file, 'exec'), glob, loc)
      File "E:/pythonSpace/byol-a/train.py", line 132, in <module>
        main(audio_dir=base_path + '1/', epochs=100)
      File "E:/pythonSpace/byol-a/train.py", line 112, in main
        learner = BYOLALearner(model, cfg.lr, cfg.shape,
      File "E:/pythonSpace/byol-a/train.py", line 56, in __init__
        self.learner = BYOL(model, image_size=shape, **kwargs)
      File "D:\min\envs\torch1_7_1\lib\site-packages\byol_pytorch\byol_pytorch.py", line 211, in __init__
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
    TypeError: randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3
    
    Not_an_issue 
    opened by a1030076395 3
  • Question about comments in the train.py

    Question about comments in the train.py

    https://github.com/nttcslab/byol-a/blob/master/train.py

    At line 67, there is comments for the shape of input.

            # in fact, it should be (B, 1, F, T), e.g. (256, 1, 64, 96) where 64 is the number of mel bins
            paired_inputs = torch.cat(paired_inputs) # [(B,1,T,F), (B,1,T,F)] -> (2*B,1,T,F)
    

    image

    However, it is different from the descriptions in config.yml file

    # Shape of loh-mel spectrogram [F, T].
    shape: [64, 96]
    
    bug 
    opened by ChenyangLEI 2
  • Doubt in paper

    Doubt in paper

    Hi there,

    Section 4, subsection A, part 1 from your paper says:

     The number of frames, T, in one segment was 96 in pretraining, which corresponds to 1,014ms. 
    

    However, the previous line says the hop size used was 10ms. So according to this 96 would mean 960ms?

    Am I understanding something wrong here?

    Thank You in advance!

    question 
    opened by Sreyan88 2
  • Random crop is not working.

    Random crop is not working.

    https://github.com/nttcslab/byol-a/blob/60cebdc514951e6b42e18e40a2537a01a39ad47b/byol_a/dataset.py#L80-L82

    If len(wav) > self.unit_length, length_adj will be a negative value. So start will be 0. If wav (before pad) is shorter than unit length, length_adj == 0 after padding. So start is always 0. So It will only perform a certain area of crop from 0 to self.unit_length (cropped_wav == wav[0: self.unit_length]), not random crop.

    So I think line 80 should be changed to length_adj = len(wav) - self.unit_length .

    bug 
    opened by JUiscoming 2
  • Doubt in RunningNorm

    Doubt in RunningNorm

    Hi There, great repo!

    I think I have misunderstood something wrong with the RunningNorm function. The function expects the size of an epoch, however, your implementation passes the size of the entire dataset.

    Is it a bug? Or is there a problem with my understanding?

    Thank You!

    question 
    opened by Sreyan88 2
  • How to interpret the performance

    How to interpret the performance

    Hi, it' s a great work, but how can I understance the performance metric? For example, VoxCeleb1 is usually for speaker verification, shouldn't we measure EER?

    opened by ranchlai 2
  • Finetuning of BYOL-A

    Finetuning of BYOL-A

    Hi,

    your paper is super interesting. I have a question regarding the downstream tasks. If I understand the paper correctly, you used a single linear layer for the downstream tasks which only used the sum of mean and max of the representation over time as input.

    Did you try to finetune BYOL-A end-to-end after pretraining to the downstream tasks? In the case of TRILL they were able to improve the performance even further by finetuning the whole model end-to-end. Is there a specific reason why this is not possible with BYOL-A?

    questions 
    opened by mschiwek 1
  • Missing scaling of validation samples in evaluate.py

    Missing scaling of validation samples in evaluate.py

    https://github.com/nttcslab/byol-a/blob/master/evaluate.py#L112

    It also needs: X_val = scaler.transform(X_val), or validation acc & loss will be invalid. This can be one of the reasons why we see lower performance when I tried to get official performances...

    bug 
    opened by daisukelab 0
Releases(v2.0.0)
Owner
NTT Communication Science Laboratories
NTT Communication Science Laboratories
Official code repository for the publication "Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons"

Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons This repository contains the code to repr

Computational Neuroscience, University of Bern 3 Aug 04, 2022
EPSANet:An Efficient Pyramid Split Attention Block on Convolutional Neural Network

EPSANet:An Efficient Pyramid Split Attention Block on Convolutional Neural Network This repo contains the official Pytorch implementaion code and conf

Hu Zhang 175 Jan 07, 2023
This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

Quinn Herden 1 Feb 04, 2022
Rank1 Conversation Emotion Detection Task

Rank1-Conversation_Emotion_Detection_Task accuracy macro-f1 recall 0.826 0.7544 0.719 基于预训练模型和时序预测模型的对话情感探测任务 1 摘要 针对对话情感探测任务,本文将其分为文本分类和时间序列预测两个子任务,分

Yuchen Han 2 Nov 28, 2021
Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data.

Deep Learning Dataset Maker Deep Learning Datasets Maker is a QGIS plugin to make datasets creation easier for raster and vector data. How to use Down

deepbands 25 Dec 15, 2022
Author's PyTorch implementation of TD3+BC, a simple variant of TD3 for offline RL

A Minimalist Approach to Offline Reinforcement Learning TD3+BC is a simple approach to offline RL where only two changes are made to TD3: (1) a weight

Scott Fujimoto 193 Dec 23, 2022
WTTE-RNN a framework for churn and time to event prediction

WTTE-RNN Weibull Time To Event Recurrent Neural Network A less hacky machine-learning framework for churn- and time to event prediction. Forecasting p

Egil Martinsson 727 Dec 28, 2022
Experimenting with computer vision techniques to generate annotated image datasets from gameplay recordings automatically.

Experimenting with computer vision techniques to generate annotated image datasets from gameplay recordings automatically. The collected data will then be used to train a deep neural network that can

Martin Valchev 3 Apr 24, 2022
DWIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data.

DWIPrep: A Robust Preprocessing Pipeline for dMRI Data DWIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data. The transp

Gal Ben-Zvi 1 Jan 09, 2023
Contains code for Deep Kernelized Dense Geometric Matching

DKM - Deep Kernelized Dense Geometric Matching Contains code for Deep Kernelized Dense Geometric Matching We provide pretrained models and code for ev

Johan Edstedt 83 Dec 23, 2022
Code for "Unsupervised Layered Image Decomposition into Object Prototypes" paper

DTI-Sprites Pytorch implementation of "Unsupervised Layered Image Decomposition into Object Prototypes" paper Check out our paper and webpage for deta

40 Dec 22, 2022
A certifiable defense against adversarial examples by training neural networks to be provably robust

DiffAI v3 DiffAI is a system for training neural networks to be provably robust and for proving that they are robust. The system was developed for the

SRI Lab, ETH Zurich 202 Dec 13, 2022
Audio Visual Emotion Recognition using TDA

Audio Visual Emotion Recognition using TDA RAVDESS database with two datasets analyzed: Video and Audio dataset: Audio-Dataset: https://www.kaggle.com

Combinatorial Image Analysis research group 3 May 11, 2022
Few-Shot-Intent-Detection includes popular challenging intent detection datasets with/without OOS queries and state-of-the-art baselines and results.

Few-Shot-Intent-Detection Few-Shot-Intent-Detection is a repository designed for few-shot intent detection with/without Out-of-Scope (OOS) intents. It

Jian-Guo Zhang 73 Dec 26, 2022
Multiple Object Tracking with Yolov5!

Tracking with yolov5 This implementation is for who need to tracking multi-object only with detector. You can easily track mult-object with your well

9 Nov 08, 2022
Sync2Gen Code for ICCV 2021 paper: Scene Synthesis via Uncertainty-Driven Attribute Synchronization

Sync2Gen Code for ICCV 2021 paper: Scene Synthesis via Uncertainty-Driven Attribute Synchronization 0. Environment Environment: python 3.6 and cuda 10

Haitao Yang 62 Dec 30, 2022
A pytorch implementation of faster RCNN detection framework (Use detectron2, it's a masterpiece)

Notice(2019.11.2) This repo was built back two years ago when there were no pytorch detection implementation that can achieve reasonable performance.

Ruotian(RT) Luo 1.8k Jan 01, 2023
My coursework for Machine Learning (2021 Spring) at National Taiwan University (NTU)

Machine Learning 2021 Machine Learning (NTU EE 5184, Spring 2021) Instructor: Hung-yi Lee Course Website : (https://speech.ee.ntu.edu.tw/~hylee/ml/202

100 Dec 26, 2022
Tool for working with Y-chromosome data from YFull and FTDNA

ycomp ycomp is a tool for working with Y-chromosome data from YFull and FTDNA. Run ycomp -h for information on how to use the program. Installation Th

Alexander Regueiro 2 Jun 18, 2022
Rendering Point Clouds with Compute Shaders

Compute Shader Based Point Cloud Rendering This repository contains the source code to our techreport: Rendering Point Clouds with Compute Shaders and

Markus Schütz 460 Jan 05, 2023