Neural Turing Machines (NTM) - PyTorch Implementation

Overview

PyTorch Neural Turing Machine (NTM)

PyTorch implementation of Neural Turing Machines (NTM).

An NTM is a memory augumented neural network (attached to external memory) where the interactions with the external memory (address, read, write) are done using differentiable transformations. Overall, the network is end-to-end differentiable and thus trainable by a gradient based optimizer.

The NTM is processing input in sequences, much like an LSTM, but with additional benfits: (1) The external memory allows the network to learn algorithmic tasks easier (2) Having larger capacity, without increasing the network's trainable parameters.

The external memory allows the NTM to learn algorithmic tasks, that are much harder for LSTM to learn, and to maintain an internal state much longer than traditional LSTMs.

A PyTorch Implementation

This repository implements a vanilla NTM in a straight forward way. The following architecture is used:

NTM Architecture

Features

  • Batch learning support
  • Numerically stable
  • Flexible head configuration - use X read heads and Y write heads and specify the order of operation
  • copy and repeat-copy experiments agree with the paper

Copy Task

The Copy task tests the NTM's ability to store and recall a long sequence of arbitrary information. The input to the network is a random sequence of bits, ending with a delimiter. The sequence lengths are randomised between 1 to 20.

Training

Training convergence for the copy task using 4 different seeds (see the notebook for details)

NTM Convergence

The following plot shows the cost per sequence length during training. The network was trained with seed=10 and shows fast convergence. Other seeds may not perform as well but should converge in less than 30K iterations.

NTM Convergence

Evaluation

Here is an animated GIF that shows how the model generalize. The model was evaluated after every 500 training samples, using the target sequence shown in the upper part of the image. The bottom part shows the network output at any given training stage.

Copy Task

The following is the same, but with sequence length = 80. Note that the network was trained with sequences of lengths 1 to 20.

Copy Task


Repeat Copy Task

The Repeat Copy task tests whether the NTM can learn a simple nested function, and invoke it by learning to execute a for loop. The input to the network is a random sequence of bits, followed by a delimiter and a scalar value that represents the number of repetitions to output. The number of repetitions, was normalized to have zero mean and variance of one (as in the paper). Both the length of the sequence and the number of repetitions are randomised between 1 to 10.

Training

Training convergence for the repeat-copy task using 4 different seeds (see the notebook for details)

NTM Convergence

Evaluation

The following image shows the input presented to the network, a sequence of bits + delimiter + num-reps scalar. Specifically the sequence length here is eight and the number of repetitions is five.

Repeat Copy Task

And here's the output the network had predicted:

Repeat Copy Task

Here's an animated GIF that shows how the network learns to predict the targets. Specifically, the network was evaluated in each checkpoint saved during training with the same input sequence.

Repeat Copy Task

Installation

The NTM can be used as a reusable module, currently not packaged though.

  1. Clone repository
  2. Install PyTorch
  3. pip install -r requirements.txt

Usage

Execute ./train.py

usage: train.py [-h] [--seed SEED] [--task {copy,repeat-copy}] [-p PARAM]
                [--checkpoint-interval CHECKPOINT_INTERVAL]
                [--checkpoint-path CHECKPOINT_PATH]
                [--report-interval REPORT_INTERVAL]

optional arguments:
  -h, --help            show this help message and exit
  --seed SEED           Seed value for RNGs
  --task {copy,repeat-copy}
                        Choose the task to train (default: copy)
  -p PARAM, --param PARAM
                        Override model params. Example: "-pbatch_size=4
                        -pnum_heads=2"
  --checkpoint-interval CHECKPOINT_INTERVAL
                        Checkpoint interval (default: 1000). Use 0 to disable
                        checkpointing
  --checkpoint-path CHECKPOINT_PATH
                        Path for saving checkpoint data (default: './')
  --report-interval REPORT_INTERVAL
                        Reporting interval
Owner
Guy Zana
I make things, author of Curated Papers
Guy Zana
Python implementation of Lightning-rod Agent, the Stack4Things board-side probe

Iotronic Lightning-rod Agent Python implementation of Lightning-rod Agent, the Stack4Things board-side probe. Free software: Apache 2.0 license Websit

2 May 19, 2022
Yolact-keras实例分割模型在keras当中的实现

Yolact-keras实例分割模型在keras当中的实现 目录 性能情况 Performance 所需环境 Environment 文件下载 Download 训练步骤 How2train 预测步骤 How2predict 评估步骤 How2eval 参考资料 Reference 性能情况 训练数

