Pytorch Implementation of paper "Noisy Natural Gradient as Variational Inference"

Overview

Noisy Natural Gradient as Variational Inference

PyTorch implementation of Noisy Natural Gradient as Variational Inference.

Requirements

  • Python 3
  • Pytorch
  • visdom

Comments

  • This paper is about how to optimize bayesian neural network which has matrix variate gaussian distribution.
  • This implementation contains Noisy Adam optimizer which is for Fully Factorized Gaussian(FFG) distribution, and Noisy KFAC optimizer which is for Matrix Variate Gaussian(MVG) distribution.
  • These optimizers only work with bayesian network which has specific structure that I will mention below.
  • Currently only linear layer is available.

Experimental comments

  • I addded a lr scheduler to noisy KFAC because loss is exploded during training. I guess this happens because of slight approximation.
  • For MNIST training noisy KFAC is 15-20x slower than noisy Adam, as mentioned in paper.
  • I guess the noisy KFAC needs more epochs to train simple neural network structure like 2 linear layers.

Usage

Currently only MNIST dataset are currently supported, and only fully connected layer is implemented.

Options

  • model : Fully Factorized Gaussian(FFG) or Matrix Variate Gaussian(MVG)
  • n : total train dataset size. need this value for optimizer.
  • eps : parameter for optimizer. Default to 1e-8.
  • initial_size : initial input tensor size. Default to 784, size of MNIST data.
  • label_size : label size. Default to 10, size of MNIST label.

More details in option_parser.py

Train

$ python train.py --model=FFG --batch_size=100 --lr=1e-3 --dataset=MNIST
$ python train.py --model=MVG --batch_size=100 --lr=1e-2 --dataset=MNIST --n=60000

Visualize

  • To visualize intermediate results and loss plots, run python -m visdom.server and go to the URL http://localhost:8097

Test

$ python test.py --epoch=20

Training Graphs

1. MNIST

  • network is consist of 2 linear layers.
  • FFG optimized by noisy Adam : epoch 20, lr 1e-3

  • MVG optimized by noisy KFAC : epoch 100, lr 1e-2, decay 0.1 for every 30 epochs
  • Need to tune learning rate.

Implementation detail

  • Optimizing parameter procedure is consists of 2 steps, Calculating gradient and Applying to bayeisan parameters.
  • Before forward, network samples parameters with means & variances.
  • Usually calling step function updates parameters, but not this case. After calling step function, you have to update bayesian parameters. Look at the ffg_model.py

TODOs

  • More benchmark cases
  • Supports bayesian convolution
  • Implement Block Tridiagonal Covariance, which is dependent between layers.

Code reference

Visualization code(visualizer.py, utils.py) references to pytorch-CycleGAN-and-pix2pix(https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) by Jun-Yan Zhu

Author

Tony Kim

Owner
Tony JiHyun Kim
CEO/Tech Lead @PostAlpine Co., Ltd.
Tony JiHyun Kim
Image processing in Python

scikit-image: Image processing in Python Website (including documentation): https://scikit-image.org/ Mailing list: https://mail.python.org/mailman3/l

Image Processing Toolbox for SciPy 5.2k Dec 31, 2022
Code for "MetaMorph: Learning Universal Controllers with Transformers", Gupta et al, ICLR 2022

MetaMorph: Learning Universal Controllers with Transformers This is the code for the paper MetaMorph: Learning Universal Controllers with Transformers

Agrim Gupta 50 Jan 03, 2023
PatchMatch-RL: Deep MVS with Pixelwise Depth, Normal, and Visibility

PatchMatch-RL: Deep MVS with Pixelwise Depth, Normal, and Visibility Jae Yong Lee, Joseph DeGol, Chuhang Zou, Derek Hoiem Installation To install nece

31 Apr 19, 2022
Developing your First ML Workflow of the AWS Machine Learning Engineer Nanodegree Program

Exercises and project documentation for the 3. Developing your First ML Workflow of the AWS Machine Learning Engineer Nanodegree Program

