JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library

Overview

JAX bindings to FINUFFT

This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.

Included features

This library is currently CPU-only, but GPU support is in the works using the cuFINUFFT library.

Type 1 and 2 transforms are supported in 1-, 2-, and 3-dimensions. All of these functions support forward, reverse, and higher-order differentiation, as well as batching using vmap.

Installation

For now, only a source build is supported.

For building, you should only need a recent version of Python (>3.6) and FFTW. At runtime, you'll need numpy, scipy, and jax. To set up such an environment, you can use conda (but you're welcome to use whatever workflow works for you!):

conda create -n jax-finufft -c conda-forge python=3.9 numpy scipy fftw
python -m pip install "jax[cpu]"

Then you can install from source using (don't forget the --recursive flag because FINUFFT is included as a submodule):

git clone --recursive https://github.com/dfm/jax-finufft
cd jax-finufft
python -m pip install .

Usage

This library provides two high-level functions (and these should be all that you generally need to interact with): nufft1 and nufft2 (for the two "types" of transforms). If you're already familiar with the Python interface to FINUFFT, please note that the function signatures here are different!

For example, here's how you can do a 1-dimensional type 1 transform:

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

Noting that the eps and iflag are optional, and that (for good reason, I promise!) the order of the positional arguments is reversed from the finufft Python package.

The syntax for a 2-, or 3-dimensional transform is:

f = nufft1((Nx, Ny), c, x, y)  # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z)  # 3D

The syntax for a type 2 transform is (also allowing optional iflag and eps parameters):

c = nufft2(f, x)  # 1D
c = nufft2(f, x, y)  # 2D
c = nufft2(f, x, y, z)  # 3D

Similar libraries

  • finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
  • mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.

License & attribution

This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:

Copyright 2021 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the FINUFFT docs.

Comments
  • batching issue

    batching issue

    Hi,

    I get the following error when I try to batch nufft2 in Jax.

    process_primitive(self, primitive, tracers, params) 161 frame = self.get_frame(vals_in, dims_in) 162 batched_primitive = self.get_primitive_batcher(primitive, frame) --> 163 val_out, dim_out = batched_primitive(vals_in, dims_in, **params) 164 if primitive.multiple_results: 165 return map(partial(BatchTracer, self), val_out, dim_out)

    TypeError: batch() got an unexpected keyword argument 'output_shape'

    it seems like this is caused by

    nufft2(source, iflag, eps, *points) 57 58 return jnp.reshape( ---> 59 nufft2_p.bind(source, *points, output_shape=None, iflag=iflag, eps=eps), 60 expected_output_shape, 61 )

    Is there something I am doing wrong?

    Thanks for your help!

    opened by samaktbo 13
  • problem batching functions that have multiple arguments

    problem batching functions that have multiple arguments

    Hi,

    I am posting this here to give a clearer description of the problem that I am having.

    When I run the following snippet, I would like to have A be a 4 by 100 array whose I-th row is the output of linear_func(q, X[I,:]).

    import numpy as np
    import jax.numpy as jnp
    from jax import vmap
    from jax_finufft import nufft2
    
    rng = np.random.default_rng(seed=314)
    
    d=10 
    L_tilde = 10
    L = 100
    
    qr = rng.standard_normal((d, L_tilde+1))
    qi = rng.standard_normal((d, L_tilde+1))
    q = jnp.array(qr + 1j * qi)
    X = jnp.array(rng.uniform(low=0.0, high=1.0, size=(4, L)))
    
    def linear_func(q, x):
      v = jnp.ones(shape=(1, L_tilde))
    
      return jnp.matmul(v, nufft2(q, x, eps=1e-6, iflag=-1))
    
    batched = vmap(linear_func, in_axes=(None, 0), out_axes=0)
    
    A = batched(q, X)
    

    However, when I run the snippet I get the error posted below. You had said last time that there could be a work around for unbatched arguments but I could not figure it out.

    Here is the error:

    UnfilteredStackTrace                      Traceback (most recent call last)
    <ipython-input-25-c23bc4fa7eb4> in <module>()
    ----> 1 test = batched(q, X)
    
    30 frames
    UnfilteredStackTrace: TypeError: '<' not supported between instances of 'NoneType' and 'int'
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/jax_finufft/ops.py in <genexpr>(.0)
        281     else:
        282         mx = args[0].ndim - ndim - 1
    --> 283     assert all(a < mx for a in axes)
        284     assert all(a == axes[0] for a in axes[1:])
        285     return prim.bind(*args, **kwargs), axes[0]
    
    TypeError: '<' not supported between instances of 'NoneType' and 'int'
    

    I hope this gives a clearer picture than what I had last time. Thanks so much for your help!

    opened by samaktbo 5
  • Installation issue

    Installation issue

    I am having trouble installing jax-finuff even with the instructions on the home page. The installation fails with

    ERROR: Failed building wheel for jax-finufft Failed to build jax-finufft ERROR: Could not build wheels for jax-finufft, which is required to install pyproject.toml-based projects

    opened by samaktbo 5
  • Error when differentiating nufft1 with respect to points only

    Error when differentiating nufft1 with respect to points only

    Hi Dan,

    First, thank you for releasing this package, I was very glad to find it!

    I noted the error below when attempting to differentiate nufft1 with respect to points. I am very new to JAX so I could be mistaken, but I don't believe this is the intended behavior:

    from jax_finufft import nufft1, nufft2
    import numpy as np
    import jax.numpy as jnp
    from jax import grad
    M = 100000
    N = 200000
    
    x = 2 * np.pi * np.random.uniform(size=M)
    c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
    
    def norm_nufft1(c,x):
        f = nufft1(N, c, x, eps=1e-6, iflag=1)
        return jnp.linalg.norm(f)
    
    def norm_nufft2(c,x):
        f = nufft2( c, x, eps=1e-6, iflag=1)
        return jnp.linalg.norm(f)
    
    grad(norm_nufft2,argnums =(1))(c,x) # Works fine
    grad(norm_nufft1,argnums =(0,))(c,x) # Works fine
    grad(norm_nufft1,argnums =(0,1))(c,x) # Works fine
    grad(norm_nufft1,argnums =(1))(c,x) # Throws error
    

    The error is below:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    /var/folders/dg/zzj57d7d1gs1k9l8sdfh043m0000gn/T/ipykernel_13430/3134504975.py in <module>
         22 grad(norm_nufft1,argnums =(0,))(c,x) # Works fine
         23 grad(norm_nufft1,argnums =(0,1))(c,x) # Works fine
    ---> 24 grad(norm_nufft1,argnums =(1))(c,x) # Throws error
         25 
         26 
    
        [... skipping hidden 10 frame]
    
    /var/folders/dg/zzj57d7d1gs1k9l8sdfh043m0000gn/T/ipykernel_13430/3134504975.py in norm_nufft1(c, x)
         11 
         12 def norm_nufft1(c,x):
    ---> 13     f = nufft1(N, c, x, eps=1e-6, iflag=1)
         14     return jnp.linalg.norm(f)
         15 
    
        [... skipping hidden 21 frame]
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax_finufft/ops.py in nufft1(output_shape, source, iflag, eps, *points)
         39 
         40     return jnp.reshape(
    ---> 41         nufft1_p.bind(source, *points, output_shape=output_shape, iflag=iflag, eps=eps),
         42         expected_output_shape,
         43     )
    
        [... skipping hidden 3 frame]
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax_finufft/ops.py in jvp(type_, prim, args, tangents, output_shape, iflag, eps)
        248 
        249         axis = -2 if type_ == 2 else -ndim - 1
    --> 250         output_tangent *= jnp.concatenate(jnp.broadcast_arrays(*scales), axis=axis)
        251 
        252         expand_shape = (
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in concatenate(arrays, axis)
       3405   if hasattr(arrays[0], "concatenate"):
       3406     return arrays[0].concatenate(arrays[1:], axis)
    -> 3407   axis = _canonicalize_axis(axis, ndim(arrays[0]))
       3408   arrays = _promote_dtypes(*arrays)
       3409   # lax.concatenate can be slow to compile for wide concatenations, so form a
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax/_src/util.py in canonicalize_axis(axis, num_dims)
        277   axis = operator.index(axis)
        278   if not -num_dims <= axis < num_dims:
    --> 279     raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
        280   if axis < 0:
        281     axis = axis + num_dims
    
    ValueError: axis -2 is out of bounds for array of dimension 1
    
    bug 
    opened by ma-gilles 4
  • Speed up `vmap`s where none of the points are batched

    Speed up `vmap`s where none of the points are batched

    If none of the points are batched in a vmap, we should be able to get a faster computation by stacking the transforms into a single transform and then reshaping. It might be worth making the stacked axes into an explicit parameter rather than just trying to infer it, but that would take some work on the interface.

    opened by dfm 0
  • WIP: Adding support for non-batched dimensions in `vmap`

    WIP: Adding support for non-batched dimensions in `vmap`

    In most cases, we'll just need to broadcast all the inputs out to the right shapes (this shouldn't be too hard), but when none of the points get mapped, we can get a bit of a speed up by stacking the transforms. This starts to implement that logic, but it's not quite ready yet.

    opened by dfm 0
  • Find a publication venue

    Find a publication venue

    @lgarrison and I have been chatting about the possibility of writing a paper describing what we're doing here. Something like "Differentiable programming with NUFFTs" or "NUFFTs for machine learning applications" or something. This issue is here to remind me to look into possible venues for this.

    opened by dfm 5
  • Starting to add GPU support using cuFINUFFT

    Starting to add GPU support using cuFINUFFT

    So far I just have the CMake definitions to compile cuFINUFFT when nvcc is found, but I haven't started writing the boilerplate needed to loop it into XLA. Coming soon!

    Keeping @lgarrison in the loop.

    opened by dfm 10
  • Support Type 3?

    Support Type 3?

    There aren't currently any plans to support the Type 3 transform for a few reasons:

    • I'm not totally sure of the use cases,
    • The logic will probably be somewhat more complicated than the existing implementations, and
    • cuFINUFFT doesn't seem to support Type 3.

    Do we want to work around these issues?

    opened by dfm 0
  • Add support for handling errors

    Add support for handling errors

    Currently, if any of the finufft methods fail with a non-zero error we just ignore it and keep on trucking. How should we handle this? It looks like the JAX convention is currently to set everything to NaN when an op fails, since propagating errors up from XLA can be a bit of a pain.

    opened by dfm 0
  • Add GPU support

    Add GPU support

    Via cufinufft. We'll need to figure out how to get XLA's handling of CUDA streams to play nice with cufinufft (this is way above my pay grade). Some references:

    1. A simple example of how to write an XLA compatible CUDA kernel: https://github.com/dfm/extending-jax/blob/main/lib/kernels.cc.cu
    2. The source code for how JAX wraps cuBLAS: https://github.com/google/jax/blob/main/jaxlib/cublas_kernels.cc
    3. There is a tensorflow implementation that might have some useful context: https://github.com/mrphys/tensorflow-nufft/blob/master/tensorflow_nufft/cc/kernels/nufft_kernels.cu.cc

    Perhaps @lgarrison is interested :D

    opened by dfm 0
Releases(v0.0.3)
  • v0.0.3(Dec 10, 2021)

    What's Changed

    • Fix segfault when batching multiple transforms by @dfm in https://github.com/dfm/jax-finufft/pull/11
    • Generalize the behavior of vmap by @dfm in https://github.com/dfm/jax-finufft/pull/12

    Full Changelog: https://github.com/dfm/jax-finufft/compare/v0.0.2...v0.0.3

    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Nov 12, 2021)

    • Faster differentiation using stacked transforms
    • Better error checking for vmap

    Full Changelog: https://github.com/dfm/jax-finufft/compare/v0.0.1...v0.0.2

    Source code(tar.gz)
    Source code(zip)
  • v0.0.1(Nov 8, 2021)

