This package requires jax, tensorflow, and numpy. Either tensorflow or scikit-learn can be used for loading data.
To run in a nix-shell with required packages (at specific versions used
nix-shell
Results are generated from main.py, running with arguments required, e.g. python main.py --lr
. The results as described in the paper are in csv files in the results subfolder.
Figures in the paper can be reproduced by running analysis.py. To generate plots with the bounds and errors using the same scale (as described in the appendix), set the variable BOUND_SCALE_AXIS in this file to False.