Simona Mircheva 1 Jan 13, 2022
This is the workbook I created while I was studying for the Qiskit Associate Developer exam. I hope this becomes useful to others as it was for me :)

A Workbook for the Qiskit Developer Certification Exam Hello everyone! This is Bartu, a fellow Qiskitter. I have recently taken the Certification exam

Bartu Bisgin 66 Dec 10, 2022
Implementation of popular bandit algorithms in batch environments.

batch-bandits Implementation of popular bandit algorithms in batch environments. Source code to our paper "The Impact of Batch Learning in Stochastic

Danil Provodin 2 Sep 11, 2022
Py-faster-rcnn - Faster R-CNN (Python implementation)

py-faster-rcnn has been deprecated. Please see Detectron, which includes an implementation of Mask R-CNN. Disclaimer The official Faster R-CNN code (w

Ross Girshick 7.8k Jan 03, 2023
Deep Federated Learning for Autonomous Driving

FADNet: Deep Federated Learning for Autonomous Driving Abstract Autonomous driving is an active research topic in both academia and industry. However,

AIOZ AI 12 Dec 01, 2022
A pytorch implementation of Reading Wikipedia to Answer Open-Domain Questions.

DrQA A pytorch implementation of the ACL 2017 paper Reading Wikipedia to Answer Open-Domain Questions (DrQA). Reading comprehension is a task to produ

Runqi Yang 394 Nov 08, 2022
Implementation of Change-Based Exploration Transfer (C-BET)

Implementation of Change-Based Exploration Transfer (C-BET), as presented in Interesting Object, Curious Agent: Learning Task-Agnostic Exploration.

Simone Parisi 29 Dec 04, 2022
CS583: Deep Learning

CS583: Deep Learning

Shusen Wang 2.6k Dec 30, 2022
Implementation of Stochastic Image-to-Video Synthesis using cINNs.

Stochastic Image-to-Video Synthesis using cINNs Official PyTorch implementation of Stochastic Image-to-Video Synthesis using cINNs accepted to CVPR202

CompVis Heidelberg 135 Dec 28, 2022
Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022)

Official code of Retinal Vessel Segmentation with Pixel-wise Adaptive Filters and Consistency Training (ISBI 2022)

anonymous 14 Oct 27, 2022
Tensorflow Implementation of ECCV'18 paper: Multimodal Human Motion Synthesis

MT-VAE for Multimodal Human Motion Synthesis This is the code for ECCV 2018 paper MT-VAE: Learning Motion Transformations to Generate Multimodal Human

Xinchen Yan 36 Oct 02, 2022
discovering subdomains, hidden paths, extracting unique links

python-website-crawler discovering subdomains, hidden paths, extracting unique links pip install -r requirements.txt discover subdomain: You can give

merve 4 Sep 05, 2022
PyTorch reimplementation of hand-biomechanical-constraints (ECCV2020)

Hand Biomechanical Constraints Pytorch Unofficial PyTorch reimplementation of Hand-Biomechanical-Constraints (ECCV2020). This project reimplement foll

Hao Meng 59 Dec 20, 2022
[NeurIPS 2021] Towards Better Understanding of Training Certifiably Robust Models against Adversarial Examples | ⛰️⚠️

Towards Better Understanding of Training Certifiably Robust Models against Adversarial Examples This repository is the official implementation of "Tow

Sungyoon Lee 4 Jul 12, 2022
Reinforcement learning library in JAX.

Reinforcement learning library in JAX.

Yicheng Luo 96 Oct 30, 2022
The challenge for Quantum Coalition Hackathon 2021

Qchack 2021 Google Challenge This is a challenge for the brave 2021 qchack.io participants. Instructions Hello, intrepid qchacker, welcome to the G|o

quantumlib 18 May 04, 2022
Implementation of SSMF: Shifting Seasonal Matrix Factorization

SSMF Implementation of SSMF: Shifting Seasonal Matrix Factorization, Koki Kawabata, Siddharth Bhatia, Rui Liu, Mohit Wadhwa, Bryan Hooi. NeurIPS, 2021

Koki Kawabata 9 Jun 10, 2022