What can linearized neural networks actually say about generalization?

Overview

What can linearized neural networks actually say about generalization?

This is the source code to reproduce the experiments of the NeurIPS 2021 paper "What can linearized neural networks actually say about generalization?" by Guillermo Ortiz-Jimenez, Seyed-Mohsen Moosavi-Dezfooli and Pascal Frossard.

Dependencies

To run the code, please install all its dependencies by running:

$ pip install -r requirements.txt

This assumes that you have access to a Linux machine with an NVIDIA GPU with CUDA>=11.1. Otherwise, please check the instructions to install JAX with your setup in the corresponding repository.

In general, all scripts are parameterized using hydra and their configuration files can be found in the config/ folder.

Experiments

The repository contains code to reproduce the following experiments:

Spectral decomposition of NTK

To generate our new benchmark, consisting on the eigenfunctions of the NTK at initialization, please run the python script compute_ntk.py selecting a desired model (e.g., mlp, lenet or resnet18) and supporting dataset (e.g., cifar10 or mnist). This can be done by running

$ python compute_ntk.py model=lenet data.dataset=cifar10

This script will save the eigenvalues, eigenfunctions and weights of the model under artifacts/eigenfunctions/{data.dataset}/{model}/.

For other configuration options, please consult the configuration file config/compute-ntk/config.yaml.

Warning

Take into account that, for large models, this computation can take very long. For example, it took us two days to compute the full eigenvalue decomposition of the NTK of one randomly initialized ResNet18 using 4 NVIDIA V100 GPUs. The estimation of eigenvectors for the MLP or the LeNet, on the other hand, can be done in a matter of minutes, depending on the number of GPUs available and the selected batch_size

Training on binary eigenfunctions

Once you have estimated the eigenfunctions of the NTK, you should be able to train on any of them. To that end, select the desired label_idx (i.e. eigenfunction index), model and dataset, and run

$ python train_ntk.py label_idx=100 model=lenet data.dataset=cifar10 linearize=False

You can choose to train with the original non-linear network, or its linear approximation by specifying your choice with the flag linearize. For the non-linear models, this script also computes the final alignment of the end NTK with the target function, which it stores under artifacts/eigenfunctions/{data.dataset}/{model}/alignment_plots/

To see the different supported training options, please consult the configuration file config/train-ntk/config.yaml.

Estimation of NADs

We also provide code to compute the NADs of a CNN architecture (e.g., lenet or resnet18) using the alignment with the NTK at initialization. To do so, please run

$ python compute_nads.py model=lenet

This script will save the eigenvalues, NADs and weights of the model under artifacts/nads/{model}/.

For other configuration options, please consult the configuration file config/compute-nads/config.yaml.

Training on linearly separable datasets

Once you have estimated the NADs of a network, you should be able to train on linearly separable datasets with a single NAD as discriminative feature. To that end, select the desired label_idx (i.e. NAD index) and model, and run

$ python train_nads.py label_idx=100 model=lenet linearize=False

You can choose to train with the original non-linear network, or its linear approximation by specifying your choice with the flag linearize.

To see the different supported training options, please consult the configuration file config/train-nads/config.yaml.

Comparison of training dynamics with pretrained NTK

We also provide code to compare the training dynamics of the linearize network at initialization, and after non-linear pretraining, to estimate a particular eigenfunction of the NTK at initialization. To do this, please run

$ python pretrained_ntk_comparison.py label_idx=100 model=lenet data.dataset=cifar10

To see the different supported training options, please consult the configuration file config/pretrained_ntk_comparison/config.yaml.

Training on CIFAR2

Finally, you can train a neural network and its linearize approximation on the binary version of CIFAR10, i.e., CIFAR2. To do this, please run

$ python train_cifar.py model=lenet linearize=False

To see the different supported training options, please consult the configuration file config/binary-cifar/config.yaml.

Reference

If you use this code, please cite the following paper:

@InCollection{Ortiz-JimenezNeurIPS2021,
  title = {What can linearized neural networks actually say about generalization?},
  author = {{Ortiz-Jimenez}, Guillermo and {Moosavi-Dezfooli}, Seyed-Mohsen and Frossard, Pascal},
  booktitle = {Advances in Neural Information Processing Systems 35},
  month = Dec,
  year = {2021}
}
Owner
gortizji
PhD student at EPFL
gortizji
EgGateWayGetShell py脚本

