Unofficial Pytorch Implementation of WaveGrad2

Overview

WaveGrad 2 — Unofficial PyTorch Implementation

WaveGrad 2: Iterative Refinement for Text-to-Speech Synthesis
Unofficial PyTorch+Lightning Implementation of Chen et al.(JHU, Google Brain), WaveGrad2.
Audio Samples: https://mindslab-ai.github.io/wavegrad2/

TODO

  • More training for WaveGrad-Base setup
  • Checkpoint release
  • WaveGrad-Large Decoder
  • Inference by reduced sampling steps

Requirements

Datasets

The supported datasets are

  • LJSpeech: a single-speaker English dataset consists of 13100 short audio clips of a female speaker reading passages from 7 non-fiction books, approximately 24 hours in total.
  • AISHELL-3: a Mandarin TTS dataset with 218 male and female speakers, roughly 85 hours in total.
  • etc.

We take LJSpeech as an example hereafter.

Preprocessing

  • Adjust preprocess.yaml, especially path section.
path:
  corpus_path: '/DATA1/LJSpeech-1.1' # LJSpeech corpus path
  lexicon_path: 'lexicon/librispeech-lexicon.txt'
  raw_path: './raw_data/LJSpeech'
  preprocessed_path: './preprocessed_data/LJSpeech'
  • run prepare_align.py for some preparations.
python prepare_align.py -c preprocess.yaml
  • Montreal Forced Aligner (MFA) is used to obtain the alignments between the utterances and the phoneme sequences. Alignments for the LJSpeech and AISHELL-3 datasets are provided here. You have to unzip the files in preprocessed_data/LJSpeech/TextGrid/.

  • After that, run preprocess.py.

python preprocess.py -c preprocess.yaml
  • Alternately, you can align the corpus by yourself.
  • Download the official MFA package and run it to align the corpus.
./montreal-forced-aligner/bin/mfa_align raw_data/LJSpeech/ lexicon/librispeech-lexicon.txt english preprocessed_data/LJSpeech

or

./montreal-forced-aligner/bin/mfa_train_and_align raw_data/LJSpeech/ lexicon/librispeech-lexicon.txt preprocessed_data/LJSpeech
  • And then run preprocess.py.
python preprocess.py -c preprocess.yaml

Training

  • Adjust hparameter.yaml, especially train section.
train:
  batch_size: 12 # Dependent on GPU memory size
  adam:
    lr: 3e-4
    weight_decay: 1e-6
  decay:
    rate: 0.05
    start: 25000
    end: 100000
  num_workers: 16 # Dependent on CPU cores
  gpus: 2 # number of GPUs
  loss_rate:
    dur: 1.0
  • If you want to train with other dataset, adjust data section in hparameter.yaml
data:
  lang: 'eng'
  text_cleaners: ['english_cleaners'] # korean_cleaners, english_cleaners, chinese_cleaners
  speakers: ['LJSpeech']
  train_dir: 'preprocessed_data/LJSpeech'
  train_meta: 'train.txt'  # relative path of metadata file from train_dir
  val_dir: 'preprocessed_data/LJSpeech'
  val_meta: 'val.txt'  # relative path of metadata file from val_dir'
  lexicon_path: 'lexicon/librispeech-lexicon.txt'
  • run trainer.py
python trainer.py
  • If you want to resume training from checkpoint, check parser.
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--resume_from', type =int,\
	required = False, help = "Resume Checkpoint epoch number")
parser.add_argument('-s', '--restart', action = "store_true",\
	required = False, help = "Significant change occured, use this")
parser.add_argument('-e', '--ema', action = "store_true",
	required = False, help = "Start from ema checkpoint")
args = parser.parse_args()
  • During training, tensorboard logger is logging loss, spectrogram and audio.
tensorboard --logdir=./tensorboard --bind_all

Inference

  • run inference.py
python inference.py -c <checkpoint_path> --text <'text'>

Checkpoint file will be released!

Note

