An implementation of DeepMind's Relational Recurrent Neural Networks in PyTorch.

Overview

relational-rnn-pytorch

An implementation of DeepMind's Relational Recurrent Neural Networks (Santoro et al. 2018) in PyTorch.

Relational Memory Core (RMC) module is originally from official Sonnet implementation. However, currently they do not provide a full language modeling benchmark code.

This repo is a port of RMC with additional comments. It features a full-fledged word language modeling benchmark vs. traditional LSTM.

It supports any arbitrary word token-based text dataset, including WikiText-2 & WikiText-103.

Both RMC & LSTM models support adaptive softmax for much lower memory usage of large vocabulary dataset. RMC supports PyTorch's DataParallel, so you can easily experiment with a multi-GPU setup.

benchmark codes are hard-forked from official PyTorch word-language-model example

It also features an N-th farthest synthetic task from the paper (see below).

Requirements

PyTorch 0.4.1 or later (Tested on 1.0.0) & Python 3.6

Examples

python train_rmc.py --cuda for full training & test run of RMC with GPU.

python train_rmc.py --cuda --adaptivesoftmax --cutoffs 1000 5000 20000 if using large vocabulary dataset (like WikiText-103) to fit all the tensors in the VRAM.

python generate_rmc.py --cuda for generating sentences from the trained model.

python train_rnn.py --cuda for full training & test run of traditional RNN with GPU.

All default hyperparameters of RMC & LSTM are results from a two-week experiment using WikiText-2.

Data Preparation

Tested with WikiText-2 and WikiText-103. WikiText-2 is bundled.

Create a subfolder inside ./data and place word-level train.txt, valid.txt, and test.txt inside the subfolder.

Specify --data=(subfolder name) and you are good to go.

The code performs tokenization at the first training run, and the corpus is saved as pickle. The code will load the pickle file after the first run.

WikiText-2 Benchmark Results

Both RMC & LSTM have ~11M parameters. Please refer to the training code for details on hyperparameters.

Models Valid Perplexity Test Perplexity Forward pass ms/batch (TITAN Xp) Forward pass ms/batch (TITAN V)
LSTM (CuDNN) 111.31 105.56 26~27 40~41
LSTM (For Loop) Same as CuDNN Same as CuDNN 30~31 60~61
RMC 112.77 107.21 110~130 220~230

RMC can reach a comparable performance to LSTM (with heavy hyperparameter search), but it turns out that the RMC is very slow. The multi-head self-attention at every time step may be the culprit here. Using LSTMCell with for loop (which is more "fair" benchmark for RMC) slows down the forward pass, but it's still much faster.

Please also note that the hyperparameter for RMC is a worst-case scenario in terms of speed, because it used a single memory slot (as described in the paper) and did not benefit from a row-wise weight sharing from multi-slot memory.

Interesting to note here is that the speed is slower in TITAN V than TITAN Xp. The reason might be that the models are relatively small and the model calls small linear operations frequently.