Bubbliiiing 11 Dec 26, 2022
List of awesome things around semantic segmentation 🎉

Awesome Semantic Segmentation List of awesome things around semantic segmentation 🎉 Semantic segmentation is a computer vision task in which we label

Dam Minh Tien 18 Nov 26, 2022
Out-of-boundary View Synthesis towards Full-frame Video Stabilization

Out-of-boundary View Synthesis towards Full-frame Video Stabilization Introduction | Update | Results Demo | Introduction This repository contains the

25 Oct 10, 2022
DAT4 - General Assembly's Data Science course in Washington, DC

DAT4 Course Repository Course materials for General Assembly's Data Science course in Washington, DC (12/15/14 - 3/16/15). Instructors: Sinan Ozdemir

Kevin Markham 779 Dec 25, 2022
RoMa: A lightweight library to deal with 3D rotations in PyTorch.

RoMa: A lightweight library to deal with 3D rotations in PyTorch. RoMa (which stands for Rotation Manipulation) provides differentiable mappings betwe

NAVER 90 Dec 27, 2022
The official implementation of CVPR 2021 Paper: Improving Weakly Supervised Visual Grounding by Contrastive Knowledge Distillation.

Improving Weakly Supervised Visual Grounding by Contrastive Knowledge Distillation This repository is the official implementation of CVPR 2021 paper:

9 Nov 14, 2022
HarDNeXt: Official HarDNeXt repository

HarDNeXt-Pytorch HarDNeXt: A Stage Receptive Field and Connectivity Aware Convolution Neural Network HarDNeXt-MSEG for Medical Image Segmentation in 0

5 May 26, 2022
PyTorch implementation of DirectCLR from paper Understanding Dimensional Collapse in Contrastive Self-supervised Learning

DirectCLR DirectCLR is a simple contrastive learning model for visual representation learning. It does not require a trainable projector as SimCLR. It

Meta Research 49 Dec 21, 2022
A template repository for submitting a job to the Slurm Cluster installed at the DISI - University of Bologna

Cluster di HPC con GPU per esperimenti di calcolo (draft version 1.0) Per poter utilizzare il cluster il primo passo è abilitare l'account istituziona

20 Dec 16, 2022
Implementation of the Triangle Multiplicative module, used in Alphafold2 as an efficient way to mix rows or columns of a 2d feature map, as a standalone package for Pytorch

Triangle Multiplicative Module - Pytorch Implementation of the Triangle Multiplicative module, used in Alphafold2 as an efficient way to mix rows or c

Phil Wang 22 Oct 28, 2022
Keras Model Implementation Walkthrough

Keras Model Implementation Walkthrough

Luke Wood 17 Sep 27, 2022
A complete, self-contained example for training ImageNet at state-of-the-art speed with FFCV

ffcv ImageNet Training A minimal, single-file PyTorch ImageNet training script designed for hackability. Run train_imagenet.py to get... ...high accur

FFCV 92 Dec 31, 2022
Implementation of Convolutional enhanced image Transformer

CeiT : Convolutional enhanced image Transformer This is an unofficial PyTorch implementation of Incorporating Convolution Designs into Visual Transfor

Rishikesh (ऋषिकेश) 82 Dec 13, 2022
A testcase generation tool for Persistent Memory Programs.

PMFuzz PMFuzz is a testcase generation tool to generate high-value tests cases for PM testing tools (XFDetector, PMDebugger, PMTest and Pmemcheck) If

Systems Research at ShiftLab 14 Jul 24, 2022
This repository implements Douzero's interface to IGCA.

douzero-interface-for-ICGA This repository implements Douzero's interface to ICGA. ./douzero: This directory stores Doudizhu AI projects. ./interface:

zhanggenjin 4 Aug 07, 2022
Weakly Supervised Posture Mining with Reverse Cross-entropy for Fine-grained Classification

Fine-grainedImageClassification Weakly Supervised Posture Mining with Reverse Cross-entropy for Fine-grained Classification We trained model here: lin

ZhenchaoTang 14 Oct 21, 2022
AutoPentest-DRL: Automated Penetration Testing Using Deep Reinforcement Learning

AutoPentest-DRL: Automated Penetration Testing Using Deep Reinforcement Learning AutoPentest-DRL is an automated penetration testing framework based o

Cyber Range Organization and Design Chair 217 Jan 01, 2023
Tello Drone Trajectory Tracking

With this library you can track the trajectory of your tello drone or swarm of drones in real time.

Kamran Asgarov 2 Oct 12, 2022