Since this repo is unofficial implementation and WaveGrad2 paper do not provide several details, a slight differences between paper could exist.
We listed modifications or arbitrary setups

  • Normal LSTM without ZoneOut is applied for encoder.
  • g2p_en is applied instead of Google's unknown G2P.
  • Trained with LJSpeech datasdet instead of Google's proprietary dataset.
    • Due to dataset replacement, output audio's sampling rate becomes 22.05kHz instead of 24kHz.
  • MT + SpecAug are not implemented.
  • hyperparameters
    • train.batch_size: 12 for 2 A100 (40GB) GPUs
    • train.adam.lr: 3e-4 and train.adam.weight_decay: 1e-6
    • train.decay learning rate decay is applied during training
    • train.loss_rate: 1 as total_loss = 1 * L1_loss + 1 * duration_loss
    • ddpm.ddpm_noise_schedule: torch.linspace(1e-6, 0.01, hparams.ddpm.max_step)
    • encoder.channel is reduced to 512 from 1024 or 2048
  • Current sample page only contains samples from WaveGrad-Base decoder.
  • TODO things.

Tree

.
├── Dockerfile
├── README.md
├── dataloader.py
├── docs
│   ├── spec.png
│   ├── tb.png
│   └── tblogger.png
├── hparameter.yaml
├── inference.py
├── lexicon
│   ├── librispeech-lexicon.txt
│   └── pinyin-lexicon-r.txt
├── lightning_model.py
├── model
│   ├── base.py
│   ├── downsampling.py
│   ├── encoder.py
│   ├── gaussian_upsampling.py
│   ├── interpolation.py
│   ├── layers.py
│   ├── linear_modulation.py
│   ├── nn.py
│   ├── resampling.py
│   ├── upsampling.py
│   └── window.py
├── prepare_align.py
├── preprocess.py
├── preprocess.yaml
├── preprocessor
│   ├── ljspeech.py
│   └── preprocessor.py
├── text
│   ├── __init__.py
│   ├── cleaners.py
│   ├── cmudict.py
│   ├── numbers.py
│   └── symbols.py
├── trainer.py
├── utils
│   ├── mel.py
│   ├── stft.py
│   ├── tblogger.py
│   └── utils.py
└── wavegrad2_tester.ipynb

Author

This code is implemented by

Special thanks to

References

This implementation uses code from following repositories:

The webpage for the audio samples uses a template from:

The audio samples on our webpage(TBD) are partially derived from:

  • LJSpeech: a single-speaker English dataset consists of 13100 short audio clips of a female speaker reading passages from 7 non-fiction books, approximately 24 hours in total.
  • WaveGrad2 Official Github.io
Owner
MINDs Lab
MINDsLab provides AI platform and various AI engines based on deep machine learning.
MINDs Lab
Physics-Informed Neural Networks (PINN) and Deep BSDE Solvers of Differential Equations for Scientific Machine Learning (SciML) accelerated simulation

NeuralPDE NeuralPDE.jl is a solver package which consists of neural network solvers for partial differential equations using scientific machine learni

SciML Open Source Scientific Machine Learning 680 Jan 02, 2023
Group-Free 3D Object Detection via Transformers

Group-Free 3D Object Detection via Transformers By Ze Liu, Zheng Zhang, Yue Cao, Han Hu, Xin Tong. This repo is the official implementation of "Group-

Ze Liu 213 Dec 07, 2022
Official implementation of SIGIR'2021 paper: "Sequential Recommendation with Graph Neural Networks".

SURGE: Sequential Recommendation with Graph Neural Networks This is our TensorFlow implementation for the paper: Sequential Recommendation with Graph

FIB LAB, Tsinghua University 53 Dec 26, 2022
Code for NAACL 2021 full paper "Efficient Attentions for Long Document Summarization"

LongDocSum Code for NAACL 2021 paper "Efficient Attentions for Long Document Summarization" This repository contains data and models needed to reprodu

56 Jan 02, 2023
Multispectral Object Detection with Yolov5

