Unofficial JAX implementations of Deep Learning models

Overview

JAX Models

license-shield release-shield python-shield code-style

Table of Contents
  1. About The Project
  2. Getting Started
  3. Contributing
  4. License
  5. Contact

About The Project

The JAX Models repository aims to provide open sourced JAX/Flax implementations for research papers originally without code or code written with frameworks other than JAX. The goal of this project is to make a collection of models, layers, activations and other utilities that are most commonly used for research. All papers and derived or translated code is cited in either the README or the docstrings. If you think that any citation is missed then please raise an issue.

All implementations provided here are available on Papers With Code.


Available model implementations for JAX are:
  1. MetaFormer is Actually What You Need for Vision (Weihao Yu et al., 2021)
  2. Augmenting Convolutional networks with attention-based aggregation (Hugo Touvron et al., 2021)
  3. MPViT : Multi-Path Vision Transformer for Dense Prediction (Youngwan Lee et al., 2021)
  4. MLP-Mixer: An all-MLP Architecture for Vision (Ilya Tolstikhin et al., 2021)
  5. Patches Are All You Need (Anonymous et al., 2021)
  6. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (Enze Xie et al., 2021)
  7. A ConvNet for the 2020s (Zhuang Liu et al., 2021)
  8. Masked Autoencoders Are Scalable Vision Learners (Kaiming He et al., 2021)

Available layers for out-of-the-box integration:
  1. DropPath (Stochastic Depth) (Gao Huang et al., 2021)
  2. Squeeze-and-Excitation Layer (Jie Hu et al. 2019)
  3. Depthwise Convolution (François Chollet, 2017)

Prerequisites

Prerequisites can be installed separately through the requirements.txt file in the main directory using:

pip install -r requirements.txt

The use of a virtual environment is highly recommended to avoid version incompatibilites.

Installation

This project is built with Python 3 for the latest JAX/Flax versions and can be directly installed via pip.

pip install jax-models

If you wish to use the latest version then you can directly clone the repository too.

git clone https://github.com/DarshanDeshpande/jax-models.git

Usage

To see all model architectures available:

from jax_models.models.model_registry import list_models
from pprint import pprint

pprint(list_models())

To load your desired model:

from jax_models.models.model_registry import load_model
load_model('mpvit-base', attach_head=True, num_classes=1000, dropout=0.1)

Contributing

Please raise an issue if any implementation gives incorrect results, crashes unexpectedly during training/inference or if any citation is missing.

You can contribute to jax_models by supporting me with compute resources or by contributing your own resources to provide pretrained weights.

If you wish to donate to this inititative then please drop me a mail here.

License

Distributed under the Apache 2.0 License. See LICENSE for more information.

Contact

Feel free to reach out for any issues or requests related to these implementations

Darshan Deshpande - Email | Twitter | LinkedIn

You might also like...
Very deep VAEs in JAX/Flax
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

PyTorch implementations of neural network models for keyword spotting
PyTorch implementations of neural network models for keyword spotting

Honk: CNNs for Keyword Spotting Honk is a PyTorch reimplementation of Google's TensorFlow convolutional neural networks for keyword spotting, which ac

Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning
Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning

Proxy Anchor Loss for Deep Metric Learning Unofficial pytorch, tensorflow and mxnet implementations of Proxy Anchor Loss for Deep Metric Learning. Not

Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.
Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.

Stock Price Prediction Using Deep Learning Univariate Time Series Predicting stock price using historical data of a company using Neural networks for

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Objax Tutorials | Install | Documentation | Philosophy This is not an officially supported Google product. Objax is an open source machine learning fr

Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.

JAX code for the paper
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Comments
  • Missing Axis Swap in ExtractPatches and MergePatches

    Missing Axis Swap in ExtractPatches and MergePatches

    In patch_utils.py, the modules ExtractPatches and MergePatches are missing an axis swap between the reshapes, resulting in the extracted patches becoming horizontal stripes. For example, if we follow the code in ExtractPatches:

    >>> inputs = jnp.arange(16).reshape(1, 4, 4, 1)
    >>> inputs[0, :, :, 0]
    
    DeviceArray([[ 0,  1,  2,  3],
                 [ 4,  5,  6,  7],
                 [ 8,  9, 10, 11],
                 [12, 13, 14, 15]], dtype=int32)
    
    >>> patch_size = 2
    >>> batch, height, width, channels = inputs.shape
    >>> height, width = height // patch_size, width // patch_size
    >>> x = jnp.reshape(inputs, (batch, height, patch_size, width, patch_size, channels))
    >>> x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
    >>> x[0, 0, :]
    
    DeviceArray([0, 1, 2, 3], dtype=int32)
    

    We see that the first patch extracted is not the patch containing [0, 1, 4, 5], but the horizontal stripe [0, 1, 2, 3]. To fix this problem, we should add an axis swap. For ExtractPatches, this should be:

    batch, height, width, channels = inputs.shape
    height, width = height // patch_size, width // patch_size
    x = jnp.reshape(
        inputs, (batch, height, patch_size, width, patch_size, channels)
    )
    x = jnp.swapaxes(x, 2, 3)
    x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
    

    For MergePatches, this should be:

    batch, length, _ = inputs.shape
    height = width = int(length**0.5)
    x = jnp.reshape(inputs, (batch, height, width, patch_size, patch_size, -1))
    x = jnp.swapaxes(x, 2, 3)
    x = jnp.reshape(x, (batch, height * patch_size, width * patch_size, -1))
    
    bug 
    opened by young-geng 4
  • fix convnext to make it work with jax.jit

    fix convnext to make it work with jax.jit

    Hey, first of all, thanks for the nice codebase. When doing inference using the convnext model, I noticed the following issue:

    Calling x.item() will call float(x), which breaks the jit tracer. We can remove the list comprehension in unnecessary conversion to make jax.jit work. Without jax.jit, the model is very slow for me, running with only ~30% GPU utilization (RTX 3090).

    This issue could apply to other models as well, maybe it is a good idea to include a test for applying jax.jit to each model?

    opened by maxidl 1