Maybe TITAN Xp (~1,900Mhz unlocked CUDA clock speed vs. TITAN V's 1,335Mhz limit) benefits from these kind of workload. Or maybe TITAN V's CUDA kernel launch latency is higher for the ops in the model.

I'm not an expert in details of CUDA. Please share your results!

RMC Hyperparameter Search Results

Attention parameters tend to overfit the WikiText-2. reducing the hyperparmeters for attention (key_size) can combat the overfitting.

Applying dropout at the output logit before the softmax (like the LSTM one) helped preventing the overfitting.

embed & head size # heads attention MLP layers key size dropout at output memory slots test ppl
128 4 3 128 No 1 128.81
128 4 3 128 No 1 128.81
128 8 3 128 No 1 141.84
128 4 3 32 No 1 123.26
128 4 3 32 Yes 1 112.4
128 4 3 64 No 1 124.44
128 4 3 64 Yes 1 110.16
128 4 2 64 Yes 1 111.67
64 4 3 64 Yes 1 133.68
64 4 3 32 Yes 1 135.93
64 4 3 64 Yes 4 137.93
192 4 3 64 Yes 1 107.21
192 4 3 64 Yes 4 114.85
256 4 3 256 No 1 194.73
256 4 3 64 Yes 1 126.39

About WikiText-103

The original RMC paper presents WikiText-103 results with a larger model & batch size (6 Tesla P100, each with 64 batch size, so a total of 384. Ouch).

Using a full softmax easily blows up the VRAM. Using --adaptivesoftmax is highly recommended. If using --adaptivesoftmax, --cutoffs should be properly provided. Please refer to the original API description

I don't have such hardware and my resource is too limited to do the experiments. Benchmark result, or any other contributions are very welcome!

Nth Farthest Task

The objective of the task is: Given k randomly labelled (from 1 to k) D-dimensional vectors, identify which is the Nth farthest vector from vector M. (The answer is an integer from 1 to k.)

The specific task in the paper is: given 8 labelled 16-dimensional vectors, which is the Nth farthest vector from vector M? The vectors are labelled randomly so the model has to recognise that the Mth vector is the vector labelled as M as opposed to the vector in the Mth position in the input.

The input to the model comprises 8 40-dimensional vectors for each example. Each of these 40-dimensional vectors is structured like this:

[(vector 1) (label: which vector is it, from 1 to 8, one-hot encoded) (N, one-hot encoded) (M, one-hot encoded)] 

Example

python train_nth_farthest.py --cuda for training and testing on the Nth Farthest Task with GPU(s).

This uses the RelationalMemory class in relational_rnn_general.py, which is a version of relational_rnn_models.py without the language-modelling specific code.

Please refer totrain_nth_farthest.py for details on hyperparameter values. These are taken from Appendix A1 in the paper and from the Sonnet implementation when the hyperparameter values are not given in the paper.

Note: new examples are generated per epoch as in the Sonnet implementation. This seems to be consistent with the paper, which does not specify the number of examples used.

Experiment results

The model has been trained with a single TITAN Xp GPU for forever until it reaches 91% test accuracy. Below are the results with 3 independent runs:

The model does break the 25% barrier if trained long enough, but the wall clock time is roughly over 2~3x longer than those reported in the paper.

TODO

Experiment with different hyperparameters

Owner
Sang-gil Lee
Ph.D. student in ML/AI @ Seoul National University, South Korea. I do deep learning for sequence & generative models.
Sang-gil Lee
Python based framework for Automatic AI for Regression and Classification over numerical data.

Python based framework for Automatic AI for Regression and Classification over numerical data. Performs model search, hyper-parameter tuning, and high-quality Jupyter Notebook code generation.

BlobCity, Inc 141 Dec 21, 2022
Image-to-image translation with conditional adversarial nets

pix2pix Project | Arxiv | PyTorch Torch implementation for learning a mapping from input images to output images, for example: Image-to-Image Translat

Phillip Isola 9.3k Jan 08, 2023
Towers of Babel: Combining Images, Language, and 3D Geometry for Learning Multimodal Vision. ICCV 2021.

Towers of Babel: Combining Images, Language, and 3D Geometry for Learning Multimodal Vision Download links and PyTorch implementation of "Towers of Ba

Blakey Wu 40 Dec 14, 2022
ManimML is a project focused on providing animations and visualizations of common machine learning concepts with the Manim Community Library.

ManimML ManimML is a project focused on providing animations and visualizations of common machine learning concepts with the Manim Community Library.

259 Jan 04, 2023
bio_inspired_min_nets_improve_the_performance_and_robustness_of_deep_networks

Code Submission for: Bio-inspired Min-Nets Improve the Performance and Robustness of Deep Networks Run with docker To build a docker environment, chan

0 Dec 09, 2021
A Pytorch Implementation of Domain adaptation of object detector using scissor-like networks

A Pytorch Implementation of Domain adaptation of object detector using scissor-like networks Please follow Faster R-CNN and DAF to complete the enviro

2 Oct 07, 2022
Easy to use Audio Tagging in PyTorch

Audio Classification, Tagging & Sound Event Detection in PyTorch Progress: Fine-tune on audio classification Fine-tune on audio tagging Fine-tune on s

sithu3 15 Dec 22, 2022
Code accompanying our NeurIPS 2021 traffic4cast challenge

Traffic forecasting on traffic movie snippets This repo contains all code to reproduce our approach to the IARAI Traffic4cast 2021 challenge. In the c

Nina Wiedemann 2 Aug 09, 2022
Python and C++ implementation of "MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation". Accepted at LXCV @ CVPR 2021.

MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation This is a PyTorch and LibTorch implementation of MarkerPose: a

Jhacson Meza 47 Nov 18, 2022
Code for "Typilus: Neural Type Hints" PLDI 2020

Typilus A deep learning algorithm for predicting types in Python. Please find a preprint here. This repository contains its implementation (src/) and

47 Nov 08, 2022
implement of SwiftNet:Real-time Video Object Segmentation

SwiftNet The official PyTorch implementation of SwiftNet:Real-time Video Object Segmentation, which has been accepted by CVPR2021. Requirements Python

haochen wang 64 Dec 14, 2022
TinyML Cookbook, published by Packt

TinyML Cookbook This is the code repository for TinyML Cookbook, published by Packt. Author: Gian Marco Iodice Publisher: Packt About the book This bo

Packt 93 Dec 29, 2022
Bayesian Meta-Learning Through Variational Gaussian Processes

vmgp This is the repository of Vivek Myers and Nikhil Sardana for our CS 330 final project, Bayesian Meta-Learning Through Variational Gaussian Proces

Vivek Myers 2 Nov 17, 2022
Official code of the paper "ReDet: A Rotation-equivariant Detector for Aerial Object Detection" (CVPR 2021)

ReDet: A Rotation-equivariant Detector for Aerial Object Detection ReDet: A Rotation-equivariant Detector for Aerial Object Detection (CVPR2021), Jiam

csuhan 334 Dec 23, 2022
NeRF visualization library under construction

NeRF visualization library using PlenOctrees, under construction pip install nerfvis Docs will be at: https://nerfvis.readthedocs.org import nerfvis s

Alex Yu 196 Jan 04, 2023
CityLearn Challenge Multi-Agent Reinforcement Learning for Intelligent Energy Management, 2020, PikaPika team

Citylearn Challenge This is the PyTorch implementation for PikaPika team, CityLearn Challenge Multi-Agent Reinforcement Learning for Intelligent Energ

bigAIdream projects 10 Oct 10, 2022
Ground truth data for the Optical Character Recognition of Historical Classical Commentaries.

OCR Ground Truth for Historical Commentaries The dataset OCR ground truth for historical commentaries (GT4HistComment) was created from the public dom

Ajax Multi-Commentary 3 Sep 08, 2022
coldcuts is an R package to automatically generate and plot segmentation drawings in R

coldcuts coldcuts is an R package that allows you to draw and plot automatically segmentations from 3D voxel arrays. The name is inspired by one of It

2 Sep 03, 2022
We will see a basic program that is basically a hint to brute force attack to crack passwords. In other words, we will make a program to Crack Any Password Using Python. Show some ❤️ by starring this repository!

Crack Any Password Using Python We will see a basic program that is basically a hint to brute force attack to crack passwords. In other words, we will

Ananya Chatterjee 11 Dec 03, 2022
Supplemental learning materials for "Fourier Feature Networks and Neural Volume Rendering"

Fourier Feature Networks and Neural Volume Rendering This repository is a companion to a lecture given at the University of Cambridge Engineering Depa

Matthew A Johnson 133 Dec 26, 2022