Turning SymPy expressions into JAX functions

Overview

sympy2jax

.github/workflows/CI.yml

Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions.

All SymPy floats become trainable input parameters. SymPy symbols become columns of a passed matrix.

Installation

pip install git+https://github.com/MilesCranmer/sympy2jax.git

Example

import sympy
from sympy import symbols
import jax
import jax.numpy as jnp
from jax import random
from sympy2jax import sympy2jax

Let's create an expression in SymPy:

x, y = symbols('x y')
expression = 1.0 * sympy.cos(x) + 3.2 * y

Let's get the JAX version. We pass the equation, and the symbols required.

f, params = sympy2jax(expression, [x, y])

The order you supply the symbols is the same order you should supply the features when calling the function f (shape [nrows, nfeatures]). In this case, features=2 for x and y. The params in this case will be jnp.array([1.0, 3.2]). You pass these parameters when calling the function, which will let you change them and take gradients.

Let's generate some JAX data to pass:

key = random.PRNGKey(0)
X = random.normal(key, (10, 2))

We can call the function with:

f(X, params)

#> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
#                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
#                3.5427954 , -2.7479894 ], dtype=float32)

We can take gradients with respect to the parameters for each row with JAX gradient parameters now:

jac_f = jax.jacobian(f, argnums=1)
jac_f(X, params)

#> DeviceArray([[ 0.49364874, -0.9692889 ],
#               [ 0.8283714 , -0.0318858 ],
#               [-0.7447336 , -1.8784496 ],
#               [ 0.70755106, -0.3137085 ],
#               [ 0.944834  ,  1.767703  ],
#               [ 0.51673377,  1.4111717 ],
#               [ 0.87347716, -0.52637756],
#               [ 0.8760679 ,  1.0549792 ],
#               [ 0.9961824 ,  0.79581654],
#               [-0.88465923, -0.5822907 ]], dtype=float32)

We can also JIT-compile our function:

compiled_f = jax.jit(f)
compiled_f(X, params)

#> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
#                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
#                3.5427954 , -2.7479894 ], dtype=float32)
Owner
Miles Cranmer
Astro PhD candidate @princeton trying to accelerate astrophysics with AI. I build interpretable ML algorithms.
Miles Cranmer
WRENCH: Weak supeRvision bENCHmark

🔧 What is it? Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development

Jieyu Zhang 176 Dec 28, 2022
Reimplementation of Learning Mesh-based Simulation With Graph Networks

Pytorch Implementation of Learning Mesh-based Simulation With Graph Networks This is the unofficial implementation of the approach described in the pa

Jingwei Xu 33 Dec 14, 2022
ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives

Status: Under development (expect bug fixes and huge updates) ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectiv

37 Dec 28, 2022
This project is based on RIFE and aims to make RIFE more practical for users by adding various features and design new models

CPM 项目描述 CPM(Chinese Pretrained Models)模型是北京智源人工智能研究院和清华大学发布的中文大规模预训练模型。官方发布了三种规模的模型,参数量分别为109M、334M、2.6B,用户需申请与通过审核,方可下载。 由于原项目需要考虑大模型的训练和使用,需要安装较为复杂

hzwer 190 Jan 08, 2023
Voxel-based Network for Shape Completion by Leveraging Edge Generation (ICCV 2021, oral)

Voxel-based Network for Shape Completion by Leveraging Edge Generation This is the PyTorch implementation for the paper "Voxel-based Network for Shape

10 Dec 04, 2022
TICC is a python solver for efficiently segmenting and clustering a multivariate time series

TICC TICC is a python solver for efficiently segmenting and clustering a multivariate time series. It takes as input a T-by-n data matrix, a regulariz

406 Dec 12, 2022
Tensorflow python implementation of "Learning High Fidelity Depths of Dressed Humans by Watching Social Media Dance Videos"

Learning High Fidelity Depths of Dressed Humans by Watching Social Media Dance Videos This repository is the official tensorflow python implementation

Yasamin Jafarian 287 Jan 06, 2023
Using LSTM write Tang poetry

本教程将通过一个示例对LSTM进行介绍。通过搭建训练LSTM网络,我们将训练一个模型来生成唐诗。本文将对该实现进行详尽的解释,并阐明此模型的工作方式和原因。并不需要过多专业知识,但是可能需要新手花一些时间来理解的模型训练的实际情况。为了节省时间,请尽量选择GPU进行训练。

56 Dec 15, 2022
Recreate CenternetV2 based on MMDET.

Introduction This project is trying to Recreate CenternetV2 based on MMDET, which is proposed in paper Probabilistic two-stage detection. This project

25 Dec 09, 2022
Low-dose Digital Mammography with Deep Learning

Impact of loss functions on the performance of a deep neural network designed to restore low-dose digital mammography ====== This repository contains

WANG-AXIS 6 Dec 13, 2022
TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

FunMatch-Distillation TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A g

Sayak Paul 67 Dec 20, 2022
A certifiable defense against adversarial examples by training neural networks to be provably robust

DiffAI v3 DiffAI is a system for training neural networks to be provably robust and for proving that they are robust. The system was developed for the

SRI Lab, ETH Zurich 202 Dec 13, 2022
OMNIVORE is a single vision model for many different visual modalities

Omnivore: A Single Model for Many Visual Modalities [paper][website] OMNIVORE is a single vision model for many different visual modalities. It learns

Meta Research 451 Dec 27, 2022
Using deep learning to predict gene structures of the coding genes in DNA sequences of Arabidopsis thaliana

DeepGeneAnnotator: A tool to annotate the gene in the genome The master thesis of the "Using deep learning to predict gene structures of the coding ge

Ching-Tien Wang 3 Sep 09, 2022
Code for SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes (NeurIPS 2021)

SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes (NeurIPS 2021) SyncTwin is a treatment effect estimation method tailored for observat

Zhaozhi Qian 3 Nov 03, 2022
Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition

Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition How Fast Compare to Other Zero-Shot NAS Proxies on CIFAR-10/100 Pre-trained Model

190 Dec 29, 2022
Clustering is a popular approach to detect patterns in unlabeled data

Visual Clustering Clustering is a popular approach to detect patterns in unlabeled data. Existing clustering methods typically treat samples in a data

Tarek Naous 24 Nov 11, 2022
A library for finding knowledge neurons in pretrained transformer models.

knowledge-neurons An open source repository replicating the 2021 paper Knowledge Neurons in Pretrained Transformers by Dai et al., and extending the t

EleutherAI 96 Dec 21, 2022
Small-bets - Ergodic Experiment With Python

Ergodic Experiment Based on this video. Run this experiment with this command: p

Michael Brant 3 Jan 11, 2022
[ICCV 2021 Oral] NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo

NerfingMVS Project Page | Paper | Video | Data NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo Yi Wei, Shaohui

Yi Wei 369 Dec 24, 2022