A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

Overview

CapsNet-Tensorflow

Contributions welcome License Gitter

A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

capsVSneuron

Notes:

  1. The current version supports MNIST and Fashion-MNIST datasets. The current test accuracy for MNIST is 99.64%, and Fashion-MNIST 90.60%, see details in the Results section
  2. See dist_version for multi-GPU support
  3. Here(知乎) is an article explaining my understanding of the paper. It may be helpful in understanding the code.

Important:

If you need to apply CapsNet model to your own datasets or build up a new model with the basic block of CapsNet, please follow my new project CapsLayer, which is an advanced library for capsule theory, aiming to integrate capsule-relevant technologies, provide relevant analysis tools, develop related application examples, and promote the development of capsule theory. For example, you can use capsule layer block in your code easily with the API capsLayer.layers.fully_connected and capsLayer.layers.conv2d

Requirements

  • Python
  • NumPy
  • Tensorflow>=1.3
  • tqdm (for displaying training progress info)
  • scipy (for saving images)

Usage

Step 1. Download this repository with git or click the download ZIP button.

$ git clone https://github.com/naturomics/CapsNet-Tensorflow.git
$ cd CapsNet-Tensorflow

Step 2. Download MNIST or Fashion-MNIST dataset. In this step, you have two choices:

  • a) Automatic downloading with download_data.py script
$ python download_data.py   (for mnist dataset)
$ python download_data.py --dataset fashion-mnist --save_to data/fashion-mnist (for fashion-mnist dataset)
  • b) Manual downloading with wget or other tools, move and extract dataset into data/mnist or data/fashion-mnist directory, for example:
