Turning SymPy expressions into JAX functions

Overview

sympy2jax

.github/workflows/CI.yml

Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions.

All SymPy floats become trainable input parameters. SymPy symbols become columns of a passed matrix.

Installation

pip install git+https://github.com/MilesCranmer/sympy2jax.git

Example

import sympy
from sympy import symbols
import jax
import jax.numpy as jnp
from jax import random
from sympy2jax import sympy2jax

Let's create an expression in SymPy:

x, y = symbols('x y')
expression = 1.0 * sympy.cos(x) + 3.2 * y

Let's get the JAX version. We pass the equation, and the symbols required.

f, params = sympy2jax(expression, [x, y])

The order you supply the symbols is the same order you should supply the features when calling the function f (shape [nrows, nfeatures]). In this case, features=2 for x and y. The params in this case will be jnp.array([1.0, 3.2]). You pass these parameters when calling the function, which will let you change them and take gradients.

Let's generate some JAX data to pass:

key = random.PRNGKey(0)
X = random.normal(key, (10, 2))

We can call the function with:

f(X, params)

#> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
#                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
#                3.5427954 , -2.7479894 ], dtype=float32)

We can take gradients with respect to the parameters for each row with JAX gradient parameters now:

jac_f = jax.jacobian(f, argnums=1)
jac_f(X, params)

#> DeviceArray([[ 0.49364874, -0.9692889 ],
#               [ 0.8283714 , -0.0318858 ],
#               [-0.7447336 , -1.8784496 ],
#               [ 0.70755106, -0.3137085 ],
#               [ 0.944834  ,  1.767703  ],
#               [ 0.51673377,  1.4111717 ],
#               [ 0.87347716, -0.52637756],
#               [ 0.8760679 ,  1.0549792 ],
#               [ 0.9961824 ,  0.79581654],
#               [-0.88465923, -0.5822907 ]], dtype=float32)

We can also JIT-compile our function:

compiled_f = jax.jit(f)
compiled_f(X, params)

#> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
#                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
#                3.5427954 , -2.7479894 ], dtype=float32)
Owner
Miles Cranmer
Astro PhD candidate @princeton trying to accelerate astrophysics with AI. I build interpretable ML algorithms.
Miles Cranmer
Speech Emotion Recognition with Fusion of Acoustic- and Linguistic-Feature-Based Decisions

APSIPA-SER-with-A-and-T This code is the implementation of Speech Emotion Recognition (SER) with acoustic and linguistic features. The network model i

kenro515 3 Jan 04, 2023
YOLOv4-v3 Training Automation API for Linux

This repository allows you to get started with training a state-of-the-art Deep Learning model with little to no configuration needed! You provide your labeled dataset or label your dataset using our

BMW TechOffice MUNICH 626 Dec 31, 2022
Implementation of the "PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences" paper.

PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences Introduction Point cloud sequences are irregular and unordered in the spatial dimen

Hehe Fan 63 Dec 09, 2022
Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user who joins your server.

Discord-Protect Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user wh

Tir Omar 2 Oct 28, 2021
Code for Temporally Abstract Partial Models

Code for Temporally Abstract Partial Models Accompanies the code for the experimental section of the paper: Temporally Abstract Partial Models, Khetar

DeepMind 19 Jul 13, 2022
[ICCV' 21] "Unsupervised Point Cloud Pre-training via Occlusion Completion"

OcCo: Unsupervised Point Cloud Pre-training via Occlusion Completion This repository is the official implementation of paper: "Unsupervised Point Clou

Hanchen 204 Dec 24, 2022
A python module for configuration of block devices

Blivet is a python module for system storage configuration. CI status Licence See COPYING Installation From Fedora repositories Blivet is available in

78 Dec 14, 2022
Sky Computing: Accelerating Geo-distributed Computing in Federated Learning

Sky Computing Introduction Sky Computing is a load-balanced framework for federated learning model parallelism. It adaptively allocate model layers to

HPC-AI Tech 72 Dec 27, 2022
Super-BPD: Super Boundary-to-Pixel Direction for Fast Image Segmentation (CVPR 2020)

Super-BPD for Fast Image Segmentation (CVPR 2020) Introduction We propose direction-based super-BPD, an alternative to superpixel, for fast generic im

189 Dec 07, 2022
PyTorch implementation of the ACL, 2021 paper Parameter-efficient Multi-task Fine-tuning for Transformers via Shared Hypernetworks.

Parameter-efficient Multi-task Fine-tuning for Transformers via Shared Hypernetworks This repo contains the PyTorch implementation of the ACL, 2021 pa

Rabeeh Karimi Mahabadi 98 Dec 28, 2022
PyTorch implementation of Deformable Convolution

PyTorch implementation of Deformable Convolution !!!Warning: There is some issues in this implementation and this repo is not maintained any more, ple

Wei Ouyang 893 Dec 18, 2022
Code for CPM-2 Pre-Train

CPM-2 Pre-Train Pre-train CPM-2 此分支为110亿非 MoE 模型的预训练代码,MoE 模型的预训练代码请切换到 moe 分支 CPM-2技术报告请参考link。 0 模型下载 请在智源资源下载页面进行申请,文件介绍如下: 文件名 描述 参数大小 100000.tar

Tsinghua AI 136 Dec 28, 2022
ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers

ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers Official implementation of ViewFormer. ViewFormer is a NeRF-free neural rend

Jonáš Kulhánek 169 Dec 30, 2022
EdMIPS: Rethinking Differentiable Search for Mixed-Precision Neural Networks

EdMIPS is an efficient algorithm to search the optimal mixed-precision neural network directly without proxy task on ImageNet given computation budgets. It can be applied to many popular network arch

Zhaowei Cai 47 Dec 30, 2022
Efficient and Scalable Physics-Informed Deep Learning and Scientific Machine Learning on top of Tensorflow for multi-worker distributed computing

Notice: Support for Python 3.6 will be dropped in v.0.2.1, please plan accordingly! Efficient and Scalable Physics-Informed Deep Learning Collocation-

tensordiffeq 74 Dec 09, 2022
Simple is not Easy: A Simple Strong Baseline for TextVQA and TextCaps[AAAI2021]

Simple is not Easy: A Simple Strong Baseline for TextVQA and TextCaps Here is the code for ssbassline model. We also provide OCR results/features/mode

ZephyrZhuQi 51 Nov 18, 2022
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma This repo provi

Jingtao Zhan 99 Dec 27, 2022
Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics.

Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics. By Andres Milioto @ University of Bonn. (for the new P

Photogrammetry & Robotics Bonn 314 Dec 30, 2022
AbelNN: Deep Learning Python module from scratch

AbelNN: Deep Learning Python module from scratch I have implemented several neural networks from scratch using only Numpy. I have designed the module

Abel 2 Apr 12, 2022
This was initially the repo for the project of [email protected] of Asaf Mazar, Millad Kassaie and Georgios Chochlakis named "Powered by the Will? Exploring Lay Theories of Behavior Change through Social Media"

Subreddit Analysis This repo includes tools for Subreddit analysis, originally developed for our class project of PSYC 626 in USC, titled "Powered by

Georgios Chochlakis 1 Dec 17, 2021