Open-AI's DALL-E for large scale training in mesh-tensorflow.

Overview

DALL-E in Mesh-Tensorflow [WIP]

Open-AI's DALL-E in Mesh-Tensorflow.

If this is similarly efficient to GPT-Neo, this repo should be able to train models up to, and larger than, the size of Open-AI's DALL-E (12B params).

No pretrained models... Yet.

Thanks to Ben Wang for the tf vae implementation as well as getting the mtf version working, and Aran Komatsuzaki for help building the mtf VAE and input pipeline.

Setup

git clone https://github.com/EleutherAI/GPTNeo
cd GPTNeo
pip3 install -r requirements.txt

Training Setup

Runs on TPUs, untested on GPUs but should work in theory. The example configs are designed to run on a TPU v3-32 pod.

To set up TPUs, sign up for Google Cloud Platform, and create a storage bucket.

Create your VM through a google shell (https://ssh.cloud.google.com/) with ctpu up --vm-only so that it can connect to your Google bucket and TPUs and setup the repo as above.

VAE pretraining

DALLE needs a pretrained VAE to compress images to tokens. To run the VAE pretraining, adjust the params in configs/vae_example.json to a glob path pointing to a dataset of jpgs, and adjust image size to the appropriate size.

  "dataset": {
    "train_path": "gs://neo-datasets/CIFAR-10-images/train/**/*.jpg",
    "eval_path": "gs://neo-datasets/CIFAR-10-images/test/**/*.jpg",
    "image_size": 32
  }

Once this is all set up, create your TPU, then run:

python train_vae_tf.py --tpu your_tpu_name --model vae_example

The training logs image tensors and loss values, to check progress, you can run:

tensorboard --logdir your_model_dir

Dataset Creation [DALL-E]

Once the VAE is pretrained, you can move on to DALL-E.

Currently we are training on a dummy dataset. A public, large-scale dataset for DALL-E is in the works. In the meantime, to generate some dummy data, run:

python src/data/create_tfrecords.py

This should download CIFAR-10, and generate some random captions to act as text inputs.

Custom datasets should be formatted in a folder, with a jsonl file in the root folder containing caption data and paths to the respective images, as follows:

Folder structure:

        data_folder
            jsonl_file
            folder_1
                img1
                img2
                ...
            folder_2
                img1
                img2
                ...
            ...

jsonl structure:
    {"image_path": folder_1/img1, "caption": "some words"}
    {"image_path": folder_2/img2, "caption": "more words"}
    ...

you can then use the create_paired_dataset function in src/data/create_tfrecords.py to encode the dataset into tfrecords for use in training.

Once the dataset is created, copy it over to your bucket with gsutil:

gsutil cp -r DALLE-tfrecords gs://neo-datasets/

And finally, run training with

python train_dalle.py --tpu your_tpu_name --model dalle_example

Config Guide

VAE:

{
  "model_type": "vae",
  "dataset": {
    "train_path": "gs://neo-datasets/CIFAR-10-images/train/**/*.jpg", # glob path to training images
    "eval_path": "gs://neo-datasets/CIFAR-10-images/test/**/*.jpg", # glob path to eval images
    "image_size": 32 # size of images (all images will be cropped / padded to this size)
  },
  "train_batch_size": 32, 
  "eval_batch_size": 32,
  "predict_batch_size": 32,
  "steps_per_checkpoint": 1000, # how often to save a checkpoint
  "iterations": 500, # number of batches to infeed to the tpu at a time. Must be < steps_per_checkpoint
  "train_steps": 100000, # total training steps
  "eval_steps": 0, # run evaluation for this many steps every steps_per_checkpoint
  "model_path": "gs://neo-models/vae_test2/", # directory in which to save the model
  "mesh_shape": "data:16,model:2", # mapping of processors to named dimensions - see mesh-tensorflow repo for more info
  "layout": "batch_dim:data", # which named dimensions of the model to split across the mesh - see mesh-tensorflow repo for more info
  "num_tokens": 512, # vocab size
  "dim": 512, 
  "hidden_dim": 64, # size of hidden dim
  "n_channels": 3, # number of input channels
  "bf_16": false, # if true, the model is trained with bfloat16 precision
  "lr": 0.001, # learning rate [by default learning rate starts at this value, then decays to 10% of this value over the course of the training]
  "num_layers": 3, # number of blocks in the encoder / decoder
  "train_gumbel_hard": true, # whether to use hard or soft gumbel_softmax
  "eval_gumbel_hard": true
}

DALL-E:

{
  "model_type": "dalle",
  "dataset": {
    "train_path": "gs://neo-datasets/DALLE-tfrecords/*.tfrecords", # glob path to tfrecords data
    "eval_path": "gs://neo-datasets/DALLE-tfrecords/*.tfrecords",
    "image_size": 32 # size of images (all images will be cropped / padded to this size)
  },
  "train_batch_size": 32, # see above
  "eval_batch_size": 32,
  "predict_batch_size": 32,
  "steps_per_checkpoint": 1000,
  "iterations": 500,
  "train_steps": 100000,
  "predict_steps": 0,
  "eval_steps": 0,
  "n_channels": 3,
  "bf_16": false,
  "lr": 0.001,
  "model_path": "gs://neo-models/dalle_test/",
  "mesh_shape": "data:16,model:2",
  "layout": "batch_dim:data",
  "n_embd": 512, # size of embedding dim
  "text_vocab_size": 50258, # vocabulary size of the text tokenizer
  "image_vocab_size": 512, # vocabulary size of the vae - should equal num_tokens above
  "text_seq_len": 256, # length of text inputs (all inputs longer / shorter will be truncated / padded)
  "n_layers": 6, 
  "n_heads": 4, # number of attention heads. For best performance, n_embd / n_heads should equal 128
  "vae_model": "vae_example" # path to or name of vae model config
}
Use Python, OpenCV, and MediaPipe to control a keyboard with facial gestures

CheekyKeys A Face-Computer Interface CheekyKeys lets you control your keyboard using your face. View a fuller demo and more background on the project

69 Nov 09, 2022
Equivariant layers for RC-complement symmetry in DNA sequence data

Equi-RC Equivariant layers for RC-complement symmetry in DNA sequence data This is a repository that implements the layers as described in "Reverse-Co

7 May 19, 2022
[ICLR2021] Unlearnable Examples: Making Personal Data Unexploitable

Unlearnable Examples Code for ICLR2021 Spotlight Paper "Unlearnable Examples: Making Personal Data Unexploitable " by Hanxun Huang, Xingjun Ma, Sarah

Hanxun Huang 98 Dec 07, 2022
Contrastive Language-Image Pretraining

CLIP [Blog] [Paper] [Model Card] [Colab] CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pair

OpenAI 11.5k Jan 08, 2023
This repository for project that can Automate Number Plate Recognition (ANPR) in Morocco Licensed Vehicles. 💻 + 🚙 + 🇲🇦 = 🤖 🕵🏻‍♂️

MoroccoAI Data Challenge (Edition #001) This Reposotory is result of our work in the comepetiton organized by MoroccoAI in the context of the first Mo

SAFOINE EL KHABICH 14 Oct 31, 2022
Incremental Cross-Domain Adaptation for Robust Retinopathy Screening via Bayesian Deep Learning

Incremental Cross-Domain Adaptation for Robust Retinopathy Screening via Bayesian Deep Learning Update (September 18th, 2021) A supporting document de

Taimur Hassan 1 Mar 16, 2022
On the Analysis of French Phonetic Idiosyncrasies for Accent Recognition

On the Analysis of French Phonetic Idiosyncrasies for Accent Recognition With the spirit of reproducible research, this repository contains codes requ

0 Feb 24, 2022
PyTorch implementation for Graph Contrastive Learning with Augmentations

Graph Contrastive Learning with Augmentations PyTorch implementation for Graph Contrastive Learning with Augmentations [poster] [appendix] Yuning You*

Shen Lab at Texas A&M University 382 Dec 15, 2022
Neural Tangent Generalization Attacks (NTGA)

Neural Tangent Generalization Attacks (NTGA) ICML 2021 Video | Paper | Quickstart | Results | Unlearnable Datasets | Competitions | Citation Overview

Chia-Hung Yuan 34 Nov 25, 2022
MMFlow is an open source optical flow toolbox based on PyTorch

Documentation: https://mmflow.readthedocs.io/ Introduction English | 简体中文 MMFlow is an open source optical flow toolbox based on PyTorch. It is a part

OpenMMLab 688 Jan 06, 2023
Vrcwatch - Supply the local time to VRChat as Avatar Parameters through OSC

English: README-EN.md VRCWatch VRCWatch は、VRChat 内のアバター向けに現在時刻を送信するためのプログラムです。 使

Kosaki Mezumona 17 Nov 30, 2022
A collection of differentiable SVD methods and also the official implementation of the ICCV21 paper "Why Approximate Matrix Square Root Outperforms Accurate SVD in Global Covariance Pooling?"

Differentiable SVD Introduction This repository contains: The official Pytorch implementation of ICCV21 paper Why Approximate Matrix Square Root Outpe

YueSong 32 Dec 25, 2022
Keywords : Streamlit, BertTokenizer, BertForMaskedLM, Pytorch

Next Word Prediction Keywords : Streamlit, BertTokenizer, BertForMaskedLM, Pytorch 🎬 Project Demo ✔ Application is hosted on Streamlit. You can see t

Vivek7 3 Aug 26, 2022
This package contains a PyTorch Implementation of IB-GAN of the submitted paper in AAAI 2021

The PyTorch implementation of IB-GAN model of AAAI 2021 This package contains a PyTorch implementation of IB-GAN presented in the submitted paper (IB-

Insu Jeon 9 Mar 30, 2022
ANEA: Automated (Named) Entity Annotation for German Domain-Specific Texts

ANEA The goal of Automatic (Named) Entity Annotation is to create a small annotated dataset for NER extracted from German domain-specific texts. Insta

Anastasia Zhukova 2 Oct 07, 2022
EASY - Ensemble Augmented-Shot Y-shaped Learning: State-Of-The-Art Few-Shot Classification with Simple Ingredients.

EASY - Ensemble Augmented-Shot Y-shaped Learning: State-Of-The-Art Few-Shot Classification with Simple Ingredients. This repository is the official im

Yassir BENDOU 57 Dec 26, 2022
Global Pooling, More than Meets the Eye: Position Information is Encoded Channel-Wise in CNNs, ICCV 2021

Global Pooling, More than Meets the Eye: Position Information is Encoded Channel-Wise in CNNs, ICCV 2021 Global Pooling, More than Meets the Eye: Posi

Md Amirul Islam 32 Apr 24, 2022
Official pytorch code for SSC-GAN: Semi-Supervised Single-Stage Controllable GANs for Conditional Fine-Grained Image Generation(ICCV 2021)

SSC-GAN_repo Pytorch implementation for 'Semi-Supervised Single-Stage Controllable GANs for Conditional Fine-Grained Image Generation'.PDF SSC-GAN:Sem

tyty 4 Aug 28, 2022
Predicting Event Memorability from Contextual Visual Semantics

Predicting Event Memorability from Contextual Visual Semantics

0 Oct 06, 2021
A general-purpose programming language, focused on simplicity, safety and stability.

The Rivet programming language A general-purpose programming language, focused on simplicity, safety and stability. Rivet's goal is to be a very power

The Rivet programming language 17 Dec 29, 2022