Pretrained models for Jax/Haiku; MobileNet, ResNet, VGG, Xception.

Overview

Pre-trained image classification models for Jax/Haiku

Jax/Haiku Applications are deep learning models that are made available alongside pre-trained weights. These models can be used for prediction, feature extraction, and fine-tuning.

Available Models

  • MobileNetV1
  • ResNet, ResNetV2
  • VGG16, VGG19
  • Xception

Planned Releases

  • MobileNetV2, MobileNetV3
  • InceptionResNetV2, InceptionV3
  • EfficientNetV1, EfficientNetV2

Installation

Haikumodels require Python 3.7 or later.

  1. Needed libraries can be installed using "installation.txt".
  2. If Jax GPU support desired, must be installed seperately according to system needs.

Usage examples for image classification models

Classify ImageNet classes with ResNet50

import haiku as hk
import jax
import jax.numpy as jnp
from PIL import Image

import haikumodels as hm

rng = jax.random.PRNGKey(42)


def _model(images, is_training):
  net = hm.ResNet50()
  return net(images, is_training)


model = hk.transform_with_state(_model)

img_path = "elephant.jpg"
img = Image.open(img_path).resize((224, 224))

x = jnp.asarray(img, dtype=jnp.float32)
x = jnp.expand_dims(x, axis=0)
x = hm.resnet.preprocess_input(x)

params, state = model.init(rng, x, is_training=True)

preds, _ = model.apply(params, state, None, x, is_training=False)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print("Predicted:", hm.decode_predictions(preds, top=3)[0])
# Predicted:
# [('n02504013', 'Indian_elephant', 0.8784022),
# ('n01871265', 'tusker', 0.09620289),
# ('n02504458', 'African_elephant', 0.025362419)]

Extract features with VGG16

import haiku as hk
import jax
import jax.numpy as jnp
from PIL import Image

import haikumodels as hm

rng = jax.random.PRNGKey(42)

model = hk.without_apply_rng(hk.transform(hm.VGG16(include_top=False)))

img_path = "elephant.jpg"
img = Image.open(img_path).resize((224, 224))

x = jnp.asarray(img, dtype=jnp.float32)
x = jnp.expand_dims(x, axis=0)
x = hm.vgg.preprocess_input(x)

params = model.init(rng, x)

features = model.apply(params, x)

Fine-tune Xception on a new set of classes

from typing import Callable, Any, Sequence, Optional

import optax
import haiku as hk
import jax
import jax.numpy as jnp

import haikumodels as hm

rng = jax.random.PRNGKey(42)


class Freezable_TrainState(NamedTuple):
  trainable_params: hk.Params
  non_trainable_params: hk.Params
  state: hk.State
  opt_state: optax.OptState


# create your custom top layers and include the desired pretrained model
class ft_xception(hk.Module):

  def __init__(
      self,
      classes: int,
      classifier_activation: Callable[[jnp.ndarray],
                                      jnp.ndarray] = jax.nn.softmax,
      with_bias: bool = True,
      w_init: Callable[[Sequence[int], Any], jnp.ndarray] = None,
      b_init: Callable[[Sequence[int], Any], jnp.ndarray] = None,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)
    self.classifier_activation = classifier_activation

    self.xception_no_top = hm.Xception(include_top=False)
    self.dense_layer = hk.Linear(
        output_size=1024,
        with_bias=with_bias,
        w_init=w_init,
        b_init=b_init,
        name="trainable_dense_layer",
    )
    self.top_layer = hk.Linear(
        output_size=classes,
        with_bias=with_bias,
        w_init=w_init,
        b_init=b_init,
        name="trainable_top_layer",
    )

  def __call__(self, inputs: jnp.ndarray, is_training: bool):
    out = self.xception_no_top(inputs, is_training)
    out = jnp.mean(out, axis=(1, 2))
    out = self.dense_layer(out)
    out = jax.nn.relu(out)
    out = self.top_layer(out)
    out = self.classifier_activation(out)


# use `transform_with_state` if models has batchnorm in it
# else use `transform` and then `without_apply_rng`
def _model(images, is_training):
  net = ft_xception(classes=200)
  return net(images, is_training)


model = hk.transform_with_state(_model)

# create your desired optimizer using Optax or alternatives
opt = optax.rmsprop(learning_rate=1e-4, momentum=0.90)


# this function will initialize params and state
# use the desired keyword to divide params to trainable and non_trainable
def initial_state(x_y, nonfreeze_key="trainable"):
  x, _ = x_y
  params, state = model.init(rng, x, is_training=True)

  trainable_params, non_trainable_params = hk.data_structures.partition(
      lambda m, n, p: nonfreeze_key in m, params)

  opt_state = opt.init(params)

  return Freezable_TrainState(trainable_params, non_trainable_params, state,
                              opt_state)