Owner
Dan Foreman-Mackey
Dan Foreman-Mackey
Toolbox of models, callbacks, and datasets for AI/ML researchers.

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch Website • Installation • Main

Pytorch Lightning 1.4k Dec 30, 2022
Official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

This is the official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

peng gao 42 Nov 26, 2022
git《USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation》(2020) GitHub: [fig2]

USD-Seg This project is an implement of paper USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation, based on FCOS detector f

Ruolin Ye 80 Nov 28, 2022
Code for the ICME 2021 paper "Exploring Driving-Aware Salient Object Detection via Knowledge Transfer"

TSOD Code for the ICME 2021 paper "Exploring Driving-Aware Salient Object Detection via Knowledge Transfer" Usage For training, open train_test, run p

Jinming Su 2 Dec 23, 2021
Tiny-NewsRec: Efficient and Effective PLM-based News Recommendation

Tiny-NewsRec The source codes for our paper "Tiny-NewsRec: Efficient and Effective PLM-based News Recommendation". Requirements PyTorch == 1.6.0 Tensor

Yang Yu 3 Dec 07, 2022
Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized Codes

Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized C

Sam Bond-Taylor 139 Jan 04, 2023
Official source code of paper 'IterMVS: Iterative Probability Estimation for Efficient Multi-View Stereo'