$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
$ gunzip data/mnist/*.gz

Step 3. Start the training(Using the MNIST dataset by default):

$ python main.py
$ # or training for fashion-mnist dataset
$ python main.py --dataset fashion-mnist
$ # If you need to monitor the training process, open tensorboard with this command
$ tensorboard --logdir=logdir
$ # or use `tail` command on linux system
$ tail -f results/val_acc.csv

Step 4. Calculate test accuracy

$ python main.py --is_training=False
$ # for fashion-mnist dataset
$ python main.py --dataset fashion-mnist --is_training=False

Note: The default parameters of batch size is 128, and epoch 50. You may need to modify the config.py file or use command line parameters to suit your case, e.g. set batch size to 64 and do once test summary every 200 steps: python main.py --test_sum_freq=200 --batch_size=48

Results

The pictures here are plotted by tensorboard and my tool plot_acc.R

  • training loss

total_loss margin_loss reconstruction_loss

Here are the models I trained and my talk and something else:

Baidu Netdisk(password:ahjs)

  • The best val error(using reconstruction)
Routing iteration 1 3 4
val error 0.36 0.36 0.41
Paper 0.29 0.25 -

test_acc

My simple comments for capsule

  1. A new version neural unit(vector in vector out, not scalar in scalar out)
  2. The routing algorithm is similar to attention mechanism
  3. Anyway, a great potential work, a lot to be built upon

My weChat:

my_wechat

Reference

Owner
Huadong Liao
Explore Nature from an Omics Perspective
Huadong Liao
D2LV: A Data-Driven and Local-Verification Approach for Image Copy Detection

Facebook AI Image Similarity Challenge: Matching Track —— Team: imgFp This is the source code of our 3rd place solution to matching track of Image Sim

16 Dec 25, 2022
Deep Dual Consecutive Network for Human Pose Estimation (CVPR2021)

Beanie - is an asynchronous ODM for MongoDB, based on Motor and Pydantic. It uses an abstraction over Pydantic models and Motor collections to work wi

295 Dec 29, 2022
Python scripts form performing stereo depth estimation using the HITNET model in ONNX.

ONNX-HITNET-Stereo-Depth-estimation Python scripts form performing stereo depth estimation using the HITNET model in ONNX. Stereo depth estimation on

Ibai Gorordo 30 Nov 08, 2022
Styleformer - Official Pytorch Implementation

Styleformer -- Official PyTorch implementation Styleformer: Transformer based Generative Adversarial Networks with Style Vector(https://arxiv.org/abs/

Jeeseung Park 159 Dec 12, 2022
An SMPC companion library for Syft

SyMPC A library that extends PySyft with SMPC support SyMPC /ˈsɪmpəθi/ is a library which extends PySyft ≥0.3 with SMPC support. It allows computing o

Arturo Marquez Flores 0 Oct 13, 2021
Free Book about Deep-Learning approaches for Chess (like AlphaZero, Leela Chess Zero and Stockfish NNUE)

Free Book about Deep-Learning approaches for Chess (like AlphaZero, Leela Chess Zero and Stockfish NNUE)

Dominik Klein 189 Dec 21, 2022
Deep deconfounded recommender (Deep-Deconf) for paper "Deep causal reasoning for recommendations"

Deep Causal Reasoning for Recommender Systems The codes are associated with the following paper: Deep Causal Reasoning for Recommendations, Yaochen Zh

Yaochen Zhu 22 Oct 15, 2022
Cross Quality LFW: A database for Analyzing Cross-Resolution Image Face Recognition in Unconstrained Environments

Cross-Quality Labeled Faces in the Wild (XQLFW) Here, we release the database, evaluation protocol and code for the following paper: Cross Quality LFW

Martin Knoche 10 Dec 12, 2022
HMLLDB is a collection of LLDB commands to assist in the debugging of iOS apps.

HMLLDB is a collection of LLDB commands to assist in the debugging of iOS apps. 中文介绍 Features Non-intrusive. Your iOS project does not need to be modi

mao2020 47 Oct 22, 2022
[ICCV 2021] Our work presents a novel neural rendering approach that can efficiently reconstruct geometric and neural radiance fields for view synthesis.

MVSNeRF Project page | Paper This repository contains a pytorch lightning implementation for the ICCV 2021 paper: MVSNeRF: Fast Generalizable Radiance

Anpei Chen 529 Dec 30, 2022
MarcoPolo is a clustering-free approach to the exploration of bimodally expressed genes along with group information in single-cell RNA-seq data

MarcoPolo is a method to discover differentially expressed genes in single-cell RNA-seq data without depending on prior clustering Overview MarcoPolo

Chanwoo Kim 13 Dec 18, 2022
Code for "NeRS: Neural Reflectance Surfaces for Sparse-View 3D Reconstruction in the Wild," in NeurIPS 2021

Code for Neural Reflectance Surfaces (NeRS) [arXiv] [Project Page] [Colab Demo] [Bibtex] This repo contains the code for NeRS: Neural Reflectance Surf

Jason Y. Zhang 234 Dec 30, 2022
Supplementary code for the paper "Meta-Solver for Neural Ordinary Differential Equations" https://arxiv.org/abs/2103.08561

Meta-Solver for Neural Ordinary Differential Equations Towards robust neural ODEs using parametrized solvers. Main idea Each Runge-Kutta (RK) solver w

Julia Gusak 25 Aug 12, 2021
Multi-Scale Progressive Fusion Network for Single Image Deraining

Multi-Scale Progressive Fusion Network for Single Image Deraining (MSPFN) This is an implementation of the MSPFN model proposed in the paper (Multi-Sc

Kuijiang 128 Nov 21, 2022
Pytorch implementation of the paper Progressive Growing of Points with Tree-structured Generators (BMVC 2021)

PGpoints Pytorch implementation of the paper Progressive Growing of Points with Tree-structured Generators (BMVC 2021) Hyeontae Son, Young Min Kim Pre

Hyeontae Son 9 Jun 06, 2022
Improving Compound Activity Classification via Deep Transfer and Representation Learning

Improving Compound Activity Classification via Deep Transfer and Representation Learning This repository is the official implementation of Improving C

NingLab 2 Nov 24, 2021
Official code of "Mitigating the Mutual Error Amplification for Semi-Supervised Object Detection"

CrossTeaching-SSOD 0. Introduction Official code of "Mitigating the Mutual Error Amplification for Semi-Supervised Object Detection" This repo include

Bruno Ma 9 Nov 29, 2022
As a part of the HAKE project, includes the reproduced SOTA models and the corresponding HAKE-enhanced versions (CVPR2020).

HAKE-Action HAKE-Action (TensorFlow) is a project to open the SOTA action understanding studies based on our Human Activity Knowledge Engine. It inclu

Yong-Lu Li 94 Nov 18, 2022
Parameter Efficient Deep Probabilistic Forecasting

PEDPF Parameter Efficient Deep Probabilistic Forecasting (PEDPF) is a repository containing code to run experiments for several deep learning based pr

Olivier Sprangers 10 Jun 13, 2022