Implement Decoupled Neural Interfaces using Synthetic Gradients in Pytorch

Overview

disclaimer: this code is modified from pytorch-tutorial

Image classification with synthetic gradient in Pytorch

I implement the Decoupled Neural Interfaces using Synthetic Gradients in pytorch. The paper uses synthetic gradient to decouple the layers among the network, which is pretty interesting since we won't suffer from update lock anymore. I test my model in mnist and almost the same performance, compared to the model updated with backpropagation.

Requirement

  • pytorch
  • python 3.5
  • torchvision
  • seaborn (optional)
  • matplotlib (optional)

TODO

  • use multi-threading on gpu to analyze the speed

What's synthetic gradients?

We ofter optimize NN by backpropogation, which is usually implemented in some well-known framework. However, is there another way for the layers in NN to communicate with other layers? Here comes the synthetic gradients! It gives us a way to allow neural networks to communicate, to learn to send messages between themselves, in a decoupled, scalable manner paving the way for multiple neural networks to communicate with each other or improving the long term temporal dependency of recurrent networks.
The neuron in each layer will automatically produces an error signal(δa_head) from synthetic-layers and do the optimzation. And how did the error signal generated? Actually, the network still does the backpropogation. While the error signal(δa) from the objective function is not used to optimize the neuron in the network, it is used to optimize the error signal(δa_head) produced by the synthetic-layer. The following is the illustration from the paper:

Result

Feed-Forward Network

Achieve accuracy=96% (compared to the original model, which with accuracy=97%)

classify loss gradient loss(log level)
cDNI classify loss cDNI gradient loss(log level)

Convolutional Neural Network

Achieve accuracy=96%, (compared to the original model, which with accuracy=98%)

classify loss gradient loss(log level)

Usage

Right now I just implement the FCN, CNN versions, which are set as the default network structure.

Run network with synthetic gradient:

python main.py --model_type mlp

or

python main.py --model_type cnn

Run network with conditioned synthetic gradient:

python main.py --model_type mlp --conditioned True

Run vanilla network, from pytorch-tutorial

python mlp.py

or

python cnn.py

Reference

Owner
Andrew
Andrew
Learning to Draw: Emergent Communication through Sketching

Learning to Draw: Emergent Communication through Sketching This is the official code for the paper "Learning to Draw: Emergent Communication through S

19 Jul 22, 2022
code for TCL: Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022

Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022 News (03/16/2022) upload retrieval checkpoints finetuned on COCO and Flickr T

187 Jan 02, 2023
Let's create a tool to convert Thailand budget from PDF to CSV.

thailand-budget-pdf2csv Let's create a tool to convert Thailand Government Budgeting from PDF to CSV! รวมพลัง Dev แปลงงบ จาก PDF สู่ Machine-readable

Kao.Geek 88 Dec 19, 2022
An educational tool to introduce AI planning concepts using mobile manipulator robots.

JEDAI Explains Decision-Making AI Virtual Machine Image The recommended way of using JEDAI is to use pre-configured Virtual Machine image that is avai

Autonomous Agents and Intelligent Robots 13 Nov 15, 2022
RGB-stacking 🛑 🟩 🔷 for robotic manipulation

RGB-stacking 🛑 🟩 🔷 for robotic manipulation BLOG | PAPER | VIDEO Beyond Pick-and-Place: Tackling Robotic Stacking of Diverse Shapes, Alex X. Lee*,

DeepMind 95 Dec 23, 2022
Jetson Nano-based smart camera system that measures crowd face mask usage in real-time.

MaskCam MaskCam is a prototype reference design for a Jetson Nano-based smart camera system that measures crowd face mask usage in real-time, with all

BDTI 212 Dec 29, 2022
Multi Task Vision and Language

12-in-1: Multi-Task Vision and Language Representation Learning Please cite the following if you use this code. Code and pre-trained models for 12-in-

Facebook Research 712 Dec 19, 2022
Weighing Counts: Sequential Crowd Counting by Reinforcement Learning

LibraNet This repository includes the official implementation of LibraNet for crowd counting, presented in our paper: Weighing Counts: Sequential Crow

Hao Lu 18 Nov 05, 2022
4D Human Body Capture from Egocentric Video via 3D Scene Grounding

4D Human Body Capture from Egocentric Video via 3D Scene Grounding [Project] [Paper] Installation: Our method requires the same dependencies as SMPLif

Miao Liu 37 Nov 08, 2022
An example of semantic segmentation using tensorflow in eager execution.

Semantic segmentation using Tensorflow eager execution Requirement Python 2.7+ Tensorflow-gpu OpenCv H5py Scikit-learn Numpy Imgaug Train with eager e

Iñigo Alonso Ruiz 25 Sep 29, 2022
Repo for my Tensorflow/Keras CV experiments. Mostly revolving around the Danbooru20xx dataset

SW-CV-ModelZoo Repo for my Tensorflow/Keras CV experiments. Mostly revolving around the Danbooru20xx dataset Framework: TF/Keras 2.7 Training SQLite D

20 Dec 27, 2022
PyTorch implementation for "Mining Latent Structures with Contrastive Modality Fusion for Multimedia Recommendation"

MIRCO PyTorch implementation for paper: Latent Structures Mining with Contrastive Modality Fusion for Multimedia Recommendation Dependencies Python 3.

Big Data and Multi-modal Computing Group, CRIPAC 9 Dec 08, 2022
Repo for the Video Person Clustering dataset, and code for the associated paper

Video Person Clustering Repo for the Video Person Clustering dataset, and code for the associated paper. This reporsitory contains the Video Person Cl

Andrew Brown 47 Nov 02, 2022
Prososdy Morph: A python library for manipulating pitch and duration in an algorithmic way, for resynthesizing speech.

ProMo (Prosody Morph) Questions? Comments? Feedback? Chat with us on gitter! A library for manipulating pitch and duration in an algorithmic way, for

Tim 71 Jan 02, 2023
Reference models and tools for Cloud TPUs.

Cloud TPUs This repository is a collection of reference models and tools used with Cloud TPUs. The fastest way to get started training a model on a Cl

5k Jan 05, 2023
PyTorch reimplementation of REALM and ORQA

PyTorch reimplementation of REALM and ORQA

Li-Huai (Allan) Lin 17 Aug 20, 2022
Fusion-DHL: WiFi, IMU, and Floorplan Fusion for Dense History of Locations in Indoor Environments

Fusion-DHL: WiFi, IMU, and Floorplan Fusion for Dense History of Locations in Indoor Environments Paper: arXiv (ICRA 2021) Video : https://youtu.be/CC

Sachini Herath 68 Jan 03, 2023
This repository contains the code for our fast polygonal building extraction from overhead images pipeline.

Polygonal Building Segmentation by Frame Field Learning We add a frame field output to an image segmentation neural network to improve segmentation qu

Nicolas Girard 186 Jan 04, 2023
PN-Net a neural field-based framework for depth estimation from single-view RGB images.

PN-Net We present a neural field-based framework for depth estimation from single-view RGB images. Rather than representing a 2D depth map as a single

1 Oct 02, 2021
PyTorch reimplementation of minimal-hand (CVPR2020)

Minimal Hand Pytorch Unofficial PyTorch reimplementation of minimal-hand (CVPR2020). you can also find in youtube or bilibili bare hand youtube or bil

Hao Meng 228 Dec 29, 2022