IterMVS official source code of paper 'IterMVS: Iterative Probability Estimation for Efficient Multi-View Stereo' Introduction IterMVS is a novel lear

Fangjinhua Wang 127 Jan 04, 2023
BC3407-Group-5-Project - BC3407 Group Project With Python

BC3407-Group-5-Project As the world struggles to contain the ever-changing varia

1 Jan 26, 2022
This is a model to classify Vietnamese sign language using Motion history image (MHI) algorithm and CNN.

Vietnamese sign lagnuage recognition using MHI and CNN This is a model to classify Vietnamese sign language using Motion history image (MHI) algorithm

Phat Pham 3 Feb 24, 2022
Distributing reference energies for SMIRNOFF implementations

Warning: This code is currently experimental and under active development. Is it not yet suitable for distribution or use as reference implementation.

Open Force Field Initiative 1 Dec 07, 2021
X-modaler is a versatile and high-performance codebase for cross-modal analytics.

X-modaler X-modaler is a versatile and high-performance codebase for cross-modal analytics. This codebase unifies comprehensive high-quality modules i

910 Dec 28, 2022
STARCH compuets regional extreme storm physical characteristics and moisture balance based on spatiotemporal precipitation data from reanalysis or climate model data.

STARCH (Storm Tracking And Regional CHaracterization) STARCH computes regional extreme storm physical and moisture balance characteristics based on sp

