Fast and Easy Infinite Neural Networks in Python

Overview

Neural Tangents

ICLR 2020 Video | Paper | Quickstart | Install guide | Reference docs | Release notes

PyPI PyPI - Python Version Build Status Readthedocs PyPI - License

Overview

Neural Tangents is a high-level neural network API for specifying complex, hierarchical, neural networks of both finite and infinite width. Neural Tangents allows researchers to define, train, and evaluate infinite networks as easily as finite ones.

Infinite (in width or channel count) neural networks are Gaussian Processes (GPs) with a kernel function determined by their architecture (see References for details and nuances of this correspondence).

Neural Tangents allows you to construct a neural network model with the usual building blocks like convolutions, pooling, residual connections, nonlinearities etc. and obtain not only the finite model, but also the kernel function of the respective GP.

The library is written in python using JAX and leveraging XLA to run out-of-the-box on CPU, GPU, or TPU. Kernel computation is highly optimized for speed and memory efficiency, and can be automatically distributed over multiple accelerators with near-perfect scaling.

Neural Tangents is a work in progress. We happily welcome contributions!

Contents

Colab Notebooks

An easy way to get started with Neural Tangents is by playing around with the following interactive notebooks in Colaboratory. They demo the major features of Neural Tangents and show how it can be used in research.

Installation

To use GPU, first follow JAX's GPU installation instructions. Otherwise, install JAX on CPU by running

pip install jax jaxlib --upgrade

Once JAX is installed install Neural Tangents by running

pip install neural-tangents

or, to use the bleeding-edge version from GitHub source,

git clone https://github.com/google/neural-tangents; cd neural-tangents
pip install -e .

You can now run the examples (using tensorflow_datasets) and tests by calling:

pip install tensorflow tensorflow-datasets more-itertools --upgrade

python examples/infinite_fcn.py
python examples/weight_space.py
python examples/function_space.py