Multispectral-Object-Detection Intro Official Code for Cross-Modality Fusion Transformer for Multispectral Object Detection. Multispectral Object Dete

Richard Fang 121 Jan 01, 2023
Road Crack Detection Using Deep Learning Methods

Road-Crack-Detection-Using-Deep-Learning-Methods This is my Diploma Thesis ¨Road Crack Detection Using Deep Learning Methods¨ under the supervision of

Aggelos Katsaliros 3 May 03, 2022
Video Matting via Consistency-Regularized Graph Neural Networks

Video Matting via Consistency-Regularized Graph Neural Networks Project Page | Real Data | Paper Installation Our code has been tested on Python 3.7,

41 Dec 26, 2022
Microscopy Image Cytometry Toolkit

Cytokit Cytokit is a collection of tools for quantifying and analyzing properties of individual cells in large fluorescent microscopy datasets with a

Hammer Lab 106 Jan 06, 2023
Repository for "Space-Time Correspondence as a Contrastive Random Walk" (NeurIPS 2020)

Space-Time Correspondence as a Contrastive Random Walk This is the repository for Space-Time Correspondence as a Contrastive Random Walk, published at

A. Jabri 239 Dec 27, 2022
The Malware Open-source Threat Intelligence Family dataset contains 3,095 disarmed PE malware samples from 454 families

MOTIF Dataset The Malware Open-source Threat Intelligence Family (MOTIF) dataset contains 3,095 disarmed PE malware samples from 454 families, labeled

Booz Allen Hamilton 112 Dec 13, 2022
PAWS 🐾 Predicting View-Assignments with Support Samples

This repo provides a PyTorch implementation of PAWS (predicting view assignments with support samples), as described in the paper Semi-Supervised Learning of Visual Features by Non-Parametrically Pre

Facebook Research 437 Dec 23, 2022
Learning trajectory representations using self-supervision and programmatic supervision.

Trajectory Embedding for Behavior Analysis (TREBA) Implementation from the paper: Jennifer J. Sun, Ann Kennedy, Eric Zhan, David J. Anderson, Yisong Y

58 Jan 06, 2023
Multi-scale discriminator feature-wise loss function

Multi-Scale Discriminative Feature Loss This repository provides code for Multi-Scale Discriminative Feature (MDF) loss for image reconstruction algor

Graphics and Displays group - University of Cambridge 76 Dec 12, 2022
Official implementation of Monocular Quasi-Dense 3D Object Tracking

Monocular Quasi-Dense 3D Object Tracking Monocular Quasi-Dense 3D Object Tracking (QD-3DT) is an online framework detects and tracks objects in 3D usi

Visual Intelligence and Systems Group 441 Dec 20, 2022
imbalanced-DL: Deep Imbalanced Learning in Python

imbalanced-DL: Deep Imbalanced Learning in Python Overview imbalanced-DL (imported as imbalanceddl) is a Python package designed to make deep imbalanc

NTUCSIE CLLab 19 Dec 28, 2022
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
X-modaler is a versatile and high-performance codebase for cross-modal analytics.

X-modaler X-modaler is a versatile and high-performance codebase for cross-modal analytics. This codebase unifies comprehensive high-quality modules i

910 Dec 28, 2022
Proposal, Tracking and Segmentation (PTS): A Cascaded Network for Video Object Segmentation

Proposal, Tracking and Segmentation (PTS): A Cascaded Network for Video Object Segmentation By Qiang Zhou*, Zilong Huang*, Lichao Huang, Han Shen, Yon

Forest 117 Apr 01, 2022
PyTorch implementation of the Pose Residual Network (PRN)

Pose Residual Network This repository contains a PyTorch implementation of the Pose Residual Network (PRN) presented in our ECCV 2018 paper: Muhammed

Salih Karagoz 289 Nov 28, 2022
Predicting Price of house by considering ,house age, Distance from public transport

House-Price-Prediction Predicting Price of house by considering ,house age, Distance from public transport, No of convenient stores around house etc..

Musab Jaleel 1 Jan 08, 2022