Skip to content

flatironinstitute/jax-finufft

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX bindings to FINUFFT

GitHub Tests Jenkins Tests

This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.

Included features

This library includes CPU and GPU (CUDA) support. GPU support is implemented through the cuFINUFFT interface of the FINUFFT library.

Type 1 and 2 transforms are supported in 1-, 2-, and 3-dimensions. All of these functions support forward, reverse, and higher-order differentiation, as well as batching using vmap.

Installation

The easiest ways to install jax-finufft is to install a pre-compiled binary from PyPI or conda-forge, but if you need GPU support or want to get tuned performance, you'll want to follow the instructions to install from source as described below.

Install binary from PyPI

Note

Only the CPU-enabled build of jax-finufft is available as a binary wheel on PyPI. For a GPU-enabled build, you'll need to build from source as described below.

To install a binary wheel from PyPI using pip, run the following commands:

python -m pip install "jax[cpu]"
python -m pip install jax-finufft

If this fails, you may need to use a conda-forge binary, or install from source.

Install binary from conda-forge

Note

Only the CPU-enabled build of jax-finufft is available as a binary from conda-forge. For a GPU-enabled build, you'll need to build from source as described below.

To install using mamba (or conda), run:

mamba install -c conda-forge jax-finufft

Install from source

Dependencies

Unsurprisingly, a key dependency is JAX, which can be installed following the directions in the JAX documentation. If you're going to want to run on a GPU, make sure that you install the appropriate JAX build.

The non-Python dependencies that you'll need are:

  • FFTW,
  • OpenMP (for CPU, optional),
  • CUDA (for GPU, >= 11.8), and
  • cuDNN (for GPU).

Older versions of CUDA may work, but they are untested.

Below we provide some example workflows for installing the required dependencies:

Install CPU dependencies with mamba or conda
mamba create -n jax-finufft -c conda-forge python jax fftw cxx-compiler
mamba activate jax-finufft
Install GPU dependencies with mamba or conda

For a GPU build, while the CUDA libraries and compiler are nominally available through conda, our experience trying to install them this way suggests that the "traditional" way of obtaining the CUDA Toolkit directly from NVIDIA may work best (see related advice for Horovod). After installing the CUDA Toolkit, one can set up the rest of the dependencies with:

mamba create -n gpu-jax-finufft -c conda-forge python numpy scipy fftw 'gxx<12'
mamba activate gpu-jax-finufft
export CMAKE_PREFIX_PATH=$CONDA_PREFIX:$CMAKE_PREFIX_PATH
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Other ways of installing JAX are given on the JAX website; the "local CUDA" install methods are preferred for jax-finufft as this ensures the CUDA extensions are compiled with the same Toolkit version as the CUDA runtime.

Install GPU dependencies using Flatiron module system
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cuda/11
ml cudnn
ml nccl

export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"

GPU build configuration

You'll need to configure your build to select the appropriate CUDA architecture(s) using the environment variable CMAKE_ARGS. To query your GPU's CUDA architecture (compute capability), you can run:

$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
7.0

This corresponds to CMAKE_CUDA_ARCHITECTURES=70, i.e.:

export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"

Note that the pip installation below uses CMake, so CMAKE_ARGS has to be set before then, but is not needed at runtime.

At runtime, you may also need:

export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"

If CUDA_PATH isn't set, you'll need to replace it with the path to your CUDA installation in the above line, often something like /usr/local/cuda.

Install source from PyPI

The source code for all released versions of jax-finufft are available on PyPI, and this can be installed using:

python -m pip install --no-binary jax-finufft

Install source from GitHub

Alternatively, you can check out the source repository from GitHub:

git clone --recurse-submodules https://github.com/flatironinstitute/jax-finufft
cd jax-finufft

Note

Don't forget the --recurse-submodules argument when cloning the repo because the upstream FINUFFT library is included as a git submodule. If you do forget, you can run git submodule update --init --recursive in your local copy to checkout the submodule after the initial clone.

After cloning the repository, you can install the local copy using:

python -m pip install -e .

where the -e flag optionally runs an "editable" install.

Usage

This library provides two high-level functions (and these should be all that you generally need to interact with): nufft1 and nufft2 (for the two "types" of transforms). If you're already familiar with the Python interface to FINUFFT, please note that the function signatures here are different!

For example, here's how you can do a 1-dimensional type 1 transform:

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

Warning

As described in the FINUFFT documentation, the non-uniform points must lie within the range [-3pi, 3pi], but this is not checked, because JAX currently doesn't have a good interface for runtime value checking. Unexpected crashes may occur if this condition is not met.

Noting that the eps and iflag are optional, and that (for good reason, I promise!) the order of the positional arguments is reversed from the finufft Python package.

The syntax for a 2-, or 3-dimensional transform is:

f = nufft1((Nx, Ny), c, x, y)  # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z)  # 3D

The syntax for a type 2 transform is (also allowing optional iflag and eps parameters):

c = nufft2(f, x)  # 1D
c = nufft2(f, x, y)  # 2D
c = nufft2(f, x, y, z)  # 3D

All of these functions support batching using vmap, and forward and reverse mode differentiation.

Advanced usage

The tuning parameters for the library can be set using the opts parameter to nufft1 and nufft2. For example, to explicitly set the CPU up-sampling factor that FINUFFT should use, you can update the example from above as follows:

from jax_finufft import options

opts = options.Opts(upsampfac=2.0)
nufft1(N, c, x, opts=opts)

The corresponding option for the GPU is gpu_upsampfac. In fact, all options for the GPU are prefixed with gpu_.

One complication here is that the vector-Jacobian product for a NUFFT requires evaluating a NUFFT of a different type. This means that you might want to separately tune the options for the forward and backward pass. This can be achieved using the options.NestedOpts interface. For example, to use a different up-sampling factor for the forward and backward passes, the code from above becomes:

import jax

opts = options.NestedOpts(
  forward=options.Opts(upsampfac=2.0),
  backward=options.Opts(upsampfac=1.25),
)
jax.grad(lambda args: nufft1(N, *args, opts=opts).real.sum())((c, x))

or, in this case equivalently:

opts = options.NestedOpts(
  type1=options.Opts(upsampfac=2.0),
  type2=options.Opts(upsampfac=1.25),
)

See the FINUFFT docs for descriptions of all the CPU tuning parameters. The corresponding GPU parameters are currently only listed in source code form in cufinufft_opts.h.

Similar libraries

  • finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
  • mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.

License & attribution

This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:

Copyright 2021, 2022, 2023 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the FINUFFT docs.