Geometric Algebra package for JAX

Overview

JAXGA - JAX Geometric Algebra

Build status PyPI

GitHub | Docs

JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing only the non-zero basis blade coefficients. It makes use of JAX's just-in-time (JIT) compilation by first precomputing blade indices and signs and then JITting the function doing the actual calculations.

Installation

Install using pip: pip install jaxga

Requirements:

Usage

Unlike most other Geometric Algebra packages, it is not necessary to pre-specify an algebra. JAXGA can either be used with the MultiVector class or by using lower-level functions which is useful for example when using JAX's jit or automatic differentaition.

The MultiVector class provides operator overloading and is constructed with an array of values and their corresponding basis blades. The basis blades are encoded as tuples, for example the multivector 2 e_1 + 4 e_23 would have the values [2, 4] and the basis blade tuple ((1,), (2, 3)).

MultiVector example

import jax.numpy as jnp
from jaxga.mv import MultiVector

a = MultiVector(
    values=2 * jnp.ones([1], dtype=jnp.float32),
    indices=((1,),)
)
# Alternative: 2 * MultiVector.e(1)

b = MultiVector(
    values=4 * jnp.ones([2], dtype=jnp.float32),
    indices=((2, 3),)
)
# Alternative: 4 * MultiVector.e(2, 3)

c = a * b
print(c)

Output: Multivector(8.0 e_{1, 2, 3})

The lower-level functions also deal with values and blades. Functions are provided that take the blades and return a function that does the actual calculation. The returned function is JITted and can also be automatically differentiated with JAX. Furthermore, some operations like the geometric product take a signature function that takes a basis vector index and returns their square.

Lower-level function example

import jax.numpy as jnp
from jaxga.signatures import positive_signature
from jaxga.ops.multiply import get_mv_multiply

a_values = 2 * jnp.ones([1], dtype=jnp.float32)
a_indices = ((1,),)

b_values = 4 * jnp.ones([1], dtype=jnp.float32)
b_indices = ((2, 3),)

mv_multiply, c_indices = get_mv_multiply(a_indices, b_indices, positive_signature)
c_values = mv_multiply(a_values, b_values)
print("C indices:", c_indices, "C values:", c_values)

Output: C indices: ((1, 2, 3),) C values: [8.]

Some notes

  • Both the MultiVector and lower-level function approaches support batches: the axes after the first one (which indexes the basis blades) are treated as batch indices.
  • The MultiVector class can also take a signature in its constructor (default is square to 1 for all basis vectors). Doing operations with MultiVectors with different signatures is undefined.
  • The jaxga.signatures submodule contains a few predefined signature functions.
  • get_mv_multiply and similar functions cache their result by their inputs.
  • The flaxmodules submodule provides flax (a popular neural network library for jax) modules with Geometric Algebra operations.
  • Because we don't deal with a specific algebra, the dual needs an input that specifies the dimensionality of the space in which we want to find the dual element.

Benchmarks

N-d vector * N-d vector, batch size 100, N=range(1, 10), CPU

JaxGA stores only the non-zero basis blade coefficients. TFGA and Clifford on the other hand store all GA elements as full multivectors including all zeros. As a result, JaxGA does better than these for high dimensional algebras.

Below is a benchmark of the geometric product of two vectors with increasing dimensionality from 1 to 9. 100 vectors are multiplied at a time.

JAXGA (CPU) tfga (CPU) clifford
benchmark-results benchmark-results benchmark-results

N-d vector * N-d vector, batch size 100, N=range(1, 50, 5), CPU

Below is a benchmark for higher dimensions that TFGA and Clifford could not handle. Note that the X axis isn't sorted naturally.

benchmark-results

Owner
Robin Kahlow
Software / Machine Learning Engineer
Robin Kahlow
Robust fine-tuning of zero-shot models

Robust fine-tuning of zero-shot models This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabri

224 Dec 29, 2022
x-transformers-paddle 2.x version

x-transformers-paddle x-transformers-paddle 2.x version paddle 2.x版本 https://github.com/lucidrains/x-transformers 。 requirements paddlepaddle-gpu==2.2

yujun 7 Dec 08, 2022
HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis

HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis Jungil Kong, Jaehyeon Kim, Jaekyoung Bae In our paper, we p