train_state = initial_state(next(gen_x_y))


# create your own custom loss function as desired
def loss_function(trainable_params, non_trainable_params, state, x_y):
  x, y = x_y
  params = hk.data_structures.merge(trainable_params, non_trainable_params)
  y_, state = model.apply(params, state, None, x, is_training=True)

  cce = categorical_crossentropy(y, y_)

  return cce, state


# to update params and optimizer, a train_step function must be created
@jax.jit
def train_step(train_state: Freezable_TrainState, x_y):
  trainable_params, non_trainable_params, state, opt_state = train_state
  trainable_params_grads, _ = jax.grad(loss_function,
                                       has_aux=True)(trainable_params,
                                                     non_trainable_params,
                                                     state, x_y)

  updates, new_opt_state = opt.update(trainable_params_grads, opt_state)
  new_trainable_params = optax.apply_updates(trainable_params, updates)

  train_state = Freezable_TrainState(new_trainable_params, non_trainable_params,
                                     state, new_opt_state)
  return train_state


# train the model on the new data for few epochs
train_state = train_step(train_state, next(gen_x_y))

# after training is complete it possible to merge
# trainable and non_trainable params to use for prediction
trainable_params, non_trainable_params, state, _ = train_state
params = hk.data_structures.merge(trainable_params, non_trainable_params)
preds, _ = model.apply(params, state, None, x, is_training=False)
You might also like...
3D ResNet Video Classification accelerated by TensorRT
3D ResNet Video Classification accelerated by TensorRT

Activity Recognition TensorRT Perform video classification using 3D ResNets trained on Kinetics-400 dataset and accelerated with TensorRT P.S Click on

improvement of CLIP features over the traditional resnet features on the visual question answering, image captioning, navigation and visual entailment tasks.

CLIP-ViL In our paper "How Much Can CLIP Benefit Vision-and-Language Tasks?", we show the improvement of CLIP features over the traditional resnet fea

PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

R2Plus1D-PyTorch PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal

PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT.

MoCo v3 for Self-supervised ResNet and ViT Introduction This is a PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT. The original M

Reproduces ResNet-V3 with pytorch
Reproduces ResNet-V3 with pytorch

ResNeXt.pytorch Reproduces ResNet-V3 (Aggregated Residual Transformations for Deep Neural Networks) with pytorch. Tried on pytorch 1.6 Trains on Cifar

DeepLab resnet v2 model in pytorch

pytorch-deeplab-resnet DeepLab resnet v2 model implementation in pytorch. The architecture of deepLab-ResNet has been replicated exactly as it is from

Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet
Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet

Reproduce ResNet-v2 using MXNet Requirements Install MXNet on a machine with CUDA GPU, and it's better also installed with cuDNN v5 Please fix the ran

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.
In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.

cdf_att_classification classes = {0: 'cat', 1: 'dog', 2: 'flower'} In this project we use both Resnet and Self-attention layer for cdf-Classification.

Comments
  • Expected top-1 test accuracy

    Expected top-1 test accuracy

    Hi

    This is a fantastic project! The released checkpoints are super helpful!

    I am wondering what's the top-1 test accuracy that one should get using the released ResNet-50 checkpoints. I am able to reach 0.749 using the my own ImageNet dataloader implemented via Tensorflow Datasets. Is the number close to your results?

    BTW, it would also be very helpful if you could release your training and dataloading code for these models!

    Thanks,

    opened by xidulu 2
  • Fitting issue

    Fitting issue

    I was trying to use a few of your pre-trained models, in particular the ResNet50 and VGG16 for features extraction, but unfortunately I didn't manage to fit on the Nvidia Titan X with 12GB of VRAM my question is which GPU did you use for training, how much VRAM I need for use them?

    For the VGG16 the system was asking me for 4 more GB and for the ResNet50 about 20 more

    Thanks.

    opened by mattiadutto 1
Owner
Alper Baris CELIK
Alper Baris CELIK
Ganilla - Official Pytorch implementation of GANILLA

