Unofficial JAX implementations of Deep Learning models

Overview

JAX Models

license-shield release-shield python-shield code-style

Table of Contents
  1. About The Project
  2. Getting Started
  3. Contributing
  4. License
  5. Contact

About The Project

The JAX Models repository aims to provide open sourced JAX/Flax implementations for research papers originally without code or code written with frameworks other than JAX. The goal of this project is to make a collection of models, layers, activations and other utilities that are most commonly used for research. All papers and derived or translated code is cited in either the README or the docstrings. If you think that any citation is missed then please raise an issue.

All implementations provided here are available on Papers With Code.


Available model implementations for JAX are:
  1. MetaFormer is Actually What You Need for Vision (Weihao Yu et al., 2021)
  2. Augmenting Convolutional networks with attention-based aggregation (Hugo Touvron et al., 2021)
  3. MPViT : Multi-Path Vision Transformer for Dense Prediction (Youngwan Lee et al., 2021)
  4. MLP-Mixer: An all-MLP Architecture for Vision (Ilya Tolstikhin et al., 2021)
  5. Patches Are All You Need (Anonymous et al., 2021)
  6. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (Enze Xie et al., 2021)
  7. A ConvNet for the 2020s (Zhuang Liu et al., 2021)
  8. Masked Autoencoders Are Scalable Vision Learners (Kaiming He et al., 2021)

Available layers for out-of-the-box integration:
  1. DropPath (Stochastic Depth) (Gao Huang et al., 2021)
  2. Squeeze-and-Excitation Layer (Jie Hu et al. 2019)
  3. Depthwise Convolution (François Chollet, 2017)

Prerequisites

Prerequisites can be installed separately through the requirements.txt file in the main directory using:

pip install -r requirements.txt

The use of a virtual environment is highly recommended to avoid version incompatibilites.

Installation

This project is built with Python 3 for the latest JAX/Flax versions and can be directly installed via pip.

pip install jax-models

If you wish to use the latest version then you can directly clone the repository too.

git clone https://github.com/DarshanDeshpande/jax-models.git

Usage

To see all model architectures available:

from jax_models.models.model_registry import list_models
from pprint import pprint

pprint(list_models())

To load your desired model:

from jax_models.models.model_registry import load_model
load_model('mpvit-base', attach_head=True, num_classes=1000, dropout=0.1)

Contributing

Please raise an issue if any implementation gives incorrect results, crashes unexpectedly during training/inference or if any citation is missing.

You can contribute to jax_models by supporting me with compute resources or by contributing your own resources to provide pretrained weights.

If you wish to donate to this inititative then please drop me a mail here.

License

Distributed under the Apache 2.0 License. See LICENSE for more information.

Contact

Feel free to reach out for any issues or requests related to these implementations

Darshan Deshpande - Email | Twitter | LinkedIn

You might also like...
Very deep VAEs in JAX/Flax
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

PyTorch implementations of neural network models for keyword spotting
PyTorch implementations of neural network models for keyword spotting

Honk: CNNs for Keyword Spotting Honk is a PyTorch reimplementation of Google's TensorFlow convolutional neural networks for keyword spotting, which ac

Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning
Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning

Proxy Anchor Loss for Deep Metric Learning Unofficial pytorch, tensorflow and mxnet implementations of Proxy Anchor Loss for Deep Metric Learning. Not

Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.
Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.

Stock Price Prediction Using Deep Learning Univariate Time Series Predicting stock price using historical data of a company using Neural networks for

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Objax Tutorials | Install | Documentation | Philosophy This is not an officially supported Google product. Objax is an open source machine learning fr

Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.

JAX code for the paper
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Comments
  • Missing Axis Swap in ExtractPatches and MergePatches

    Missing Axis Swap in ExtractPatches and MergePatches

    In patch_utils.py, the modules ExtractPatches and MergePatches are missing an axis swap between the reshapes, resulting in the extracted patches becoming horizontal stripes. For example, if we follow the code in ExtractPatches:

    >>> inputs = jnp.arange(16).reshape(1, 4, 4, 1)
    >>> inputs[0, :, :, 0]
    
    DeviceArray([[ 0,  1,  2,  3],
                 [ 4,  5,  6,  7],
                 [ 8,  9, 10, 11],
                 [12, 13, 14, 15]], dtype=int32)
    
    >>> patch_size = 2
    >>> batch, height, width, channels = inputs.shape
    >>> height, width = height // patch_size, width // patch_size
    >>> x = jnp.reshape(inputs, (batch, height, patch_size, width, patch_size, channels))
    >>> x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
    >>> x[0, 0, :]
    
    DeviceArray([0, 1, 2, 3], dtype=int32)
    

    We see that the first patch extracted is not the patch containing [0, 1, 4, 5], but the horizontal stripe [0, 1, 2, 3]. To fix this problem, we should add an axis swap. For ExtractPatches, this should be:

    batch, height, width, channels = inputs.shape
    height, width = height // patch_size, width // patch_size
    x = jnp.reshape(
        inputs, (batch, height, patch_size, width, patch_size, channels)
    )
    x = jnp.swapaxes(x, 2, 3)
    x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
    

    For MergePatches, this should be:

    batch, length, _ = inputs.shape
    height = width = int(length**0.5)
    x = jnp.reshape(inputs, (batch, height, width, patch_size, patch_size, -1))
    x = jnp.swapaxes(x, 2, 3)
    x = jnp.reshape(x, (batch, height * patch_size, width * patch_size, -1))
    
    bug 
    opened by young-geng 4
  • fix convnext to make it work with jax.jit

    fix convnext to make it work with jax.jit

    Hey, first of all, thanks for the nice codebase. When doing inference using the convnext model, I noticed the following issue:

    Calling x.item() will call float(x), which breaks the jit tracer. We can remove the list comprehension in unnecessary conversion to make jax.jit work. Without jax.jit, the model is very slow for me, running with only ~30% GPU utilization (RTX 3090).

    This issue could apply to other models as well, maybe it is a good idea to include a test for applying jax.jit to each model?

    opened by maxidl 1
