Jaxtorch (a jax nn library)

Related tags

Deep Learningjaxtorch
Overview

Jaxtorch (a jax nn library)

This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular jax frameworks (flax, haiku).

The objective is to enable pytorch-like model definition and training with a minimum of magic. Simple example:

import jax
import jax.numpy as jnp
import jaxlib
import jaxtorch

# Modules are just classes that inherit jaxtorch.Module
class Linear(jaxtorch.Module):
    # They can accept any constructor parameters
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        # Parameters are represented by a Param type, which identifies
        # them, and specifies how to initialize them.
        self.weight = jaxtorch.init.glorot_normal(out_features, in_features)
        assert type(self.weight) is jaxtorch.Param
        if bias:
            self.bias = jaxtorch.init.zeros(out_features)
        else:
            self.bias = None

    # The forward function accepts cx, a Context object as the first argument
    # always. This provides random number generation as well as the parameters.
    def forward(self, cx: jaxtorch.Context, x):
        # Parameters are looked up in the context using the stored identifier.
        y = x @ jnp.transpose(cx[self.weight])
        if self.bias:
            y = y + cx[self.bias]
        return y

model = Linear(3, 3)

# You initialize the weights by passing a RNG key.
# Calling init_weights also names all the parameters in the Module tree.
params = model.init_weights(jax.random.PRNGKey(0))

# Parameters are stored in a dictionary by name.
assert type(params) is dict
assert type(params[model.weight.name]) is jaxlib.xla_extension.DeviceArray
assert model.weight.name == 'weight'

def loss(params, key):
    cx = jaxtorch.Context(params, key)
    x = jnp.array([1.0,2.0,3.0])
    y = jnp.array([4.0,5.0,6.0])
    return jnp.mean((model(cx, x) - y)**2)
f_grad = jax.value_and_grad(loss)

for _ in range(100):
    (loss, grad) = f_grad(params, jax.random.PRNGKey(0))
    params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grad)
print(loss)
# 4.7440533e-08
Owner
nshepperd
nshepperd
Neighborhood Contrastive Learning for Novel Class Discovery

Neighborhood Contrastive Learning for Novel Class Discovery This repository contains the official implementation of our paper: Neighborhood Contrastiv

Zhun Zhong 56 Dec 09, 2022
Gesture Volume Control v.2

Gesture volume control v.2 In this project I am going to learn how to use Gesture Control to change the volume of a computer. I first look into hand t

Pavel Dat 23 Dec 26, 2022
Paper Code:A Self-adaptive Weighted Differential Evolution Approach for Large-scale Feature Selection

1. SaWDE.m is the main function 2. DataPartition.m is used to randomly partition the original data into training sets and test sets with a ratio of 7

wangxb 14 Dec 08, 2022
Self-training for Few-shot Transfer Across Extreme Task Differences

Self-training for Few-shot Transfer Across Extreme Task Differences (STARTUP) Introduction This repo contains the official implementation of the follo

Cheng Perng Phoo 33 Oct 31, 2022
Camera calibration & 3D pose estimation tools for AcinoSet

AcinoSet: A 3D Pose Estimation Dataset and Baseline Models for Cheetahs in the Wild Daniel Joska, Liam Clark, Naoya Muramatsu, Ricardo Jericevich, Fre

African Robotics Unit 42 Nov 16, 2022
Implementation of ConvMixer for "Patches Are All You Need? 🤷"

Patches Are All You Need? 🤷 This repository contains an implementation of ConvMixer for the ICLR 2022 submission "Patches Are All You Need?" by Asher

CMU Locus Lab 934 Jan 08, 2023
Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks.

Self Supervised Learning with Fastai Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks. Install pip install self-

Kerem Turgutlu 276 Dec 23, 2022
PyTorch Implementation of SSTNs for hyperspectral image classifications from the IEEE T-GRS paper "Spectral-Spatial Transformer Network for Hyperspectral Image Classification: A FAS Framework."

PyTorch Implementation of SSTN for Hyperspectral Image Classification Paper links: SSTN published on IEEE T-GRS. Also, you can directly find the imple

Zilong Zhong 54 Dec 19, 2022
NeoDTI: Neural integration of neighbor information from a heterogeneous network for discovering new drug-target interactions

NeoDTI NeoDTI: Neural integration of neighbor information from a heterogeneous network for discovering new drug-target interactions (Bioinformatics).

62 Nov 26, 2022
Official implementation of "An Image is Worth 16x16 Words, What is a Video Worth?" (2021 paper)

An Image is Worth 16x16 Words, What is a Video Worth? paper Official PyTorch Implementation Gilad Sharir, Asaf Noy, Lihi Zelnik-Manor DAMO Academy, Al

213 Nov 12, 2022
official implemntation for "Contrastive Learning with Stronger Augmentations"

CLSA CLSA is a self-supervised learning methods which focused on the pattern learning from strong augmentations. Copyright (C) 2020 Xiao Wang, Guo-Jun

Lab for MAchine Perception and LEarning (MAPLE) 47 Nov 29, 2022
Stream images from a connected camera over MQTT, view using Streamlit, record to file and sqlite

mqtt-camera-streamer Summary: Publish frames from a connected camera or MJPEG/RTSP stream to an MQTT topic, and view the feed in a browser on another

Robin Cole 183 Dec 16, 2022
AirCode: A Robust Object Encoding Method

AirCode This repo contains source codes for the arXiv preprint "AirCode: A Robust Object Encoding Method" Demo Object matching comparison when the obj

Chen Wang 30 Dec 09, 2022
Winners of the Facebook Image Similarity Challenge

Winners of the Facebook Image Similarity Challenge

DrivenData 111 Jan 05, 2023
Inteligência artificial criada para realizar interação social com idosos.

IA SONIA 4.0 A SONIA foi inspirada no assistente mais famoso do mundo e muito bem conhecido JARVIS. Todo mundo algum dia ja sonhou em ter o seu própri

Vinícius Azevedo 2 Oct 21, 2021
基于深度强化学习的原神自动钓鱼AI

原神自动钓鱼AI由YOLOX, DQN两部分模型组成。使用迁移学习,半监督学习进行训练。 模型也包含一些使用opencv等传统数字图像处理方法实现的不可学习部分。

4.2k Jan 01, 2023
Udacity Suse Cloud Native Foundations Scholarship Course Walkthrough

SUSE Cloud Native Foundations Scholarship Udacity is collaborating with SUSE, a global leader in true open source solutions, to empower developers and

Shivansh Srivastava 34 Oct 18, 2022
SwinIR: Image Restoration Using Swin Transformer

SwinIR: Image Restoration Using Swin Transformer This repository is the official PyTorch implementation of SwinIR: Image Restoration Using Shifted Win

Jingyun Liang 2.4k Jan 08, 2023
HugsVision is a easy to use huggingface wrapper for state-of-the-art computer vision

HugsVision is an open-source and easy to use all-in-one huggingface wrapper for computer vision. The goal is to create a fast, flexible and user-frien

Labrak Yanis 166 Nov 27, 2022
Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic video-to-video translation.

vid2vid Project | YouTube(short) | YouTube(full) | arXiv | Paper(full) Pytorch implementation for high-resolution (e.g., 2048x1024) photorealistic vid

NVIDIA Corporation 8.1k Jan 01, 2023