GANILLA We provide PyTorch implementation for: GANILLA: Generative Adversarial Networks for Image to Illustration Translation. Paper Arxiv Updates (Fe

Samet Hi 462 Dec 05, 2022
the code for our CVPR 2021 paper Bilateral Grid Learning for Stereo Matching Network [BGNet]

BGNet This repository contains the code for our CVPR 2021 paper Bilateral Grid Learning for Stereo Matching Network [BGNet] Environment Python 3.6.* C

3DCV developer 87 Nov 29, 2022
Code for the paper "JANUS: Parallel Tempered Genetic Algorithm Guided by Deep Neural Networks for Inverse Molecular Design"

JANUS: Parallel Tempered Genetic Algorithm Guided by Deep Neural Networks for Inverse Molecular Design This repository contains code for the paper: JA

Aspuru-Guzik group repo 55 Nov 29, 2022
Generate pixel-style avatars with python.

face2pixel Generate pixel-style avatars with python. Run: Clone the project: git clone https://github.com/theodorecooper/face2pixel install requiremen

Theodore Cooper 2 May 11, 2022
Official codebase for "B-Pref: Benchmarking Preference-BasedReinforcement Learning" contains scripts to reproduce experiments.

B-Pref Official codebase for B-Pref: Benchmarking Preference-BasedReinforcement Learning contains scripts to reproduce experiments. Install conda env

48 Dec 20, 2022
BuildingNet: Learning to Label 3D Buildings

BuildingNet This is the implementation of the BuildingNet architecture described in this paper: Paper: BuildingNet: Learning to Label 3D Buildings Arx

16 Nov 07, 2022
Exe-to-xlsm - Simple script to create VBscript of exe and inject to xlsm

🎁 Exe To Office Executable file injection to Office documents: .xlsm, .docm, .p

3 Jan 25, 2022
PyTorch implementation code for the paper MixCo: Mix-up Contrastive Learning for Visual Representation

How to Reproduce our Results This repository contains PyTorch implementation code for the paper MixCo: Mix-up Contrastive Learning for Visual Represen

opcrisis 46 Dec 15, 2022
A clean and robust Pytorch implementation of PPO on continuous action space.

PPO-Continuous-Pytorch I found the current implementation of PPO on continuous action space is whether somewhat complicated or not stable. And this is

XinJingHao 56 Dec 16, 2022
A toolset of Python programs for signal modeling and indentification via sparse semilinear autoregressors.

SPAAR Description A toolset of Python programs for signal modeling via sparse semilinear autoregressors. References Vides, F. (2021). Computing Semili

Fredy Vides 0 Oct 30, 2021
Code for NeurIPS 2021 paper 'Spatio-Temporal Variational Gaussian Processes'

Spatio-Temporal Variational GPs This repository is the official implementation of the methods in the publication: O. Hamelijnck, W.J. Wilkinson, N.A.

AaltoML 26 Sep 16, 2022
General-purpose program synthesiser

DeepSynth General-purpose program synthesiser. This is the repository for the code of the paper "Scaling Neural Program Synthesis with Distribution-ba

Nathanaël Fijalkow 24 Oct 23, 2022
Anomaly Detection Based on Hierarchical Clustering of Mobile Robot Data

We proposed a new approach to detect anomalies of mobile robot data. We investigate each data seperately with two clustering method hierarchical and k-means. There are two sub-method that we used for

Zekeriyya Demirci 1 Jan 09, 2022
IDRLnet, a Python toolbox for modeling and solving problems through Physics-Informed Neural Network (PINN) systematically.

IDRLnet IDRLnet is a machine learning library on top of PyTorch. Use IDRLnet if you need a machine learning library that solves both forward and inver

IDRL 105 Dec 17, 2022
The code for MM2021 paper "Multi-Level Counterfactual Contrast for Visual Commonsense Reasoning"

The Code for MM2021 paper "Multi-Level Counterfactual Contrast for Visual Commonsense Reasoning" Setting up and using the repo Get the dataset. Follow

4 Apr 20, 2022
SMPLpix: Neural Avatars from 3D Human Models

subject0_validation_poses.mp4 Left: SMPL-X human mesh registered with SMPLify-X, middle: SMPLpix render, right: ground truth video. SMPLpix: Neural Av

Sergey Prokudin 292 Dec 30, 2022
PyTorch EO aims to make Deep Learning for Earth Observation data easy and accessible to real-world cases and research alike.

Pytorch EO Deep Learning for Earth Observation applications and research. 🚧 This project is in early development, so bugs and breaking changes are ex

earthpulse 28 Aug 25, 2022
Abstractive opinion summarization system (SelSum) and the largest dataset of Amazon product summaries (AmaSum). EMNLP 2021 conference paper.

Learning Opinion Summarizers by Selecting Informative Reviews This repository contains the codebase and the dataset for the corresponding EMNLP 2021

Arthur Bražinskas 39 Jan 01, 2023
Yolo object detection - Yolo object detection with python

How to run download required files make build_image make download Docker versio

3 Jan 26, 2022
Data cleaning, missing value handle, EDA use in this project

Lending Club Case Study Project Brief Solving this assignment will give you an idea about how real business problems are solved using EDA. In this cas

Dhruvil Sheth 1 Jan 05, 2022