GAN JAX - A toy project to generate images from GANs with JAX

Related tags

Deep LearningGANJax
Overview

GAN JAX - A toy project to generate images from GANs with JAX

This project aims to bring the power of JAX, a Python framework developped by Google and DeepMind to train Generative Adversarial Networks for images generation.

JAX

JAX logo

JAX is a framework developed by Deep-Mind (Google) that allows to build machine learning models in a more powerful (XLA compilation) and flexible way than its counterpart Tensorflow, using a framework almost entirely based on the nd.array of numpy (but stored on the GPU, or TPU if available). It also provides new utilities for gradient computation (per sample, jacobian with backward propagation and forward-propagation, hessian...) as well as a better seed system (for reproducibility) and a tool to batch complicated operations automatically and efficiently.

Github link: https://github.com/google/jax

GAN

GAN diagram

Generative adversarial networks (GANs) are algorithmic architectures that use two neural networks, pitting one against the other (thus the adversarial) in order to generate new, synthetic instances of data that can pass for real data. They are used widely in image generation, video generation and voice generation. GANs were introduced in a paper by Ian Goodfellow and other researchers at the University of Montreal, including Yoshua Bengio, in 2014. Referring to GANs, Facebook’s AI research director Yann LeCun called adversarial training the most interesting idea in the last 10 years in ML. (source)

Original paper: https://arxiv.org/abs/1406.2661

Some ideas have improved the training of the GANs by the years. For example:

Deep Convolution GAN (DCGAN) paper: https://arxiv.org/abs/1511.06434

Progressive Growing GAN (ProGAN) paper: https://arxiv.org/abs/1710.10196

The goal of this project is to implement these ideas in JAX framework.

Installation

You can install JAX following the instruction on JAX - Installation

It is strongly recommended to run JAX on Linux with CUDA available (Windows has no stable support yet). In this case you can install JAX using the following command:

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Then you can install Tensorflow to benefit from tf.data.Dataset to handle the data and the pre-installed dataset. However, Tensorfow allocate memory of the GPU on use (which is not optimal for running calculation with JAX). Therefore, you should install Tensorflow on the CPU instead of the GPU. Visit this site Tensorflow - Installation with pip to install the CPU-only version of Tensorflow 2 depending on your OS and your Python version.

Exemple with Linux and Python 3.9:

pip install tensorflow -f https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow_cpu-2.6.0-cp39-cp39-manylinux2010_x86_64.whl

Then you can install the other librairies from requirements.txt. It will install Haiku and Optax, two usefull add-on libraries to implement and optimize machine learning models with JAX.

pip install -r requirements.txt

Install CelebA dataset (optional)

To use the CelebA dataset, you need to download the dataset from Kaggle and install the images in the folder img_align_celeba/ in data/CelebA/images. It is recommended to download the dataset from this source because the faces are already cropped.

Note: the other datasets will be automatically installed with keras or tensorflow-datasets.

Quick Start

You can test a pretrained GAN model by using apps/test.py. It will download the model from pretrained models (in pre_trained/) and generate pictures. You can change the GAN to test by changing the path in the script.

You can also train your own GAN from scratch with apps/train.py. To change the parameters of the training, you can change the configs in the script. You can also change the dataset or the type of GAN by changing the imports (there is only one workd to change for each).

Example to train a GAN in celeba (64x64):

from utils.data import load_images_celeba_64 as load_images

To train a DCGAN:

from gan.dcgan import DCGAN as GAN

Then you can implement your own GAN and train/test them in your own dataset (by overriding the appropriate functions, check the examples in the repository).

Some results of pre-trained models

- Deep Convolution GAN

  • On MNIST:

DCGAN Cifar10

  • On Cifar10:

DCGAN Cifar10

  • On CelebA (64x64):

DCGAN CelebA-64

- Progressive Growing GAN

  • On MNIST:

  • On Cifar10:

  • On CelebA (64x64):

  • On CelebA (128x128):

Owner
Valentin Goldité
Student at CentraleSupelec (top french Engineer School) specialized in machine learning (Computer Vision, NLP, Audio, RL, Time Analysis).
Valentin Goldité
[ICLR 2021] Is Attention Better Than Matrix Decomposition?

Enjoy-Hamburger 🍔 Official implementation of Hamburger, Is Attention Better Than Matrix Decomposition? (ICLR 2021) Under construction. Introduction T

