An implementation of DeepMind's Relational Recurrent Neural Networks in PyTorch.

Overview

relational-rnn-pytorch

An implementation of DeepMind's Relational Recurrent Neural Networks (Santoro et al. 2018) in PyTorch.

Relational Memory Core (RMC) module is originally from official Sonnet implementation. However, currently they do not provide a full language modeling benchmark code.

This repo is a port of RMC with additional comments. It features a full-fledged word language modeling benchmark vs. traditional LSTM.

It supports any arbitrary word token-based text dataset, including WikiText-2 & WikiText-103.

Both RMC & LSTM models support adaptive softmax for much lower memory usage of large vocabulary dataset. RMC supports PyTorch's DataParallel, so you can easily experiment with a multi-GPU setup.

benchmark codes are hard-forked from official PyTorch word-language-model example

It also features an N-th farthest synthetic task from the paper (see below).

Requirements

PyTorch 0.4.1 or later (Tested on 1.0.0) & Python 3.6

Examples

python train_rmc.py --cuda for full training & test run of RMC with GPU.

python train_rmc.py --cuda --adaptivesoftmax --cutoffs 1000 5000 20000 if using large vocabulary dataset (like WikiText-103) to fit all the tensors in the VRAM.

python generate_rmc.py --cuda for generating sentences from the trained model.

python train_rnn.py --cuda for full training & test run of traditional RNN with GPU.

All default hyperparameters of RMC & LSTM are results from a two-week experiment using WikiText-2.

Data Preparation

Tested with WikiText-2 and WikiText-103. WikiText-2 is bundled.

Create a subfolder inside ./data and place word-level train.txt, valid.txt, and test.txt inside the subfolder.

Specify --data=(subfolder name) and you are good to go.

The code performs tokenization at the first training run, and the corpus is saved as pickle. The code will load the pickle file after the first run.

WikiText-2 Benchmark Results

Both RMC & LSTM have ~11M parameters. Please refer to the training code for details on hyperparameters.

Models Valid Perplexity Test Perplexity Forward pass ms/batch (TITAN Xp) Forward pass ms/batch (TITAN V)
LSTM (CuDNN) 111.31 105.56 26~27 40~41
LSTM (For Loop) Same as CuDNN Same as CuDNN 30~31 60~61
RMC 112.77 107.21 110~130 220~230

RMC can reach a comparable performance to LSTM (with heavy hyperparameter search), but it turns out that the RMC is very slow. The multi-head self-attention at every time step may be the culprit here. Using LSTMCell with for loop (which is more "fair" benchmark for RMC) slows down the forward pass, but it's still much faster.

Please also note that the hyperparameter for RMC is a worst-case scenario in terms of speed, because it used a single memory slot (as described in the paper) and did not benefit from a row-wise weight sharing from multi-slot memory.

Interesting to note here is that the speed is slower in TITAN V than TITAN Xp. The reason might be that the models are relatively small and the model calls small linear operations frequently.