EgGateWayGetShell_py 免责声明 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,作者不为此承担任何责任。 使用 python3 eg.py urls.txt 目标 title:锐捷网络-EWEB网管系统 port:4430 漏洞成因 ?p

榆木 61 Nov 09, 2022
ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

Katherine Crowson 53 Dec 29, 2022
Official Repo of my work for SREC Nandyal Machine Learning Bootcamp

About the Bootcamp A 3-day Machine Learning Bootcamp organised by Department of Electronics and Communication Engineering, Santhiram Engineering Colle

MS 1 Nov 29, 2021
A Python package for faster, safer, and simpler ML processes

Bender 🤖 A Python package for faster, safer, and simpler ML processes. Why use bender? Bender will make your machine learning processes, faster, safe

Otovo 6 Dec 13, 2022
Multi-Anchor Active Domain Adaptation for Semantic Segmentation (ICCV 2021 Oral)

Multi-Anchor Active Domain Adaptation for Semantic Segmentation Munan Ning*, Donghuan Lu*, Dong Wei†, Cheng Bian, Chenglang Yuan, Shuang Yu, Kai Ma, Y

Munan Ning 36 Dec 07, 2022
AdamW optimizer for bfloat16 models in pytorch.

Image source AdamW optimizer for bfloat16 models in pytorch. Bfloat16 is currently an optimal tradeoff between range and relative error for deep netwo

Alex Rogozhnikov 8 Nov 20, 2022
Code for our SIGCOMM'21 paper "Network Planning with Deep Reinforcement Learning".

0. Introduction This repository contains the source code for our SIGCOMM'21 paper "Network Planning with Deep Reinforcement Learning". Notes The netwo

NetX Group 68 Nov 24, 2022
A data-driven maritime port simulator

PySeidon - A Data-Driven Maritime Port Simulator 🌊 Extendable and modular software for maritime port simulation. This software uses entity-component

6 Apr 10, 2022
Local Similarity Pattern and Cost Self-Reassembling for Deep Stereo Matching Networks

Local Similarity Pattern and Cost Self-Reassembling for Deep Stereo Matching Networks Contributions A novel pairwise feature LSP to extract structural

31 Dec 06, 2022
Code for Learning to Segment The Tail (LST)

Learning to Segment the Tail [arXiv] In this repository, we release code for Learning to Segment The Tail (LST). The code is directly modified from th

47 Nov 07, 2022
JFB: Jacobian-Free Backpropagation for Implicit Models

JFB: Jacobian-Free Backpropagation for Implicit Models

Typal Research 28 Dec 11, 2022
Learnable Motion Coherence for Correspondence Pruning

Learnable Motion Coherence for Correspondence Pruning Yuan Liu, Lingjie Liu, Cheng Lin, Zhen Dong, Wenping Wang Project Page Any questions or discussi

liuyuan 41 Nov 30, 2022
Code for the paper Open Sesame: Getting Inside BERT's Linguistic Knowledge.

Open Sesame This repository contains the code for the paper Open Sesame: Getting Inside BERT's Linguistic Knowledge. Credits We built the project on t

9 Jul 24, 2022
Code for WECHSEL: Effective initialization of subword embeddings for cross-lingual transfer of monolingual language models.

WECHSEL Code for WECHSEL: Effective initialization of subword embeddings for cross-lingual transfer of monolingual language models. arXiv: https://arx

Institute of Computational Perception 45 Dec 29, 2022
LIAO Shuiying 6 Dec 01, 2022
A simple algorithm for extracting tree height in sparse scene from point cloud data.

TREE HEIGHT EXTRACTION IN SPARSE SCENES BASED ON UAV REMOTE SENSING This is the offical python implementation of the paper "Tree Height Extraction in

6 Oct 28, 2022
🥈78th place in Riiid Solution🥈

Riiid Answer Correctness Prediction Introduction This repository is the code that placed 78th in Riiid Answer Correctness Prediction competition. Requ

ds wook 14 Apr 26, 2022
PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

Salesforce 1.3k Dec 31, 2022
Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Aviv Gabbay 41 Nov 29, 2022
[ICCV 2021] Excavating the Potential Capacity of Self-Supervised Monocular Depth Estimation

EPCDepth EPCDepth is a self-supervised monocular depth estimation model, whose supervision is coming from the other image in a stereo pair. Details ar

Rui Peng 110 Dec 23, 2022