Gsunshine 271 Dec 29, 2022
EvDistill: Asynchronous Events to End-task Learning via Bidirectional Reconstruction-guided Cross-modal Knowledge Distillation (CVPR'21)

EvDistill: Asynchronous Events to End-task Learning via Bidirectional Reconstruction-guided Cross-modal Knowledge Distillation (CVPR'21) Citation If y

addisonwang 18 Nov 11, 2022
Annotated notes and summaries of the TensorFlow white paper, along with SVG figures and links to documentation

TensorFlow White Paper Notes Features Notes broken down section by section, as well as subsection by subsection Relevant links to documentation, resou

Sam Abrahams 437 Oct 09, 2022
Improving Deep Network Debuggability via Sparse Decision Layers

Improving Deep Network Debuggability via Sparse Decision Layers This repository contains the code for our paper: Leveraging Sparse Linear Layers for D

Madry Lab 35 Nov 14, 2022
Boostcamp AI Tech 3rd / Basic Paper reading w.r.t Embedding

Boostcamp AI Tech 3rd : Basic Paper Reading w.r.t Embedding TL;DR 1992년부터 2018년도까지 이루어진 word/sentence embedding의 중요한 줄기를 이루는 기초 논문 스터디를 진행하고자 합니다. 논

Soyeon Kim 14 Nov 14, 2022
A library to inspect itermediate layers of PyTorch models.

A library to inspect itermediate layers of PyTorch models. Why? It's often the case that we want to inspect intermediate layers of a model without mod

archinet.ai 380 Dec 28, 2022
CMP 414/765 course repository for Spring 2022 semester

CMP414/765: Artificial Intelligence Spring2021 This is the GitHub repository for course CMP 414/765: Artificial Intelligence taught at The City Univer

ch00226855 4 May 16, 2022
Official implementation of EfficientPose

EfficientPose This is the official implementation of EfficientPose. We based our work on the Keras EfficientDet implementation xuannianz/EfficientDet

2 May 17, 2022
PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

R2Plus1D-PyTorch PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal

Irhum Shafkat 342 Dec 16, 2022
An open source library for face detection in images. The face detection speed can reach 1000FPS.

libfacedetection This is an open source library for CNN-based face detection in images. The CNN model has been converted to static variables in C sour

Shiqi Yu 11.4k Dec 27, 2022
A collection of easy-to-use, ready-to-use, interesting deep neural network models

Interesting and reproducible research works should be conserved. This repository wraps a collection of deep neural network models into a simple and un

Aria Ghora Prabono 16 Jun 16, 2022
An implementation of "Learning human behaviors from motion capture by adversarial imitation"

Merel-MoCap-GAIL An implementation of Merel et al.'s paper on generative adversarial imitation learning (GAIL) using motion capture (MoCap) data: Lear

Yu-Wei Chao 34 Nov 12, 2022
🔥🔥High-Performance Face Recognition Library on PaddlePaddle & PyTorch🔥🔥

face.evoLVe: High-Performance Face Recognition Library based on PaddlePaddle & PyTorch Evolve to be more comprehensive, effective and efficient for fa

Zhao Jian 3.1k Jan 02, 2023
Adversarial Self-Defense for Cycle-Consistent GANs

Adversarial Self-Defense for Cycle-Consistent GANs This is the official implementation of the CycleGAN robust to self-adversarial attacks used in pape

Dina Bashkirova 10 Oct 10, 2022
Implementation detail for paper "Multi-level colonoscopy malignant tissue detection with adversarial CAC-UNet"

Multi-level-colonoscopy-malignant-tissue-detection-with-adversarial-CAC-UNet Implementation detail for our paper "Multi-level colonoscopy malignant ti

CVSM Group - email: <a href=[email protected]"> 84 Nov 22, 2022
Time Series Forecasting with Temporal Fusion Transformer in Pytorch

Forecasting with the Temporal Fusion Transformer Multi-horizon forecasting often contains a complex mix of inputs – including static (i.e. time-invari

Nicolás Fornasari 6 Jan 24, 2022
PatrickStar enables Larger, Faster, Greener Pretrained Models for NLP. Democratize AI for everyone.

PatrickStar: Parallel Training of Large Language Models via a Chunk-based Memory Management Meeting PatrickStar Pre-Trained Models (PTM) are becoming

Tencent 633 Dec 28, 2022
Probabilistic Gradient Boosting Machines

PGBM Probabilistic Gradient Boosting Machines (PGBM) is a probabilistic gradient boosting framework in Python based on PyTorch/Numba, developed by Air

Olivier Sprangers 112 Dec 28, 2022
Repository for MuSiQue: Multi-hop Questions via Single-hop Question Composition

🎵 MuSiQue: Multi-hop Questions via Single-hop Question Composition This is the repository for our paper "MuSiQue: Multi-hop Questions via Single-hop

21 Jan 02, 2023