Re-implementation of 'Grokking: Generalization beyond overfitting on small algorithmic datasets'

Related tags

Deep Learninggrokking
Overview

Re-implementation of the paper 'Grokking: Generalization beyond overfitting on small algorithmic datasets'

Paper

Original paper can be found here

Datasets

I'm not super clear on how they defined their division. I am using integer division:

  • $$x\circ y = (x // y) mod p$$, for some prime $$p$$ and $$0\leq x,y \leq p$$
  • $$x\circ y = (x // y) mod p$$ if y is odd else (x - y) mod p, for some prime $$p$$ and $$0\leq x,y \leq p$$

Hyperparameters

The default hyperparameters are from the paper, but can be adjusted via the command line when running train.py

Running experiments

To run with default settings, simply run python train.py. The first time you train on any dataset you have to specify --force_data.

Arguments:

optimizer args

  • "--lr", type=float, default=1e-3
  • "--weight_decay", type=float, default=1
  • "--beta1", type=float, default=0.9
  • "--beta2", type=float, default=0.98

model args

  • "--num_heads", type=int, default=4
  • "--layers", type=int, default=2
  • "--width", type=int, default=128

data args

  • "--data_name", type=str, default="perm", choices=[
    • "perm_xy", # permutation composition x * y
    • "perm_xyx1", # permutation composition x * y * x^-1
    • "perm_xyx", # permutation composition x * y * x
    • "plus", # x + y
    • "minus", # x - y
    • "div", # x / y
    • "div_odd", # x / y if y is odd else x - y
    • "x2y2", # x^2 + y^2
    • "x2xyy2", # x^2 + y^2 + xy
    • "x2xyy2x", # x^2 + y^2 + xy + x
    • "x3xy", # x^3 + y
    • "x3xy2y" # x^3 + xy^2 + y ]
  • "--num_elements", type=int, default=5 (choose 5 for permutation data, 97 for arithmetic data)
  • "--data_dir", type=str, default="./data"
  • "--force_data", action="store_true", help="Whether to force dataset creation."

training args

  • "--batch_size", type=int, default=512
  • "--steps", type=int, default=10**5
  • "--train_ratio", type=float, default=0.5
  • "--seed", type=int, default=42
  • "--verbose", action="store_true"
  • "--log_freq", type=int, default=10
  • "--num_workers", type=int, default=4
Owner
Tom Lieberum
Master student in AI at the University of Amsterdam. Effective altruist, rationalist, and transhumanist. Got my B.Sc. in Physics from RWTH Aachen Uni
Tom Lieberum
Dashboard for the COVID19 spread

COVID-19 Data Explorer App A streamlit Dashboard for the COVID-19 spread. The app is live at: [https://covid19.cwerner.ai]. New data is queried from G

Christian Werner 22 Sep 29, 2022
Official implementation of the paper 'High-Resolution Photorealistic Image Translation in Real-Time: A Laplacian Pyramid Translation Network' in CVPR 2021

LPTN Paper | Supplementary Material | Poster High-Resolution Photorealistic Image Translation in Real-Time: A Laplacian Pyramid Translation Network Ji

372 Dec 26, 2022
Deep Learning as a Cloud API Service.

Deep API Deep Learning as Cloud APIs. This project provides pre-trained deep learning models as a cloud API service. A web interface is available as w

Wu Han 4 Jan 06, 2023
The mini-MusicNet dataset

mini-MusicNet A music-domain dataset for multi-label classification Music transcription is sequence-to-sequence prediction problem: given an audio per

John Thickstun 4 Nov 09, 2022
City-Scale Multi-Camera Vehicle Tracking Guided by Crossroad Zones Code

City-Scale Multi-Camera Vehicle Tracking Guided by Crossroad Zones Requirements Python 3.8 or later with all requirements.txt dependencies installed,

88 Dec 12, 2022
Code for models used in Bashiri et al., "A Flow-based latent state generative model of neural population responses to natural images".

A Flow-based latent state generative model of neural population responses to natural images Code for "A Flow-based latent state generative model of ne

Sinz Lab 5 Aug 26, 2022
TICC is a python solver for efficiently segmenting and clustering a multivariate time series

TICC TICC is a python solver for efficiently segmenting and clustering a multivariate time series. It takes as input a T-by-n data matrix, a regulariz

406 Dec 12, 2022
[TOG 2021] PyTorch implementation for the paper: SofGAN: A Portrait Image Generator with Dynamic Styling.

This repository contains the official PyTorch implementation for the paper: SofGAN: A Portrait Image Generator with Dynamic Styling. We propose a SofGAN image generator to decouple the latent space o

Anpei Chen 694 Dec 23, 2022
Learning What and Where to Draw

###Learning What and Where to Draw Scott Reed, Zeynep Akata, Santosh Mohan, Samuel Tenka, Bernt Schiele, Honglak Lee This is the code for our NIPS 201

Scott Ellison Reed 337 Nov 18, 2022
Dungeons and Dragons randomized content generator

Component based Dungeons and Dragons generator Supports Entity/Monster Generation NPC Generation Weapon Generation Encounter Generation Environment Ge

Zac 3 Dec 04, 2021
Convolutional neural network web app trained to track our infant’s sleep schedule using our Google Nest camera.

Machine Learning Sleep Schedule Tracker What is it? Convolutional neural network web app trained to track our infant’s sleep schedule using our Google

g-parki 7 Jul 15, 2022
Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch

Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch Reference Paper URL Author: Yi Tay, Dara Bahri, Donald Metzler

Myeongjun Kim 66 Nov 30, 2022
The Ludii general game system, developed as part of the ERC-funded Digital Ludeme Project.

The Ludii General Game System Ludii is a general game system being developed as part of the ERC-funded Digital Ludeme Project (DLP). This repository h

Digital Ludeme Project 50 Jan 04, 2023
Official implementation for CVPR 2021 paper: Adaptive Class Suppression Loss for Long-Tail Object Detection

Adaptive Class Suppression Loss for Long-Tail Object Detection This repo is the official implementation for CVPR 2021 paper: Adaptive Class Suppressio

CASIA-IVA-Lab 67 Dec 04, 2022
Official repository of the paper Privacy-friendly Synthetic Data for the Development of Face Morphing Attack Detectors

SMDD-Synthetic-Face-Morphing-Attack-Detection-Development-dataset Official repository of the paper Privacy-friendly Synthetic Data for the Development

10 Dec 12, 2022
OpenMMLab Computer Vision Foundation

English | 简体中文 Introduction MMCV is a foundational library for computer vision research and supports many research projects as below: MMCV: OpenMMLab

OpenMMLab 4.6k Jan 09, 2023
Code release for Local Light Field Fusion at SIGGRAPH 2019

Local Light Field Fusion Project | Video | Paper Tensorflow implementation for novel view synthesis from sparse input images. Local Light Field Fusion

1.1k Dec 27, 2022
MQBench: Towards Reproducible and Deployable Model Quantization Benchmark

MQBench: Towards Reproducible and Deployable Model Quantization Benchmark We propose a benchmark to evaluate different quantization algorithms on vari

494 Dec 29, 2022
Apache Spark - A unified analytics engine for large-scale data processing

Apache Spark Spark is a unified analytics engine for large-scale data processing. It provides high-level APIs in Scala, Java, Python, and R, and an op

The Apache Software Foundation 34.7k Jan 04, 2023
Efficiently Disentangle Causal Representations

Efficiently Disentangle Causal Representations Install dependency pip install -r requirements.txt Main experiments Causality direction prediction cd

4 Apr 01, 2022