set -e; for f in tests/*.py; do python $f; done

5-Minute intro

See this Colab for a detailed tutorial. Below is a very quick introduction.

Our library closely follows JAX's API for specifying neural networks, stax. In stax a network is defined by a pair of functions (init_fn, apply_fn) initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3-layer network and computing it's outputs y given inputs x.

from jax import random
from jax.experimental import stax

init_fn, apply_fn = stax.serial(
    stax.Dense(512), stax.Relu,
    stax.Dense(512), stax.Relu,
    stax.Dense(1)
)

key = random.PRNGKey(1)
x = random.normal(key, (10, 100))
_, params = init_fn(key, input_shape=x.shape)

y = apply_fn(params, x)  # (10, 1) np.ndarray outputs of the neural network

Neural Tangents is designed to serve as a drop-in replacement for stax, extending the (init_fn, apply_fn) tuple to a triple (init_fn, apply_fn, kernel_fn), where kernel_fn is the kernel function of the infinite network (GP) of the given architecture. Below is an example of computing the covariances of the GP between two batches of inputs x1 and x2.

from jax import random
from neural_tangents import stax

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512), stax.Relu(),
    stax.Dense(512), stax.Relu(),
    stax.Dense(1)
)

key1, key2 = random.split(random.PRNGKey(1))
x1 = random.normal(key1, (10, 100))
x2 = random.normal(key2, (20, 100))

kernel = kernel_fn(x1, x2, 'nngp')

Note that kernel_fn can compute two covariance matrices corresponding to the Neural Network Gaussian Process (NNGP) and Neural Tangent (NT) kernels respectively. The NNGP kernel corresponds to the Bayesian infinite neural network [1-5]. The NTK corresponds to the (continuous) gradient descent trained infinite network [10]. In the above example, we compute the NNGP kernel but we could compute the NTK or both:

# Get kernel of a single type
nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) np.ndarray
ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) np.ndarray

# Get kernels as a namedtuple
both = kernel_fn(x1, x2, ('nngp', 'ntk'))
both.nngp == nngp  # True
both.ntk == ntk  # True

# Unpack the kernels namedtuple
nngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk'))

Additionally, if no third-argument is specified then the kernel_fn will return a Kernel namedtuple that contains additional metadata. This can be useful for composing applications of kernel_fn as follows:

kernel = kernel_fn(x1, x2)
kernel = kernel_fn(kernel)
print(kernel.nngp)

Doing inference with infinite networks trained on MSE loss reduces to classical GP inference, for which we also provide convenient tools:

import neural_tangents as nt

x_train, x_test = x1, x2
y_train = random.uniform(key1, shape=(10, 1))  # training targets

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
                                                      y_train)

y_test_nngp = predict_fn(x_test=x_test, get='nngp')
# (20, 1) np.ndarray test predictions of an infinite Bayesian network

y_test_ntk = predict_fn(x_test=x_test, get='ntk')
# (20, 1) np.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)

# Get predictions as a namedtuple
both = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
both.nngp == y_test_nngp  # True
both.ntk == y_test_ntk  # True

# Unpack the predictions namedtuple
y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'))

Infinitely WideResnet

We can define a more compex, (infinitely) Wide Residual Network [14] using the same nt.stax building blocks:

from neural_tangents import stax

def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
  Main = stax.serial(
      stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
      stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
  Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
      channels, (3, 3), strides, padding='SAME')
  return stax.serial(stax.FanOut(2),
                     stax.parallel(Main, Shortcut),
                     stax.FanInSum())

def WideResnetGroup(n, channels, strides=(1, 1)):
  blocks = []
  blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
  for _ in range(n - 1):
    blocks += [WideResnetBlock(channels, (1, 1))]
  return stax.serial(*blocks)

def WideResnet(block_size, k, num_classes):
  return stax.serial(
      stax.Conv(16, (3, 3), padding='SAME'),
      WideResnetGroup(block_size, int(16 * k)),
      WideResnetGroup(block_size, int(32 * k), (2, 2)),
      WideResnetGroup(block_size, int(64 * k), (2, 2)),
      stax.AvgPool((8, 8)),
      stax.Flatten(),
      stax.Dense(num_classes, 1., 0.))

init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)

Package description

The neural_tangents (nt) package contains the following modules and functions:

  • stax - primitives to construct neural networks like Conv, Relu, serial, parallel etc.

  • predict - predictions with infinite networks:

    • predict.gradient_descent_mse - inference with a single infinite width / linearized network trained on MSE loss with continuous gradient descent for an arbitrary finite or infinite (t=None) time. Computed in closed form.

    • predict.gradient_descent - inference with a single infinite width / linearized network trained on arbitrary loss with continuous (momentum) gradient descent for an arbitrary finite time. Computed using an ODE solver.

    • predict.gradient_descent_mse_ensemble - inference with an infinite ensemble of infinite width networks, either fully Bayesian (get='nngp') or inference with MSE loss using continuous gradient descent (get='ntk'). Finite-time Bayesian inference (e.g. t=1., get='nngp') is interpreted as gradient descent on the top layer only [11], since it converges to exact Gaussian process inference with NNGP (t=None, get='nngp'). Computed in closed form.

    • predict.gp_inference - exact closed form Gaussian process inference using NNGP (get='nngp'), NTK (get='ntk'), or both (get=('nngp', 'ntk')). Equivalent to predict.gradient_descent_mse_ensemble with t=None (infinite training time), but has a slightly different API (accepting precomputed kernel matrix k_train_train instead of kernel_fn and x_train).

  • monte_carlo_kernel_fn - compute a Monte Carlo kernel estimate of any (init_fn, apply_fn), not necessarily specified via nt.stax, enabling the kernel computation of infinite networks without closed-form expressions.

  • Tools to investigate training dynamics of wide but finite neural networks, like linearize, taylor_expand, empirical_kernel_fn and more. See Training dynamics of wide but finite networks for details.

Technical gotchas

nt.stax vs jax.experimental.stax

We remark the following differences between our library and the JAX one.

  • All nt.stax layers are instantiated with a function call, i.e. nt.stax.Relu() vs jax.experimental.stax.Relu.
  • All layers with trainable parameters use the NTK parameterization by default (see [10], Remark 1). However, Dense and Conv layers also support the standard parameterization via a parameterization keyword argument (see [15]).
  • nt.stax and jax.experimental.stax may have different layers and options available (for example nt.stax layers support CIRCULAR padding, have LayerNorm, but no BatchNorm.).

CPU and TPU performance

For CNNs w/ pooling, our CPU and TPU performance is suboptimal due to low core utilization (10-20%, looks like an XLA:CPU issue), and excessive padding respectively. We will look into improving performance, but recommend NVIDIA GPUs in the meantime. See Performance.

Training dynamics of wide but finite networks

The kernel of an infinite network kernel_fn(x1, x2).ntk combined with nt.predict.gradient_descent_mse together allow to analytically track the outputs of an infinitely wide neural network trained on MSE loss througout training. Here we discuss the implications for wide but finite neural networks and present tools to study their evolution in weight space (trainable parameters of the network) and function space (outputs of the network).

Weight space

Continuous gradient descent in an infinite network has been shown in [11] to correspond to training a linear (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.

For this, we provide two convenient functions:

  • nt.linearize, and
  • nt.taylor_expand,

which allow to linearize or get an arbitrary-order Taylor expansion of any function apply_fn(params, x) around some initial parameters params_0 as apply_fn_lin = nt.linearize(apply_fn, params_0).

One can use apply_fn_lin(params, x) exactly as you would any other function (including as an input to JAX optimizers). This makes it easy to compare the training trajectory of neural networks with that of its linearization. Previous theory and experiments have examined the linearization of neural networks from inputs to logits or pre-activations, rather than from inputs to post-activations which are substantially more nonlinear.

Example:

import jax.numpy as np
import neural_tangents as nt

def apply_fn(params, x):
  W, b = params
  return np.dot(x, W) + b

W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))

apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0))
W = np.array([[1.5, 0.2], [0.1, 0.9]])
b = b_0 + 0.2

x = np.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])
logits = apply_fn_lin((W, b), x)  # (3, 2) np.ndarray

Function space:

Outputs of a linearized model evolve identically to those of an infinite one [11] but with a different kernel - specifically, the Neural Tangent Kernel [10] evaluated on the specific apply_fn of the finite network given specific params_0 that the network is initialized with. For this we provide the nt.empirical_kernel_fn function that accepts any apply_fn and returns a kernel_fn(x1, x2, get, params) that allows to compute the empirical NTK and/or NNGP (based on get) kernels on specific params.

Example:

import jax.random as random
import jax.numpy as np
import neural_tangents as nt

def apply_fn(params, x):
  W, b = params
  return np.dot(x, W) + b

W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
params = (W_0, b_0)

key1, key2 = random.split(random.PRNGKey(1), 2)
x_train = random.normal(key1, (3, 2))
x_test = random.normal(key2, (4, 2))
y_train = random.uniform(key1, shape=(3, 2))

kernel_fn = nt.empirical_kernel_fn(apply_fn)
ntk_train_train = kernel_fn(x_train, None, 'ntk', params)
ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
mse_predictor = nt.predict.gradient_descent_mse(ntk_train_train, y_train)

t = 5.
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train)
# (3, 2) and (4, 2) np.ndarray train and test outputs after `t` units of time
# training with continuous gradient descent

What to Expect

The success or failure of the linear approximation is highly architecture dependent. However, some rules of thumb that we've observed are:

  1. Convergence as the network size increases.

    • For fully-connected networks one generally observes very strong agreement by the time the layer-width is 512 (RMSE of about 0.05 at the end of training).

    • For convolutional networks one generally observes reasonable agreement agreement by the time the number of channels is 512.

  2. Convergence at small learning rates.

With a new model it is therefore advisable to start with a very large model on a small dataset using a small learning rate.

Performance

In the table below we measure time to compute a single NTK entry in a 21-layer CNN (3x3 filters, no strides, SAME padding, ReLU) on inputs of shape 3x32x32. Precisely:

layers = []
for _ in range(21):
  layers += [stax.Conv(1, (3, 3), (1, 1), 'SAME'), stax.Relu()]

CNN with pooling

Top layer is stax.GlobalAvgPool():

_, _, kernel_fn = stax.serial(*(layers + [stax.GlobalAvgPool()]))
Platform Precision Milliseconds / NTK entry Max batch size (NxN)
CPU, >56 cores, >700 Gb RAM 32 112.90 >= 128
CPU, >56 cores, >700 Gb RAM 64 258.55 95 (fastest - 72)
TPU v2 32/16 3.2550 16
TPU v3 32/16 2.3022 24
NVIDIA P100 32 5.9433 26
NVIDIA P100 64 11.349 18
NVIDIA V100 32 2.7001 26
NVIDIA V100 64 6.2058 18

CNN without pooling

Top layer is stax.Flatten():

_, _, kernel_fn = stax.serial(*(layers + [stax.Flatten()]))
Platform Precision Milliseconds / NTK entry Max batch size (NxN)
CPU, >56 cores, >700 Gb RAM 32 0.12013 2048 <= N < 4096 (fastest - 512)
CPU, >56 cores, >700 Gb RAM 64 0.3414 2048 <= N < 4096 (fastest - 256)
TPU v2 32/16 0.0015722 512 <= N < 1024
TPU v3 32/16 0.0010647 512 <= N < 1024
NVIDIA P100 32 0.015171 512 <= N < 1024
NVIDIA P100 64 0.019894 512 <= N < 1024
NVIDIA V100 32 0.0046510 512 <= N < 1024
NVIDIA V100 64 0.010822 512 <= N < 1024

Tested using version 0.2.1. All GPU results are per single accelerator. Note that runtime is proportional to the depth of your network. If your performance differs significantly, please file a bug!

Myrtle network

Colab notebook Performance Benchmark demonstrates how one would construct and benchmark kernels. To demonstrate flexibility, we took architecture from [16] as an example. With NVIDIA V100 64-bit precision, nt took 316/330/508 GPU-hours on full 60k CIFAR-10 dataset for Myrtle-5/7/10 kernels.

Papers

Neural Tangents has been used in the following papers:

  1. Correlated Weights in Infinite Limits of Deep Convolutional Neural Networks
  2. Dataset Meta-Learning from Kernel Ridge-Regression
  3. Deep learning versus kernel learning: an empirical study of loss landscape geometry and the time evolution of the Neural Tangent Kernel
  4. Stable ResNet
  5. Label-Aware Neural Tangent Kernel: Toward Better Generalization and Local Elasticity
  6. Semi-supervised Batch Active Learning via Bilevel Optimization
  7. Temperature check: theory and practice for training models with softmax-cross-entropy losses
  8. Experimental Design for Overparameterized Learning with Application to Single Shot Deep Active Learning
  9. How Neural Networks Extrapolate: From Feedforward to Graph Neural Networks
  10. Exploring the Uncertainty Properties of Neural Networks’ Implicit Priors in the Infinite-Width Limit
  11. Cold Posteriors and Aleatoric Uncertainty
  12. Asymptotics of Wide Convolutional Neural Networks
  13. Finite Versus Infinite Neural Networks: an Empirical Study
  14. Bayesian Deep Ensembles via the Neural Tangent Kernel
  15. The Surprising Simplicity of the Early-Time Learning Dynamics of Neural Networks
  16. When Do Neural Networks Outperform Kernel Methods?
  17. Statistical Mechanics of Generalization in Kernel Regression
  18. Exact posterior distributions of wide Bayesian neural networks
  19. Infinite attention: NNGP and NTK for deep attention networks
  20. Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains
  21. Finding trainable sparse networks through Neural Tangent Transfer
  22. Coresets via Bilevel Optimization for Continual Learning and Streaming
  23. On the Neural Tangent Kernel of Deep Networks with Orthogonal Initialization
  24. The large learning rate phase of deep learning: the catapult mechanism
  25. Spectrum Dependent Learning Curves in Kernel Regression and Wide Neural Networks
  26. Taylorized Training: Towards Better Approximation of Neural Network Training at Finite Width
  27. On the Infinite Width Limit of Neural Networks with a Standard Parameterization
  28. Disentangling Trainability and Generalization in Deep Learning
  29. Information in Infinite Ensembles of Infinitely-Wide Neural Networks
  30. Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel
  31. Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent
  32. Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes

Please let us know if you make use of the code in a publication and we'll add it to the list!

Citation

If you use the code in a publication, please cite our ICLR 2020 paper:

@inproceedings{neuraltangents2020,
    title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},
    author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
    booktitle={International Conference on Learning Representations},
    year={2020},
    url={https://github.com/google/neural-tangents}
}

References

[1] Priors for Infinite Networks
[2] Exponential expressivity in deep neural networks through transient chaos
[3] Toward deeper understanding of neural networks: The power of initialization and a dual view on expressivity
[4] Deep Information Propagation
[5] Deep Neural Networks as Gaussian Processes
[6] Gaussian Process Behaviour in Wide Deep Neural Networks
[7] Dynamical Isometry and a Mean Field Theory of CNNs: How to Train 10,000-Layer Vanilla Convolutional Neural Networks.
[8] Bayesian Deep Convolutional Networks with Many Channels are Gaussian Processes
[9] Deep Convolutional Networks as shallow Gaussian Processes
[10] Neural Tangent Kernel: Convergence and Generalization in Neural Networks
[11] Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent
[12] Scaling Limits of Wide Neural Networks with Weight Sharing: Gaussian Process Behavior, Gradient Independence, and Neural Tangent Kernel Derivation
[13] Mean Field Residual Networks: On the Edge of Chaos
[14] Wide Residual Networks
[15] On the Infinite Width Limit of Neural Networks with a Standard Parameterization
[16] Neural Kernels Without Tangents
Comments
  • Memory and running time issues for CNN

    Memory and running time issues for CNN

    Hi,

     I currently use the neural tangents to compute the kernel for CiFAR-10 images. I need to compute the kernel matrix for 10000 images x 10000 images and there are 3x32x32 pixels each image. If I use a 2-layer feedforward NNs with reshaped input 3072, it took me about 3G memory and several minutes to compute the kernel.
    
    However, if I use a simple CNN network (one layer CNN), it will output an error with "failed to allocate request 381T memory". I can only reduce the size of minibatch each time. But it will make the computing process quite slower. And this is just one-layer CNN, I expect it will cost more time for multilayer CNN. And even for one batch (100 images), it still costs much more time than the 2-layer feedforward NNs.
    
    Another strange thing is that I expect that I should be able to compute the kernel matrix for batch size 200 (out of 10000) each time because the server has a memory of 394G.  But it is still out of memory (manually checked) after running several minutes and killed without error prompt.
    
    So I am wondering how to use your tools to compute the kernel matrix for CNNs. It either costs too much memory or too much time in my end. Do you have any suggestions to deal with this issue?  I am not sure about your latent mechanism to compute the kernel for CNN. But I expect it shouldn't cost so much memory and run so slow, because [Arora et al' 2019](https://arxiv.org/pdf/1904.11955.pdf) compute the kernel for 21-layer CNN.
    
    It is really a good tool but I hope that you can help with the CNN memory and running time issue.
    

    Thanks, Hangfeng

    enhancement 
    opened by HornHehhf 11
  • value_and_grad(kernel_fn) not equal to kernel_fn with standard parameterization

    value_and_grad(kernel_fn) not equal to kernel_fn with standard parameterization

    I am confused by the behavior of the following snippet of code (the WideResNet from the README with standard parameterization):

    import jax
    from neural_tangents import stax
    
    
    def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
        main = stax.serial(
            stax.Relu(),
            stax.Conv(
                channels, (3, 3), strides, padding="SAME", parameterization="standard"
            ),
            stax.Relu(),
            stax.Conv(channels, (3, 3), padding="SAME", parameterization="standard"),
        )
        shortcut = (
            stax.Identity()
            if not channel_mismatch
            else stax.Conv(
                channels, (3, 3), strides, padding="SAME", parameterization="standard"
            )
        )
        return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum())
    
    
    def WideResnetGroup(n, channels, strides=(1, 1)):
        blocks = []
        blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
        for _ in range(n - 1):
            blocks += [WideResnetBlock(channels, (1, 1))]
        return stax.serial(*blocks)
    
    
    def WideResnet(block_size, k, num_classes):
        return stax.serial(
            stax.Conv(16, (3, 3), padding="SAME", parameterization="standard"),
            WideResnetGroup(block_size, int(16 * k)),
            WideResnetGroup(block_size, int(32 * k), (2, 2)),
            WideResnetGroup(block_size, int(64 * k), (2, 2)),
            stax.AvgPool((8, 8)),
            stax.Flatten(),
            stax.Dense(num_classes, 1.0, 0.0, parameterization="standard"),
        )
    
    
    _, _, kernel_fn = WideResnet(block_size=4, k=1, num_classes=1)
    
    def kernel_scalar(x, y):
        return kernel_fn(x, y, "ntk")[0, 0]
    
    z = jax.numpy.zeros((1, 32, 32, 3))
    print(jax.value_and_grad(kernel_scalar)(z, z)[0])
    print(kernel_scalar(z, z))
    

    My understanding is that the two printed values should be the same. However, when I run it, I get two totally different values:

    34.41480472358908
    64.62813414153004
    

    Is my understanding correct? I have not yet found a simpler network that features this behavior.

    Versions:

    • jax 0.2.20
    • jaxlib 0.1.71+cuda111
    • neural-tangents 0.3.7
    bug 
    opened by PythonNut 8
  • Can this be used to compute the NTK for a finite-width neural network?

    Can this be used to compute the NTK for a finite-width neural network?

    Suppose I have a regular old neural network with its weights set to some values. Then the NTK k(x, y) is well-defined as the dot product of df/dw at each input, that is, the dot product of the gradients of the network's output with respect to the weights. In some of my own code I'm computing this kernel using Keras with tensorflow's automatic differentiation capabilities, but it chokes on even moderate-sized models (trying to compute the train-train kernel with 1000 neurons and 10k training inputs).

    I've been looking at neural_tangents.utils.empirical, but I thought I'd ask - does this codebase contain some magical code that will allow me to compute my Gram matrix in a reasonable amount of time and memory?

    question 
    opened by geajack 8
  • A type error in predict.gradient_descent

    A type error in predict.gradient_descent

    I wrote simple codes with monte_carlo_kernel_fn and gradient_descent modules, but it raised an unidentifiable type error even though I've never manipulated any types in the code. Basically, I followed some examples shown in the source codes except for the fact that I used jax.experiment.stax.Tanh to build a two-layer neural network involving a hypertangent activation.

    The code I ran was as follows:

    import neural_tangents as nt
    import jax.experimental.stax as ostax
    from jax import random as jrandom
    import jax.numpy as np
    
    key = jrandom.PRNGKey(0)
    
    def gen_key():
        global key
        key, k = jrandom.split(key, 2)
        return k
    
    def cross_entropy(fx, y_hat):
        return -np.mean(ostax.logsoftmax(fx)*y_hat)
    
    x_train = jrandom.normal(gen_key(), (20, 784))
    x_test = jrandom.normal(gen_key(), (20, 784))
    y_train = jrandom.normal(gen_key(), (20, 50))
    
    init_fn, apply_fn = ostax.serial(
        ostax.Dense(200), ostax.Tanh, ostax.Dense(50))
    
    _, params = init_fn(gen_key(), x_train.shape)
    
    kernel_fn = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, key=gen_key(), n_samples=100)
    
    k_train_train = kernel_fn(x_train, None, get='ntk')
    k_test_train = kernel_fn(x_test, x_train, get='ntk')
    
    predict_fn = nt.predict.gradient_descent(
        cross_entropy, k_train_train, y_train, 1e-2, 0.9)
    fx_train_0 = apply_fn(params, x_train)
    fx_test_0 = apply_fn(params, x_test)
    
    t = 1e-7
    fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, k_test_train)
    print(fx_train_t)
    print(fx_test_t)
    

    The raised error was as follows:

    Traceback (most recent call last):
      File "nt-practice.py", line 42, in <module>
        fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, k_test_train)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/neural_tangents/predict.py", line 472, in predict_fn
        state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/experimental/ode.py", line 173, in odeint
        return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/api.py", line 338, in cache_miss
        donated_invars=donated_invars)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/core.py", line 1402, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/core.py", line 1393, in call_bind
        outs = primitive.process(top_trace, fun, tracers, params)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/core.py", line 1405, in process
        return trace.process_call(self, fun, tracers, params)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/core.py", line 600, in process_call
        return primitive.impl(f, *tracers, **params)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/xla.py", line 577, in _xla_call_impl
        *unsafe_map(arg_spec, args))
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/linear_util.py", line 260, in memoized_fun
        ans = call(fun, *args)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/xla.py", line 652, in _xla_callable
        jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1209, in trace_to_jaxpr_final
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/experimental/ode.py", line 179, in _odeint_wrapper
        out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/custom_derivatives.py", line 485, in __call__
        out_trees=out_trees)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/custom_derivatives.py", line 566, in bind
        out_trees=out_trees)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1137, in process_custom_vjp_call
        fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/experimental/ode.py", line 214, in _odeint
        _, ys = lax.scan(scan_fun, init_carry, ts[1:])
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 1276, in scan
        init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 1263, in _create_jaxpr
        jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals, "scan")
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/util.py", line 185, in wrapper
        return cached(bool(config.x64_enabled), *args, **kwargs)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/util.py", line 178, in cached
        return f(*args, **kwargs)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 77, in _initial_style_jaxpr
        transform_name)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/util.py", line 185, in wrapper
        return cached(bool(config.x64_enabled), *args, **kwargs)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/util.py", line 178, in cached
        return f(*args, **kwargs)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 70, in _initial_style_open_jaxpr
        transform_name=transform_name)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1178, in trace_to_jaxpr_dynamic
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/experimental/ode.py", line 204, in scan_fun
        _, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 301, in while_loop
        in_tree_children[0], init_avals)
      File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 1940, in _check_tree_and_avals
        f"{what} must have identical types, got\n"
    TypeError: body_fun output and input must have identical types, got
    [ShapedArray(int64[], weak_type=True), ShapedArray(float64[4000]), ShapedArray(float64[4000]), ShapedArray(float64[]), ShapedArray(float64[]), ShapedArray(float32[]), ShapedArray(float64[5,4000])]
    and
    [ShapedArray(int64[], weak_type=True), ShapedArray(float64[4000]), ShapedArray(float64[4000]), ShapedArray(float32[]), ShapedArray(float64[]), ShapedArray(float32[]), ShapedArray(float64[5,4000])].
    

    Is there any idea to address this problem?

    bug 
    opened by kim-hyunsu 7
  • Calculating Marginal log-likelihood of NNGP

    Calculating Marginal log-likelihood of NNGP

    Hi,

    I've been using the neural-tangents library a lot over the past few months, it's been extremely helpful.

    I just a had a question about calculating the marginal log-likelihood for NNGPs, which I came across when reading the neural tangents library paper (e.g. figure 3, figure 7).

    I have tried to calculate the NLL on the CIFAR-10 dataset as well and have linked my jupyter notebook . The problem I'm getting is that as I increase the training test size, the training NLL increases as well, which is the opposite of the results in the paper. Could you please point out the errors in my code/calculation or perhaps share the code for the calculations?

    Thanks

    question 
    opened by TZeng20 7
  • Question about `test_composition_conv` in Stax Tests

    Question about `test_composition_conv` in Stax Tests

    Hi, sorry for bothering. In the test_composition_conv_avg_pool test cases, some outer products on the covariance matrices are performed while doing Kernel transformation. In the outer product function, there is the interleave_ones operation which adds ones to the covariance dimensions:

    def outer_prod(x, y, start_axis, end_axis, prod_op):
      if y is None:
        y = x
      x = interleave_ones(x, start_axis, end_axis, True)
      y = interleave_ones(y, start_axis, end_axis, False)
      tf.print("x: {}, y: {}".format(x.shape, y.shape), output_stream=sys.stdout)
      return prod_op(x, y)
    

    When I print out the shapes after interleave_ones, some shapes are like x: (5, 1, 8, 1, 8, 1), y: (1, 5, 1, 8, 1, 8) which obviously do not match. In this case, would you mind explaining the role of interleave_ones and how could the unmatched shapes be multiplied together? Thanks!

    opened by DarrenZhang01 7
  • Using gp_inference to predict the posterior distribution - What is nngp_test_test?

    Using gp_inference to predict the posterior distribution - What is nngp_test_test?

    I want to use neural_tangents.predict.gp_inference to predict a distribution (mean and variance) on a test set. The documentation says that it returns a function predict_fn(get, k_test_train, nngp_test_test) and the docstring for the code says that nngp_test_test is a "A test-test NNGP array," but I see no mention of this NNGP array anywhere else in the documentation. How do I produce this?

    question 
    opened by samuelkim314 6
  • Analytic kernel evaluated on sparse inputs

    Analytic kernel evaluated on sparse inputs

    Hi!

    A bug seems to occur when I was trying to evaluate analytic NTKs using sparse input data -- the evaluated kernel contains nan entries. This can be reproduced with the following lines of codes:

    from jax import random
    from neural_tangents import stax
    
    key = random.PRNGKey(1)
    
    # a batch of dense inputs 
    x_dense = random.normal(key, (3, 32, 32, 3))
    
    # a batch of sparse inputs 
    x_sparse = x_dense * (abs(x_dense) > 1.2)
    
    
    # A CNN architecture
    init_fn, apply_fn, kernel_fn = stax.serial(
         stax.Conv(128, (3, 3)),
         stax.Relu(),
         stax.Flatten(),
         stax.Dense(10) )
    
    # Evaluate the analytic NTK upon dense inputs
    
    print('NTK evaluated w/ dense inputs: \n', kernel_fn(x_dense, x_dense, 'ntk')) # the outputs look fine.
    
    print('\n')
    
    # Evaluate the analytic NTK upon sparse inputs
    
    print('NTK evaluated w/ sparse inputs: \n', kernel_fn(x_sparse, x_sparse, 'ntk')) # the outputs contains nan
    

    The output of the above script should be:

    NTK evaluated w/ dense inputs: 
     [[0.97102666 0.16131128 0.16714054]
     [0.16131128 0.9743941  0.17580226]
     [0.16714054 0.17580226 1.0097454 ]]
    
    
    NTK evaluated w/ sparse inputs: 
     [[       nan        nan        nan]
     [       nan 0.66292834        nan]
     [       nan        nan        nan]]
    
    

    Thanks for your time in advance!

    opened by liutianlin0121 6
  • Error using batch with jit

    Error using batch with jit

    I get the error Too many leaves for PyTreeDef; expected 6. when I'm trying to run the following code -

    def get_network(W_std=1):
        init_fun, apply_fun, ker_fun = stax.serial(
            stax.Dense(1, W_std=W_std, b_std=0.1)
        )
        ker_fun =jit(batch(ker_fun, batch_size=25, device_count=0))
        kdd = ker_fun(train_xs, None)
        return 0
    jit(get_network)(2.0)
    
    bug 
    opened by ravidziv 6
  • Gettting NaN for predict.gradient_descent_mse

    Gettting NaN for predict.gradient_descent_mse

    Hi,

    As the title says, I am getting NaN values for predict.gradient_descent_mse (t=None, diag_reg=0). According to my diagnosis, it is because of cho_factor function in predict._get_cho_solve outputs NaNs, which is weird since cho_factor should not give NaNs for PSD matrix and an empirical tangent kernel (= ntk_fn(X_train, None, params)) is supposed to be a PSD matrix. If I give diag_reg > 0, it does not blow up to NaNs but still returns a large number. I may increase diag_reg to avoid this at the cost of accuracy. What would be the best way to get around this problem?

    question 
    opened by won-bae 5
  • Significant difference in empirical NTK for batched and non-batched versions

    Significant difference in empirical NTK for batched and non-batched versions

    After updating my environment to work with a more recent version of JAX and FLAX, I have noticed that empirical the NTK Gram matrices computed using nt.batch applied to nt.empirical_kernel_fn are significantly different depending on the batch size.

    The code to reproduce this error is:

    import jax.numpy as jnp
    import flax.linen as nn
    import functools
    import jax
    import neural_tangents as nt
    
    
    class LeNet(nn.Module):
        kernel_size = (5, 5)
        strides = (2, 2)
        window_shape = (2, 2)
        num_classes = 1
        features = (6, 16, 120, 84, 1)
        pooling = True
        padding = "SAME"
    
        @nn.compact
        def __call__(self, x):
            conv = functools.partial(nn.Conv, padding=self.padding)
            x = conv(features=self.features[0], kernel_size=tuple(self.kernel_size))(x)
            x = nn.relu(x)
            x = nn.avg_pool(x, window_shape=tuple(self.window_shape), strides=tuple(self.strides))
    
            x = conv(features=self.features[1], kernel_size=tuple(self.kernel_size))(x)
            x = nn.relu(x)
            x = nn.avg_pool(x, window_shape=tuple(self.window_shape), strides=tuple(self.strides))
    
            x = x.reshape((x.shape[0], -1))
            x = nn.Dense(self.features[2])(x)
            x = nn.relu(x)
            x = nn.Dense(self.features[3])(x)
            x = nn.relu(x)
    
            x = nn.Dense(self.num_classes)(x)
            return x
    
    model_key, data_key = jax.random.split(jax.random.PRNGKey(42))
    data = jax.random.normal(data_key, [500, 32, 32, 3])
    model = LeNet()
    init_params = model.init(model_key, jnp.zeros([1, 32, 32, 3]))
    
    # Compute NTK Gram matrix using the fully parallel version
    kernel_full_fn = nt.batch(
        nt.empirical_kernel_fn(model.apply, vmap_axes=0, implementation=2, trace_axes=()),
        batch_size=500,
        device_count=-1,
        store_on_device=False,
    )
    K_full = kernel_full_fn(data, None, "ntk", init_params)
    
    # Compute NTK Gram matrix using minibatches
    kernel_batch_fn = nt.batch(
        nt.empirical_kernel_fn(model.apply, vmap_axes=0, implementation=2, trace_axes=()),
        batch_size=100,
        device_count=-1,
        store_on_device=False,
    )
    K_batch = kernel_batch_fn(data, None, "ntk", init_params)
    
    # Compute difference between two matrices. It should technically be 0.
    print("Average error per entry:",  jnp.linalg.norm(K_full - K_batch) / K_full.size)
    

    Surprisingly, if I run this with my old environment I get an average error of the order of 1e-8, while with the new environment the error is of the order of 1e-1. Also, this error remains exactly the same as long as batch_size<data.shape[0].

    My old enviornment consisted of:

    python=3.7.4
    jax=0.2.8
    jaxlib=0.1.57+cuda102
    flax=0.3.0
    neural-tangents=0.3.7
    

    and my new environment has:

    python=3.7.4
    jax=0.2.19
    jaxlib=0.1.70+cuda102
    flax=0.3.4
    neural-tangents=0.3.7
    
    bug 
    opened by gortizji 5
  • Cases where NTK and NNGP are very different

    Cases where NTK and NNGP are very different

    Hello everyone, I'm looking for types of data sets where the NTK prediction and the NNGP prediction will be very different. Both I'm calculating according to theory using k(x_test,X_training)*inverse(K(X_training, X_training)) *Y. Where Y is the labels vector, k test is the kernel between the test point and the training and X is the training matrix). For NTK I use NTK kernels and for NNGP I using NNGP kernels. Is there a problem set in which these values will be very different

    question 
    opened by YehonatanAvidan 0
  • Excessive memory consumption for deep networks

    Excessive memory consumption for deep networks

    The LLVM compiler pass uses excessive amounts of memory for deep networks which are constructed like this

    stax.serial([my_layer]*depth)
    

    In fact, the compilation may eventually OOM.

    The reason is that the serial combinator internally relies on a python for loop (with carry) to support mixed input sequences.

    It would be nice to have a specialization for the case in which the same layer is repeated n times, which could then use jax.lax.scan() to save compilation time by avoiding loop unrolling.

    Suggestion:

    import jax.example_libraries.stax as ostax
    from neural_tangents._src.utils.typing import Layer, InternalLayer, NTTree
    from neural_tangents._src.stax.requirements import get_req, requires, layer
    from neural_tangents._src.utils.kernel import Kernel
    from jax.lax import scan
    import jax.numpy as np
    
    @layer
    def repeat(layer: Layer, n: int) -> InternalLayer:
      """Combinator for repeating the same layers `n` times.
    
      Based on :obj:`jax.example_libraries.stax.serial`.
    
      Args:
        layer:
          a single layer, each an `(init_fn, apply_fn, kernel_fn)` triple.
    
        n:
          the number of iterations
    
      Returns:
        A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple,
        representing the composition of `n` layers.
      """
      init_fn, apply_fn, kernel_fn = layer
    
      init_fn, apply_fn = ostax.serial(*zip([init_fn] * n, [apply_fn] * n))
      @requires(**get_req(kernel_fn))
      def kernel_fn_scan(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
        # TODO(xlc): if we drop `x1_is_x2` and use `rng` instead, need split key
        # inside kernel functions here and parallel below.
        k, _ = scan(lambda carry, x: (kernel_fn(carry, **kwargs), None), k, np.arange(n))
        return k
    
      return init_fn, apply_fn, kernel_fn_scan
    

    Use like this

    repeat(my_layer, depth)
    
    enhancement 
    opened by jglaser 2
  • Questions about memory consumption of infinitely wide NTK

    Questions about memory consumption of infinitely wide NTK

    I am working on a simple MNIST example. I found that I could not compute the NTK for the entire dataset without running out of memory. Below is the code snippet I used:

    import neural_tangents as nt
    from neural_tangents import stax
    from examples import datasets
    from jax import random, jit
    import jax.numpy as jnp
    
    def FC(depth=1, num_classes=10, W_std=1.0, b_std=0.0):
        layers = [stax.Flatten()]
        for _ in range(depth):
            layers += [stax.Dense(1, W_std, b_std), stax.Relu()]
        layers += [stax.Dense(num_classes, W_std, b_std)]
        return stax.serial(*layers)
    
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', data_dir="./data", permute_train=True)
    
    key = random.PRNGKey(0)
    init_fn, apply_fn, kernel_fn = FC()
    _, params = init_fn(key, (-1, 784))
    
    apply_fn = jit(apply_fn)
    kernel_fn = jit(kernel_fn, static_argnums=(2,))
    
    batched_kernel_fn = nt.batch(kernel_fn, 1000, store_on_device=False)
    
    k_train_train = kernel_fn(x_train, None, 'ntk')
    k_test_train = kernel_fn(x_test, x_train, 'ntk')
    predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train)
    fx_train_0 = apply_fn(params, x_train)
    fx_test_0 = apply_fn(params, x_test)
    fx_train_inf, fx_test_inf = predict_fn(fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)
    

    I am running this on two RTX3090 each having a 24Gb buffer. Is there something I'm doing wrong, or is it normal for NTK to consume so much memory? Thank you!

    bug 
    opened by jasonli0707 6
  • Does neural-tangents work for custom layer?

    Does neural-tangents work for custom layer?

    I have built a custom layer (KerasLayer) using class in python (say class NewLayer). Can I use something like stax.NewLayer for manipulating neural-tangents on this custom layer?

    question 
    opened by Shuhul24 1
  • Draw Phase Diagram for CNTK

    Draw Phase Diagram for CNTK

    I'm curious about the initialization for CNTK, so I replace the kernel_fn in c_map(W_var, b_var) function in colab with:

    # Create a single layer of a network as an affine transformation composed
    # with an Erf nonlinearity.
    # kernel_fn = stax.serial(stax.Dense(1024, W_std, b_std), stax.Erf())[2]
    kernel_fn = stax.serial(
          stax.Conv(out_chan=1024, filter_shape=(3, 3), strides=None, padding='SAME', W_std=W_std, b_std=b_std),
          stax.Relu(),
          stax.Flatten(),
          stax.Dense(10, W_std=W_std, b_std=b_std, parameterization='ntk')
    )[2]
    

    However, it seems that there's a bottom layer error when I tried to plot, with the error msg as follow:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    /tmp/ipykernel_54123/684423119.py in <module>
    ----> 1 plt.contourf(W_var, b_var, c_star(W_var, b_var))
          2 plt.colorbar()
          3 plt.title('$C^*$ as a function of weight and bias variance', fontsize=14)
          4 
          5 format_plot('$\\sigma_w^2$', '$\\sigma_b^2$')
    
        [... skipping hidden 18 frame]
    
    /tmp/ipykernel_54123/2333898834.py in <lambda>(W_var, b_var)
         51   return c_map_fn
         52 
    ---> 53 c_star = lambda W_var, b_var: fixed_point(c_map(W_var, b_var), 0.1, 1e-7)
         54 chi = lambda c, W_var, b_var: grad(c_map(W_var, b_var))(c)
         55 chi_1 = partial(chi, 1.)
    
    /tmp/ipykernel_54123/2333898834.py in c_map(W_var, b_var)
         42     return kernel_fn(Kernel(np.array([[q]]))).nngp[0, 0]
         43 
    ---> 44   qstar = fixed_point(q_map_fn, 1.0, 1e-7)
         45 
         46   def c_map_fn(c):
    
    /tmp/ipykernel_54123/3420146269.py in fixed_point(f, initial_value, threshold)
         38     return x - g(x) / dg(x), x
         39 
    ---> 40   return lax.while_loop(cond_fn, body_fn, (initial_value, 0.0))[0]
    
        [... skipping hidden 12 frame]
    
    /tmp/ipykernel_54123/3420146269.py in body_fn(x)
         36   def body_fn(x):
         37     x, _ = x
    ---> 38     return x - g(x) / dg(x), x
         39 
         40   return lax.while_loop(cond_fn, body_fn, (initial_value, 0.0))[0]
    
    /tmp/ipykernel_54123/3420146269.py in <lambda>(x)
         27 def fixed_point(f, initial_value, threshold):
         28   """Find fixed-points of a function f:R->R using Newton's method."""
    ---> 29   g = lambda x: f(x) - x
         30   dg = grad(g)
         31 
    
    /tmp/ipykernel_54123/2333898834.py in q_map_fn(q)
         40   def q_map_fn(q):
         41     print(q)
    ---> 42     return kernel_fn(Kernel(np.array([[q]]))).nngp[0, 0]
         43 
         44   qstar = fixed_point(q_map_fn, 1.0, 1e-7)
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
        174     @functools.wraps(f)
        175     def h(*args, **kwargs):
    --> 176       return g(*args, **kwargs)
        177 
        178     h.__signature__ = inspect.signature(f)
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in getter_fn(*args, **kwargs)
        208                                                           len(args)])
        209 
    --> 210       fn_out = fn(*canonicalized_args, **kwargs)
        211 
        212       @nt_tree_fn()
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_any(x1_or_kernel, x2, get, pattern, mask_constant, diagonal_batch, diagonal_spatial, **kwargs)
       4293     """
       4294     if utils.is_nt_tree_of(x1_or_kernel, Kernel):
    -> 4295       return kernel_fn_kernel(x1_or_kernel,
       4296                               pattern=pattern,
       4297                               diagonal_batch=diagonal_batch,
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_kernel(kernel, **kwargs)
       4212 
       4213   def kernel_fn_kernel(kernel, **kwargs):
    -> 4214     out_kernel = kernel_fn(kernel, **kwargs)
       4215     return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs)
       4216 
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
        174     @functools.wraps(f)
        175     def h(*args, **kwargs):
    --> 176       return g(*args, **kwargs)
        177 
        178     h.__signature__ = inspect.signature(f)
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in new_kernel_fn(k, **kwargs)
        191               pass
        192 
    --> 193       return kernel_fn(k, **kwargs)
        194 
        195     setattr(new_kernel_fn, _INPUT_REQ, frozendict.frozendict(static_reqs))
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn(k, **kwargs)
        325     # inside kernel functions here and parallel below.
        326     for f in kernel_fns:
    --> 327       k = f(k, **kwargs)
        328     return k
        329 
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
        174     @functools.wraps(f)
        175     def h(*args, **kwargs):
    --> 176       return g(*args, **kwargs)
        177 
        178     h.__signature__ = inspect.signature(f)
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in getter_fn(*args, **kwargs)
        208                                                           len(args)])
        209 
    --> 210       fn_out = fn(*canonicalized_args, **kwargs)
        211 
        212       @nt_tree_fn()
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_any(x1_or_kernel, x2, get, pattern, mask_constant, diagonal_batch, diagonal_spatial, **kwargs)
       4293     """
       4294     if utils.is_nt_tree_of(x1_or_kernel, Kernel):
    -> 4295       return kernel_fn_kernel(x1_or_kernel,
       4296                               pattern=pattern,
       4297                               diagonal_batch=diagonal_batch,
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_kernel(kernel, **kwargs)
       4212 
       4213   def kernel_fn_kernel(kernel, **kwargs):
    -> 4214     out_kernel = kernel_fn(kernel, **kwargs)
       4215     return _set_shapes(init_fn, apply_fn, kernel, out_kernel, **kwargs)
       4216 
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn_with_masking(k, **user_reqs)
        277         mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
        278 
    --> 279         k = kernel_fn(k, **user_reqs)  # type: Kernel
        280 
        281         if remask_kernel:
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/utils/utils.py in h(*args, **kwargs)
        174     @functools.wraps(f)
        175     def h(*args, **kwargs):
    --> 176       return g(*args, **kwargs)
        177 
        178     h.__signature__ = inspect.signature(f)
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in new_kernel_fn(k, **kwargs)
        191               pass
        192 
    --> 193       return kernel_fn(k, **kwargs)
        194 
        195     setattr(new_kernel_fn, _INPUT_REQ, frozendict.frozendict(static_reqs))
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in kernel_fn(k, **kwargs)
       1506       return out
       1507 
    -> 1508     cov1 = conv(cov1, 1 if k.diagonal_batch else 2)
       1509     cov2 = conv(cov2, 1 if k.diagonal_batch else 2)
       1510 
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in conv(lhs, batch_ndim)
       1502 
       1503     def conv(lhs, batch_ndim):
    -> 1504       out = conv_unscaled(lhs, batch_ndim)
       1505       out = affine(out, W_std**2, b_std**2, batch_ndim)
       1506       return out
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in conv_unscaled(lhs, batch_ndim)
       1477 
       1478     def conv_unscaled(lhs, batch_ndim):
    -> 1479       lhs = conv_kernel(lhs,
       1480                         filter_shape_kernel,
       1481                         strides_kernel,
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in _conv_kernel_full_spatial_shared(lhs, filter_shape, strides, padding, batch_ndim)
       4759     return n_channels
       4760 
    -> 4761   out = _conv_kernel_full_spatial_loop(lhs, filter_shape, strides, padding,
       4762                                        lax_conv, get_n_channels)
       4763   return out
    
    ~/anaconda3/envs/jax/lib/python3.8/site-packages/neural_tangents/stax.py in _conv_kernel_full_spatial_loop(lhs, filter_shape, strides, padding, lax_conv, get_n_channels)
       4912     spatial_i = (i - batch_ndim) // 2
       4913 
    -> 4914     lhs = np.moveaxis(lhs, (i - 1, i), (-2, -1))
       4915     preshape = lhs.shape[:-2]
       4916     n_channels = get_n_channels(utils.size_at(preshape))
    
    ~/jax/jax/_src/numpy/lax_numpy.py in moveaxis(a, source, destination)
       1535     destination_axes = tuple(cast(Sequence[int], destination))
       1536   source_axes = tuple(_canonicalize_axis(i, ndim(a)) for i in source_axes)
    -> 1537   destination_axes = tuple(_canonicalize_axis(i, ndim(a))
       1538                            for i in destination_axes)
       1539   if len(source_axes) != len(destination_axes):
    
    ~/jax/jax/_src/numpy/lax_numpy.py in <genexpr>(.0)
       1535     destination_axes = tuple(cast(Sequence[int], destination))
       1536   source_axes = tuple(_canonicalize_axis(i, ndim(a)) for i in source_axes)
    -> 1537   destination_axes = tuple(_canonicalize_axis(i, ndim(a))
       1538                            for i in destination_axes)
       1539   if len(source_axes) != len(destination_axes):
    
    ~/jax/jax/_src/util.py in canonicalize_axis(axis, num_dims)
        275   axis = operator.index(axis)
        276   if not -num_dims <= axis < num_dims:
    --> 277     raise ValueError(
        278         "axis {} is out of bounds for array of dimension {}".format(
        279             axis, num_dims))
    
    ValueError: axis -2 is out of bounds for array of dimension 1
    

    Is there any misunderstanding of me to the Phase Diagram? (Is CNTK fundamentally un-drawn-able? Also, I've also found that there's totally no difference in Phase Diagram when I simply deeper an FC network, e.g.

    def DenseGroup(n, neurons, W_std, b_std):
        blocks = []
        for _ in range(n):
            blocks += [stax.Dense(neurons, W_std, b_std), stax.Erf()]
        return stax.serial(*blocks)
    
    for layer in range(1,11):
        def c_map(W_var, b_var):
            ...
            kernel_fn = stax.serial(DenseGroup(layer, 1024, W_std, b_std))[2]
            ...
        c_star = lambda W_var, b_var: fixed_point(c_map(W_var, b_var), 0.1, 1e-7)
        chi = lambda c, W_var, b_var: grad(c_map(W_var, b_var))(c)
        chi_1 = partial(chi, 1.)
        
        c_star = jit(vectorize_over_sw_sb(c_star))
        chi_1 = jit(vectorize_over_sw_sb(chi_1))
    
        plt.contourf(W_var, b_var, c_star(W_var, b_var))
    

    Does it mean that the depth of NNs won't affect the initialization?

    enhancement 
    opened by hhorace 1
  • Question: simple example poor performance, what am I doing wrong?

    Question: simple example poor performance, what am I doing wrong?

    Dear team, great package, I'm very excited to use it.

    However, I tried a simple case, and I failed miserably to get a decent performance.

    I generate a multi-dimensional dataset with a relatively simple feature

    import numpy as np
        
    #Create some fake data
    np.random.seed(0)
    m = 1000
    n = 10
    noise_std = 1.
    X = 80*numpy.random.uniform(size=(m,n)) - 40
    y = np.abs(X[:,6] - 4.0) + noise_std * np.random.normal(size=m)
    

    And I followed your examples as

    import neural_tangents as nt
    from neural_tangents import stax
    from sklearn.model_selection import train_test_split
    
    x_train, x_test, y_train, y_test = train_test_split(
        X, y.reshape(-1, 1), test_size=0.4, random_state=42)
    
    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(256), stax.Relu(),
        stax.Dense(1)
    )
    predict_fn = nt.predict.gradient_descent_mse_ensemble(
        kernel_fn, 
        x_train,
        y_train)
    
    # Unpack the predictions namedtuple
    y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'), compute_cov=True)
    

    Visual inspection shows terrible predictions, and loss values are large:

    loss = lambda ypred, y_hat: 0.5 * jnp.mean((ypred - y_hat) ** 2)
    print("loss_nngp = {}".format(loss(y_test_nngp.mean, y_test)))
    print("loss_ntk = {}".format(loss(y_test_ntk.mean, y_test)))
    
    loss_nngp = 6.877374649047852
    loss_ntk = 6.610106468200684
    

    I varied the network in many ways and fiddled with learning_rate and diag_reg, but I hardly changed anything.

    I'm sure I am doing something wrong, but I cannot see what it is. Any obvious mistake?

    Thanks for your help.

    question 
    opened by mfouesneau 5
Releases(v0.6.1)
Using OpenAI's CLIP to upscale and enhance images

CLIP Upscaler and Enhancer Using OpenAI's CLIP to upscale and enhance images Based on nshepperd's JAX CLIP Guided Diffusion v2.4 Sample Results Viewpo

Tripp Lyons 5 Jun 14, 2022
PyTorch implementation for Convolutional Networks with Adaptive Inference Graphs

Convolutional Networks with Adaptive Inference Graphs (ConvNet-AIG) This repository contains a PyTorch implementation of the paper Convolutional Netwo

Andreas Veit 176 Dec 07, 2022
Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet

Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet, CVPR2021 安全AI挑战者计划第六期:ImageNet无限制对抗攻击 决赛第四名(team name: Advers)

51 Dec 01, 2022
Fast and exact ILP-based solvers for the Minimum Flow Decomposition (MFD) problem, and variants of it.

MFD-ILP Fast and exact ILP-based solvers for the Minimum Flow Decomposition (MFD) problem, and variants of it. The solvers are implemented using Pytho

Algorithmic Bioinformatics Group @ University of Helsinki 4 Oct 23, 2022
SEOVER: Sentence-level Emotion Orientation Vector based Conversation Emotion Recognition Model

SEOVER-Master This code is the implementation of paper: SEOVER: Sentence-level Emotion Orientation Vector based Conversation Emotion Recognition Model

4 Feb 24, 2022
An end-to-end project on customer segmentation

End-to-end Customer Segmentation Project Note: This project is in progress. Tools Used in This Project Prefect: Orchestrate workflows hydra: Manage co

Ocelot Consulting 8 Oct 06, 2022
An ML & Correlation platform for transforming disparate data points of interest into usable intelligence.

SSIDprobeCollector An ML & Correlation platform for transforming disparate data points of interest into usable intelligence. At a High level the platf

Bill Reyor 1 Jan 30, 2022
Easy-to-use micro-wrappers for Gym and PettingZoo based RL Environments

SuperSuit introduces a collection of small functions which can wrap reinforcement learning environments to do preprocessing ('microwrappers'). We supp

Farama Foundation 357 Jan 06, 2023
Website which uses Deep Learning to generate horror stories.

Creepypasta - Text Generator Website which uses Deep Learning to generate horror stories. View Demo · View Website Repo · Report Bug · Request Feature

Dhairya Sharma 5 Oct 14, 2022
Integrated Semantic and Phonetic Post-correction for Chinese Speech Recognition

Integrated Semantic and Phonetic Post-correction for Chinese Speech Recognition | paper | dataset | pretrained detection model | Authors: Yi-Chang Che

Yi-Chang Chen 1 Aug 23, 2022
Official repository for "Action-Based Conversations Dataset: A Corpus for Building More In-Depth Task-Oriented Dialogue Systems"

Action-Based Conversations Dataset (ABCD) This respository contains the code and data for ABCD (Chen et al., 2021) Introduction Whereas existing goal-

ASAPP Research 49 Oct 09, 2022
YOLOV4运行在嵌入式设备上

在嵌入式设备上实现YOLO V4 tiny 在嵌入式设备上实现YOLO V4 tiny 目录结构 目录结构 |-- YOLO V4 tiny |-- .gitignore |-- LICENSE |-- README.md |-- test.txt |-- t

Liu-Wei 6 Sep 09, 2021
A stock generator that assess a list of stocks and returns the best stocks for investing and money allocations based on users choices of volatility, duration and number of stocks

Stock-Generator Please visit "Stock Generator.ipynb" for a clearer view and "Stock Generator.py" for scripts. The stock generator is designed to allow

jmengnyay 1 Aug 02, 2022
code for generating data set ES-ImageNet with corresponding training code

es-imagenet-master code for generating data set ES-ImageNet with corresponding training code dataset generator some codes of ODG algorithm The variabl

Ordinarabbit 18 Dec 25, 2022
Time-Optimal Planning for Quadrotor Waypoint Flight

Time-Optimal Planning for Quadrotor Waypoint Flight This is an example implementation of the paper "Time-Optimal Planning for Quadrotor Waypoint Fligh

Robotics and Perception Group 38 Dec 02, 2022
Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

1 Feb 14, 2022
Multi-resolution SeqMatch based long-term Place Recognition

MRS-SLAM for long-term place recognition In this work, we imply an multi-resolution sambling based visual place recognition method. This work is based

METASLAM 6 Dec 06, 2022
This repository contains the code for "SBEVNet: End-to-End Deep Stereo Layout Estimation" paper by Divam Gupta, Wei Pu, Trenton Tabor, Jeff Schneider

SBEVNet: End-to-End Deep Stereo Layout Estimation This repository contains the code for "SBEVNet: End-to-End Deep Stereo Layout Estimation" paper by D

Divam Gupta 19 Dec 17, 2022
A simple python program that can be used to implement user authentication tokens into your program...

token-generator A simple python module that can be used by developers to implement user authentication tokens into your program... code examples creat

octo 6 Apr 18, 2022
Pytorch Implementation of Auto-Compressing Subset Pruning for Semantic Image Segmentation

Pytorch Implementation of Auto-Compressing Subset Pruning for Semantic Image Segmentation Introduction ACoSP is an online pruning algorithm that compr

Merantix 8 Dec 07, 2022