Releases(v0.5-van)
Owner
Helping Machines Learn Better 💻😃
Generate pixel-style avatars with python.

face2pixel Generate pixel-style avatars with python. Run: Clone the project: git clone https://github.com/theodorecooper/face2pixel install requiremen

Theodore Cooper 2 May 11, 2022
Certis - Certis, A High-Quality Backtesting Engine

Certis - Backtesting For y'all Certis is a powerful, lightweight, simple backtes

Yeachan-Heo 46 Oct 30, 2022
PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT.

MAE for Self-supervised ViT Introduction This is an unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-sup

36 Oct 30, 2022
This is the codebase for the ICLR 2021 paper Trajectory Prediction using Equivariant Continuous Convolution

Trajectory Prediction using Equivariant Continuous Convolution (ECCO) This is the codebase for the ICLR 2021 paper Trajectory Prediction using Equivar

Spatiotemporal Machine Learning 45 Jul 22, 2022
22 Oct 14, 2022
Video Frame Interpolation with Transformer (CVPR2022)

VFIformer Official PyTorch implementation of our CVPR2022 paper Video Frame Interpolation with Transformer Dependencies python = 3.8 pytorch = 1.8.0

DV Lab 63 Dec 16, 2022
PICARD - Parsing Incrementally for Constrained Auto-Regressive Decoding from Language Models

This is the official implementation of the following paper: Torsten Scholak, Nathan Schucher, Dzmitry Bahdanau. PICARD - Parsing Incrementally for Con

ElementAI 217 Jan 01, 2023
The PASS dataset: pretrained models and how to get the data - PASS: Pictures without humAns for Self-Supervised Pretraining

The PASS dataset: pretrained models and how to get the data - PASS: Pictures without humAns for Self-Supervised Pretraining

Yuki M. Asano 249 Dec 22, 2022
Official implementation of the paper Momentum Capsule Networks (MoCapsNet)

Momentum Capsule Network Official implementation of the paper Momentum Capsule Networks (MoCapsNet). Abstract Capsule networks are a class of neural n

8 Oct 20, 2022
iNAS: Integral NAS for Device-Aware Salient Object Detection

iNAS: Integral NAS for Device-Aware Salient Object Detection Introduction Integral search design (jointly consider backbone/head structures, design/de

顾宇超 77 Dec 02, 2022
Learning High-Speed Flight in the Wild

Learning High-Speed Flight in the Wild This repo contains the code associated to the paper Learning Agile Flight in the Wild. For more information, pl

Robotics and Perception Group 391 Dec 29, 2022
Official Pytorch Implementation of: "ImageNet-21K Pretraining for the Masses"(2021) paper

ImageNet-21K Pretraining for the Masses Paper | Pretrained models Official PyTorch Implementation Tal Ridnik, Emanuel Ben-Baruch, Asaf Noy, Lihi Zelni

574 Jan 02, 2023
Sky Computing: Accelerating Geo-distributed Computing in Federated Learning

Sky Computing Introduction Sky Computing is a load-balanced framework for federated learning model parallelism. It adaptively allocate model layers to

HPC-AI Tech 72 Dec 27, 2022
Official PyTorch implementation of the paper: Improving Graph Neural Network Expressivity via Subgraph Isomorphism Counting.

Improving Graph Neural Network Expressivity via Subgraph Isomorphism Counting Official PyTorch implementation of the paper: Improving Graph Neural Net

Giorgos Bouritsas 58 Dec 31, 2022
Rest API Written In Python To Classify NSFW Images.

Rest API Written In Python To Classify NSFW Images.

Wahyusaputra 2 Dec 23, 2021
Minimal deep learning library written from scratch in Python, using NumPy/CuPy.

SmallPebble Project status: experimental, unstable. SmallPebble is a minimal/toy automatic differentiation/deep learning library written from scratch

Sidney Radcliffe 92 Dec 30, 2022
Implementation of Kronecker Attention in Pytorch

Kronecker Attention Pytorch Implementation of Kronecker Attention in Pytorch. Results look less than stellar, but if someone found some context where

Phil Wang 16 May 06, 2022
An addernet CUDA version

Training addernet accelerated by CUDA Usage cd adder_cuda python setup.py install cd .. python main.py Environment pytorch 1.10.0 CUDA 11.3 benchmark

LingXY 4 Jun 20, 2022
iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis

iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis Andreas Bl

CompVis Heidelberg 36 Dec 25, 2022
Implementation of gaze tracking and demo

Predicting Customer Demand by Using Gaze Detecting and Object Tracking This project is the integration of gaze detecting and object tracking. Predict

2 Oct 20, 2022