Rishikesh (ऋषिकेश) 31 Dec 08, 2022
Low-dose Digital Mammography with Deep Learning

Impact of loss functions on the performance of a deep neural network designed to restore low-dose digital mammography ====== This repository contains

WANG-AXIS 6 Dec 13, 2022
This repository is based on Ultralytics/yolov5, with adjustments to enable polygon prediction boxes.

Polygon-Yolov5 This repository is based on Ultralytics/yolov5, with adjustments to enable polygon prediction boxes. Section I. Description The codes a

xinzelee 226 Jan 05, 2023
Road Crack Detection Using Deep Learning Methods

Road-Crack-Detection-Using-Deep-Learning-Methods This is my Diploma Thesis ¨Road Crack Detection Using Deep Learning Methods¨ under the supervision of

Aggelos Katsaliros 3 May 03, 2022
FedMM: Saddle Point Optimization for Federated Adversarial Domain Adaptation

This repository contains the code accompanying the paper " FedMM: Saddle Point Optimization for Federated Adversarial Domain Adaptation" Paper link: R

20 Jun 29, 2022
ThunderGBM: Fast GBDTs and Random Forests on GPUs

Documentations | Installation | Parameters | Python (scikit-learn) interface What's new? ThunderGBM won 2019 Best Paper Award from IEEE Transactions o

Xtra Computing Group 647 Jan 04, 2023
PERIN is Permutation-Invariant Semantic Parser developed for MRP 2020

PERIN: Permutation-invariant Semantic Parsing David Samuel & Milan Straka Charles University Faculty of Mathematics and Physics Institute of Formal an

ÚFAL 40 Jan 04, 2023
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Pawel Dziemiach 1 Dec 18, 2021
Learning Representations that Support Robust Transfer of Predictors

Transfer Risk Minimization (TRM) Code for Learning Representations that Support Robust Transfer of Predictors Prepare the Datasets Preprocess the Scen

Yilun Xu 15 Dec 07, 2022
Official pytorch implementation of DeformSyncNet: Deformation Transfer via Synchronized Shape Deformation Spaces

DeformSyncNet: Deformation Transfer via Synchronized Shape Deformation Spaces Minhyuk Sung*, Zhenyu Jiang*, Panos Achlioptas, Niloy J. Mitra, Leonidas

Zhenyu Jiang 21 Aug 30, 2022
Binary Passage Retriever (BPR) - an efficient passage retriever for open-domain question answering

BPR Binary Passage Retriever (BPR) is an efficient neural retrieval model for open-domain question answering. BPR integrates a learning-to-hash techni

Studio Ousia 147 Dec 07, 2022
PyTorch implementation of our paper How robust are discriminatively trained zero-shot learning models?

How robust are discriminatively trained zero-shot learning models? This repository contains the PyTorch implementation of our paper How robust are dis

Mehmet Kerim Yucel 5 Feb 04, 2022
Prototypical Networks for Few shot Learning in PyTorch

Prototypical Networks for Few shot Learning in PyTorch Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code)

Orobix 835 Jan 08, 2023
The PyTorch improved version of TPAMI 2017 paper: Face Alignment in Full Pose Range: A 3D Total Solution.

Face Alignment in Full Pose Range: A 3D Total Solution By Jianzhu Guo. [Updates] 2020.8.30: The pre-trained model and code of ECCV-20 are made public

Jianzhu Guo 3.4k Jan 02, 2023
Keras documentation, hosted live at keras.io

Keras.io documentation generator This repository hosts the code used to generate the keras.io website. Generating a local copy of the website pip inst

Keras 2k Jan 08, 2023
Hierarchical Time Series Forecasting with a familiar API

scikit-hts Hierarchical Time Series with a familiar API. This is the result from not having found any good implementations of HTS on-line, and my work

Carlo Mazzaferro 204 Dec 17, 2022
Viperdb - A tiny log-structured key-value database written in pure Python

ViperDB 🐍 ViperDB is a lightweight embedded key-value store written in pure Pyt

17 Oct 17, 2022
A toolkit for Lagrangian-based constrained optimization in Pytorch

Cooper About Cooper is a toolkit for Lagrangian-based constrained optimization in Pytorch. This library aims to encourage and facilitate the study of

Cooper 34 Jan 01, 2023