Onosama 7 Oct 20, 2022
Self-Supervised Multi-Frame Monocular Scene Flow (CVPR 2021)

Self-Supervised Multi-Frame Monocular Scene Flow 3D visualization of estimated depth and scene flow (overlayed with input image) from temporally conse

Visual Inference Lab @TU Darmstadt 85 Dec 22, 2022
Multi-Horizon-Forecasting-for-Limit-Order-Books

Multi-Horizon-Forecasting-for-Limit-Order-Books This jupyter notebook is used to demonstrate our work, Multi-Horizon Forecasting for Limit Order Books

Zihao Zhang 116 Dec 23, 2022
这是一个facenet-pytorch的库,可以用于训练自己的人脸识别模型。

Facenet:人脸识别模型在Pytorch当中的实现 目录 性能情况 Performance 所需环境 Environment 注意事项 Attention 文件下载 Download 预测步骤 How2predict 训练步骤 How2train 参考资料 Reference 性能情况 训练数据

Bubbliiiing 210 Jan 06, 2023
Detection of PCBA defect

Detection_of_PCBA_defect Detection_of_PCBA_defect Use yolov5 to train. $pip install -r requirements.txt Detect.py will detect file(jpg,mp4...) in cu

6 Nov 28, 2022
Code for the paper "Learning-Augmented Algorithms for Online Steiner Tree"

Learning-Augmented Algorithms for Online Steiner Tree This is the code for the paper "Learning-Augmented Algorithms for Online Steiner Tree". Requirem

0 Dec 09, 2021
Code for the head detector (HeadHunter) proposed in our CVPR 2021 paper Tracking Pedestrian Heads in Dense Crowd.

Head Detector Code for the head detector (HeadHunter) proposed in our CVPR 2021 paper Tracking Pedestrian Heads in Dense Crowd. The head_detection mod

Ramana Sundararaman 76 Dec 06, 2022
Galactic and gravitational dynamics in Python

Gala is a Python package for Galactic and gravitational dynamics. Documentation The documentation for Gala is hosted on Read the docs. Installation an

Adrian Price-Whelan 101 Dec 22, 2022
Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning The predictive learning of spatiotemporal sequences aims to generate future

THUML: Machine Learning Group @ THSS 243 Dec 26, 2022