Jaxtorch (a jax nn library)

Related tags

Deep Learningjaxtorch
Overview

Jaxtorch (a jax nn library)

This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular jax frameworks (flax, haiku).

The objective is to enable pytorch-like model definition and training with a minimum of magic. Simple example:

import jax
import jax.numpy as jnp
import jaxlib
import jaxtorch

# Modules are just classes that inherit jaxtorch.Module
class Linear(jaxtorch.Module):
    # They can accept any constructor parameters
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        # Parameters are represented by a Param type, which identifies
        # them, and specifies how to initialize them.
        self.weight = jaxtorch.init.glorot_normal(out_features, in_features)
        assert type(self.weight) is jaxtorch.Param
        if bias:
            self.bias = jaxtorch.init.zeros(out_features)
        else:
            self.bias = None

    # The forward function accepts cx, a Context object as the first argument
    # always. This provides random number generation as well as the parameters.
    def forward(self, cx: jaxtorch.Context, x):
        # Parameters are looked up in the context using the stored identifier.
        y = x @ jnp.transpose(cx[self.weight])
        if self.bias:
            y = y + cx[self.bias]
        return y

model = Linear(3, 3)

# You initialize the weights by passing a RNG key.
# Calling init_weights also names all the parameters in the Module tree.
params = model.init_weights(jax.random.PRNGKey(0))

# Parameters are stored in a dictionary by name.
assert type(params) is dict
assert type(params[model.weight.name]) is jaxlib.xla_extension.DeviceArray
assert model.weight.name == 'weight'

def loss(params, key):
    cx = jaxtorch.Context(params, key)
    x = jnp.array([1.0,2.0,3.0])
    y = jnp.array([4.0,5.0,6.0])
    return jnp.mean((model(cx, x) - y)**2)
f_grad = jax.value_and_grad(loss)

for _ in range(100):
    (loss, grad) = f_grad(params, jax.random.PRNGKey(0))
    params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grad)
print(loss)
# 4.7440533e-08
Owner
nshepperd
nshepperd
Numerical-computing-is-fun - Learning numerical computing with notebooks for all ages.

As much as this series is to educate aspiring computer programmers and data scientists of all ages and all backgrounds, it is also a reminder to mysel

EKA foundation 758 Dec 25, 2022
PSTR: End-to-End One-Step Person Search With Transformers (CVPR2022)

PSTR (CVPR2022) This code is an official implementation of "PSTR: End-to-End One-Step Person Search With Transformers (CVPR2022)". End-to-end one-step

Jiale Cao 28 Dec 13, 2022
Compressed Video Action Recognition