Maybe TITAN Xp (~1,900Mhz unlocked CUDA clock speed vs. TITAN V's 1,335Mhz limit) benefits from these kind of workload. Or maybe TITAN V's CUDA kernel launch latency is higher for the ops in the model.

I'm not an expert in details of CUDA. Please share your results!

RMC Hyperparameter Search Results

Attention parameters tend to overfit the WikiText-2. reducing the hyperparmeters for attention (key_size) can combat the overfitting.

Applying dropout at the output logit before the softmax (like the LSTM one) helped preventing the overfitting.

embed & head size # heads attention MLP layers key size dropout at output memory slots test ppl
128 4 3 128 No 1 128.81
128 4 3 128 No 1 128.81
128 8 3 128 No 1 141.84
128 4 3 32 No 1 123.26
128 4 3 32 Yes 1 112.4
128 4 3 64 No 1 124.44
128 4 3 64 Yes 1 110.16
128 4 2 64 Yes 1 111.67
64 4 3 64 Yes 1 133.68
64 4 3 32 Yes 1 135.93
64 4 3 64 Yes 4 137.93
192 4 3 64 Yes 1 107.21
192 4 3 64 Yes 4 114.85
256 4 3 256 No 1 194.73
256 4 3 64 Yes 1 126.39

About WikiText-103

The original RMC paper presents WikiText-103 results with a larger model & batch size (6 Tesla P100, each with 64 batch size, so a total of 384. Ouch).

Using a full softmax easily blows up the VRAM. Using --adaptivesoftmax is highly recommended. If using --adaptivesoftmax, --cutoffs should be properly provided. Please refer to the original API description

I don't have such hardware and my resource is too limited to do the experiments. Benchmark result, or any other contributions are very welcome!

Nth Farthest Task

The objective of the task is: Given k randomly labelled (from 1 to k) D-dimensional vectors, identify which is the Nth farthest vector from vector M. (The answer is an integer from 1 to k.)

The specific task in the paper is: given 8 labelled 16-dimensional vectors, which is the Nth farthest vector from vector M? The vectors are labelled randomly so the model has to recognise that the Mth vector is the vector labelled as M as opposed to the vector in the Mth position in the input.

The input to the model comprises 8 40-dimensional vectors for each example. Each of these 40-dimensional vectors is structured like this:

[(vector 1) (label: which vector is it, from 1 to 8, one-hot encoded) (N, one-hot encoded) (M, one-hot encoded)] 

Example

python train_nth_farthest.py --cuda for training and testing on the Nth Farthest Task with GPU(s).

This uses the RelationalMemory class in relational_rnn_general.py, which is a version of relational_rnn_models.py without the language-modelling specific code.

Please refer totrain_nth_farthest.py for details on hyperparameter values. These are taken from Appendix A1 in the paper and from the Sonnet implementation when the hyperparameter values are not given in the paper.

Note: new examples are generated per epoch as in the Sonnet implementation. This seems to be consistent with the paper, which does not specify the number of examples used.

Experiment results

The model has been trained with a single TITAN Xp GPU for forever until it reaches 91% test accuracy. Below are the results with 3 independent runs:

The model does break the 25% barrier if trained long enough, but the wall clock time is roughly over 2~3x longer than those reported in the paper.

TODO

Experiment with different hyperparameters

Owner
Sang-gil Lee
Ph.D. student in ML/AI @ Seoul National University, South Korea. I do deep learning for sequence & generative models.
Sang-gil Lee
Classification of Long Sequential Data using Circular Dilated Convolutional Neural Networks

Classification of Long Sequential Data using Circular Dilated Convolutional Neural Networks arXiv preprint: https://arxiv.org/abs/2201.02143. Architec

19 Nov 30, 2022
Poplar implementation of "Bundle Adjustment on a Graph Processor" (CVPR 2020)

Poplar Implementation of Bundle Adjustment using Gaussian Belief Propagation on Graphcore's IPU Implementation of CVPR 2020 paper: Bundle Adjustment o

Joe Ortiz 34 Dec 05, 2022
Source code for "Interactive All-Hex Meshing via Cuboid Decomposition [SIGGRAPH Asia 2021]".

Interactive All-Hex Meshing via Cuboid Decomposition Video demonstration This repository contains an interactive software to the PolyCube-based hex-me

Lingxiao Li 131 Dec 05, 2022
Implemenets the Contourlet-CNN as described in C-CNN: Contourlet Convolutional Neural Networks, using PyTorch

C-CNN: Contourlet Convolutional Neural Networks This repo implemenets the Contourlet-CNN as described in C-CNN: Contourlet Convolutional Neural Networ

Goh Kun Shun (KHUN) 10 Nov 03, 2022
Speech Recognition is an important feature in several applications used such as home automation, artificial intelligence

Speech Recognition is an important feature in several applications used such as home automation, artificial intelligence, etc. This article aims to provide an introduction on how to make use of the S

RISHABH MISHRA 1 Feb 13, 2022
An open-source Deep Learning Engine for Healthcare that aims to treat & prevent major diseases

AlphaCare Background AlphaCare is a work-in-progress, open-source Deep Learning Engine for Healthcare that aims to treat and prevent major diseases. T

Siraj Raval 44 Nov 05, 2022
Pytorch implementation of PCT: Point Cloud Transformer

PCT: Point Cloud Transformer This is a Pytorch implementation of PCT: Point Cloud Transformer.

Yi_Zhang 265 Dec 22, 2022
A Robust Non-IoU Alternative to Non-Maxima Suppression in Object Detection

Confluence: A Robust Non-IoU Alternative to Non-Maxima Suppression in Object Detection 1. 介绍 用以替代 NMS,在所有 bbox 中挑选出最优的集合。 NMS 仅考虑了 bbox 的得分,然后根据 IOU 来

44 Sep 15, 2022
Simple Text-Generator with OpenAI gpt-2 Pytorch Implementation

GPT2-Pytorch with Text-Generator Better Language Models and Their Implications Our model, called GPT-2 (a successor to GPT), was trained simply to pre

Tae-Hwan Jung 775 Jan 08, 2023
KeypointDeformer: Unsupervised 3D Keypoint Discovery for Shape Control

KeypointDeformer: Unsupervised 3D Keypoint Discovery for Shape Control Tomas Jakab, Richard Tucker, Ameesh Makadia, Jiajun Wu, Noah Snavely, Angjoo Ka

Tomas Jakab 87 Nov 30, 2022
TorchFlare is a simple, beginner-friendly, and easy-to-use PyTorch Framework train your models effortlessly.

TorchFlare TorchFlare is a simple, beginner-friendly and an easy-to-use PyTorch Framework train your models without much effort. It provides an almost

Atharva Phatak 85 Dec 26, 2022
PyTorch code of my WACV 2022 paper Improving Model Generalization by Agreement of Learned Representations from Data Augmentation

Improving Model Generalization by Agreement of Learned Representations from Data Augmentation (WACV 2022) Paper ArXiv Why it matters? When data augmen

Rowel Atienza 5 Mar 04, 2022
[CVPR 2021] Monocular depth estimation using wavelets for efficiency

Single Image Depth Prediction with Wavelet Decomposition Michaël Ramamonjisoa, Michael Firman, Jamie Watson, Vincent Lepetit and Daniyar Turmukhambeto

Niantic Labs 205 Jan 02, 2023
Vision Deep-Learning using Tensorflow, Keras.

Welcome! I am a computer vision deep learning developer working in Korea. This is my blog, and you can see everything I've studied here. https://www.n

kimminjun 6 Dec 14, 2022
HIVE: Evaluating the Human Interpretability of Visual Explanations

HIVE: Evaluating the Human Interpretability of Visual Explanations Project Page | Paper This repo provides the code for HIVE, a human evaluation frame

Princeton Visual AI Lab 16 Dec 13, 2022
1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

Lihe Yang 209 Jan 01, 2023
Pytorch Implementation of Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic

Pytorch Implementation of Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic [Paper] [Colab is coming soon] Approach Example Usage To r

170 Jan 03, 2023
Kaggle Lyft Motion Prediction for Autonomous Vehicles 4th place solution

Lyft Motion Prediction for Autonomous Vehicles Code for the 4th place solution of Lyft Motion Prediction for Autonomous Vehicles on Kaggle. Discussion

44 Jun 27, 2022
NALSM: Neuron-Astrocyte Liquid State Machine

NALSM: Neuron-Astrocyte Liquid State Machine This package is a Tensorflow implementation of the Neuron-Astrocyte Liquid State Machine (NALSM) that int

Computational Brain Lab 4 Nov 28, 2022
Earthquake detection via fiber optic cables using deep learning

Earthquake detection via fiber optic cables using deep learning Author: Fantine Huot Getting started Update the submodules After cloning the repositor

Fantine 4 Nov 30, 2022