Releases(v0.5-van)
Owner
Helping Machines Learn Better 💻😃
This repository contains code demonstrating the methods outlined in Path Signature Area-Based Causal Discovery in Coupled Time Series presented at Causal Analysis Workshop 2021.

signed-area-causal-inference This repository contains code demonstrating the methods outlined in Path Signature Area-Based Causal Discovery in Coupled

Will Glad 1 Mar 11, 2022
Human Pose Detection on EdgeTPU

Coral PoseNet Pose estimation refers to computer vision techniques that detect human figures in images and video, so that one could determine, for exa

google-coral 476 Dec 31, 2022
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data

Real-ESRGAN Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data Ported from https://github.com/xinntao/Real-ESRGAN Depend

Holy Wu 44 Dec 27, 2022
A collection of educational notebooks on multi-view geometry and computer vision.

Multiview notebooks This is a collection of educational notebooks on multi-view geometry and computer vision. Subjects covered in these notebooks incl

Max 65 Dec 09, 2022
Surrogate-Assisted Genetic Algorithm for Wrapper Feature Selection

SAGA Surrogate-Assisted Genetic Algorithm for Wrapper Feature Selection Please refer to the Jupyter notebook (Example.ipynb) for an example of using t

9 Dec 28, 2022
a project for 3D multi-object tracking

a project for 3D multi-object tracking

155 Jan 04, 2023
Hierarchical Few-Shot Generative Models

Hierarchical Few-Shot Generative Models Giorgio Giannone, Ole Winther This repo contains code and experiments for the paper Hierarchical Few-Shot Gene

Giorgio Giannone 6 Dec 12, 2022
A PyTorch Lightning solution to training OpenAI's CLIP from scratch.

train-CLIP 📎 A PyTorch Lightning solution to training CLIP from scratch. Goal ⚽ Our aim is to create an easy to use Lightning implementation of OpenA

Cade Gordon 396 Dec 30, 2022
Code for Low-Cost Algorithmic Recourse for Users With Uncertain Cost Functions

EMS-COLS-recourse Initial Code for Low-Cost Algorithmic Recourse for Users With Uncertain Cost Functions Folder structure: data folder contains raw an

Prateek Yadav 1 Nov 25, 2022
Python package provinding tools for artistic interactive applications using AI

Documentation redrawing Python package provinding tools for artistic interactive applications using AI Created by ReDrawing Campinas team for the Open

ReDrawing Campinas 1 Sep 30, 2021
[MICCAI'20] AlignShift: Bridging the Gap of Imaging Thickness in 3D Anisotropic Volumes

AlignShift NEW: Code for our new MICCAI'21 paper "Asymmetric 3D Context Fusion for Universal Lesion Detection" will also be pushed to this repository

Medical 3D Vision 42 Jan 06, 2023
Trainable Bilateral Filter Layer (PyTorch)

Trainable Bilateral Filter Layer (PyTorch) This repository contains our GPU-accelerated trainable bilateral filter layer (three spatial and one range

FabianWagner 26 Dec 25, 2022
Computational Pathology Toolbox developed by TIA Centre, University of Warwick.

TIA Toolbox Computational Pathology Toolbox developed at the TIA Centre Getting Started All Users This package is for those interested in digital path

Tissue Image Analytics (TIA) Centre 156 Jan 08, 2023
Generalized and Efficient Blackbox Optimization System.

OpenBox Doc | OpenBox中文文档 OpenBox: Generalized and Efficient Blackbox Optimization System OpenBox is an efficient and generalized blackbox optimizatio

DAIR Lab 238 Dec 29, 2022
Shuffle Attention for MobileNetV3

SA-MobileNetV3 Shuffle Attention for MobileNetV3 Train Run the following command for train model on your own dataset: python train.py --dataset mnist

Sajjad Aemmi 36 Dec 28, 2022
Unoffical implementation about Image Super-Resolution via Iterative Refinement by Pytorch

Image Super-Resolution via Iterative Refinement Paper | Project Brief This is a unoffical implementation about Image Super-Resolution via Iterative Re

LiangWei Jiang 2.5k Jan 02, 2023
Galileo library for large scale graph training by JD

近年来,图计算在搜索、推荐和风控等场景中获得显著的效果,但也面临超大规模异构图训练,与现有的深度学习框架Tensorflow和PyTorch结合等难题。 Galileo(伽利略)是一个图深度学习框架,具备超大规模、易使用、易扩展、高性能、双后端等优点,旨在解决超大规模图算法在工业级场景的落地难题,提

JD Galileo Team 128 Nov 29, 2022
Franka Emika Panda manipulator kinematics&dynamics simulation

pybullet_sim_panda Pybullet simulation environment for Franka Emika Panda Dependency pybullet, numpy, spatial_math_mini Simple example (please check s

0 Jan 20, 2022
ICML 21 - Voice2Series: Reprogramming Acoustic Models for Time Series Classification

Voice2Series-Reprogramming Voice2Series: Reprogramming Acoustic Models for Time Series Classification International Conference on Machine Learning (IC

49 Jan 03, 2023
Defending graph neural networks against adversarial attacks (NeurIPS 2020)

GNNGuard: Defending Graph Neural Networks against Adversarial Attacks Authors: Xiang Zhang ( Zitnik Lab @ Harvard 44 Dec 07, 2022