Compressed Video Action Recognition Chao-Yuan Wu, Manzil Zaheer, Hexiang Hu, R. Manmatha, Alexander J. Smola, Philipp Krähenbühl. In CVPR, 2018. [Proj

Chao-Yuan Wu 479 Dec 26, 2022
An executor that loads ONNX models and embeds documents using the ONNX runtime.

ONNXEncoder An executor that loads ONNX models and embeds documents using the ONNX runtime. Usage via Docker image (recommended) from jina import Flow

Jina AI 2 Mar 15, 2022
MediaPipeのPythonパッケージのサンプルです。2020/12/11時点でPython実装のある4機能(Hands、Pose、Face Mesh、Holistic)について用意しています。

mediapipe-python-sample MediaPipeのPythonパッケージのサンプルです。 2020/12/11時点でPython実装のある以下4機能について用意しています。 Hands Pose Face Mesh Holistic Requirement mediapipe 0.

KazuhitoTakahashi 217 Dec 12, 2022
MCMC samplers for Bayesian estimation in Python, including Metropolis-Hastings, NUTS, and Slice

Sampyl May 29, 2018: version 0.3 Sampyl is a package for sampling from probability distributions using MCMC methods. Similar to PyMC3 using theano to

Mat Leonard 304 Dec 25, 2022
Deep Illuminator is a data augmentation tool designed for image relighting. It can be used to easily and efficiently generate a wide range of illumination variants of a single image.

Deep Illuminator Deep Illuminator is a data augmentation tool designed for image relighting. It can be used to easily and efficiently generate a wide

George Chogovadze 52 Nov 29, 2022
More Photos are All You Need: Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval

More Photos are All You Need: Semi-Supervised Learning for Fine-Grained Sketch Based Image Retrieval, CVPR 2021. Ayan Kumar Bhunia, Pinaki nath Chowdh

Ayan Kumar Bhunia 22 Aug 27, 2022
Unofficial PyTorch implementation of Neural Additive Models (NAM) by Agarwal, et al.

nam-pytorch Unofficial PyTorch implementation of Neural Additive Models (NAM) by Agarwal, et al. [abs, pdf] Installation You can access nam-pytorch vi

Rishabh Anand 11 Mar 14, 2022
Deep metric learning methods implemented in Chainer

Deep Metric Learning Implementation of several methods for deep metric learning in Chainer v4.2.0. Proxy-NCA: No Fuss Distance Metric Learning using P

ronekko 156 Nov 28, 2022
[ICCV'21] NEAT: Neural Attention Fields for End-to-End Autonomous Driving

NEAT: Neural Attention Fields for End-to-End Autonomous Driving Paper | Supplementary | Video | Poster | Blog This repository is for the ICCV 2021 pap

254 Jan 02, 2023
Point cloud processing tool library.

Point Cloud ToolBox This point cloud processing tool library can be used to process point clouds, 3d meshes, and voxels. Environment python 3.7.5 Dep

ZhangXinyun 40 Dec 09, 2022
converts nominal survey data into a numerical value based on a dictionary lookup.

SWAP RATE Converts nominal survey data into a numerical values based on a dictionary lookup. It allows the user to switch nominal scale data from text

Jake Rhodes 1 Jan 18, 2022
Caffe-like explicit model constructor. C(onfig)Model

cmodel Caffe-like explicit model constructor. C(onfig)Model Installation pip install git+https://github.com/bonlime/cmodel Usage In order to allow usi

1 Feb 18, 2022
Rethinking of Pedestrian Attribute Recognition: A Reliable Evaluation under Zero-Shot Pedestrian Identity Setting

Pytorch Pedestrian Attribute Recognition: A strong PyTorch baseline of pedestrian attribute recognition and multi-label classification.

Jian 79 Dec 18, 2022
Code for paper "Learning to Reweight Examples for Robust Deep Learning"

learning-to-reweight-examples Code for paper Learning to Reweight Examples for Robust Deep Learning. [arxiv] Environment We tested the code on tensorf

Uber Research 261 Jan 01, 2023
ECCV18 Workshops - Enhanced SRGAN. Champion PIRM Challenge on Perceptual Super-Resolution. The training codes are in BasicSR.

ESRGAN (Enhanced SRGAN) [ 🚀 BasicSR] [Real-ESRGAN] ✨ New Updates. We have extended ESRGAN to Real-ESRGAN, which is a more practical algorithm for rea

Xintao 4.7k Jan 02, 2023
Data augmentation for NLP, accepted at EMNLP 2021 Findings

AEDA: An Easier Data Augmentation Technique for Text Classification This is the code for the EMNLP 2021 paper AEDA: An Easier Data Augmentation Techni

Akbar Karimi 81 Dec 09, 2022
Brain tumor detection using CNN (InceptionResNetV2 Model)

Brain-Tumor-Detection Building a detection model using a convolutional neural network in Tensorflow & Keras. Used brain MRI images. InceptionResNetV2

1 Feb 13, 2022
Structural Constraints on Information Content in Human Brain States

Structural Constraints on Information Content in Human Brain States Code accompanying the paper "The information content of brain states is explained

Leon Weninger 3 Sep 07, 2022