GAN JAX - A toy project to generate images from GANs with JAX

Related tags

Deep LearningGANJax
Overview

GAN JAX - A toy project to generate images from GANs with JAX

This project aims to bring the power of JAX, a Python framework developped by Google and DeepMind to train Generative Adversarial Networks for images generation.

JAX

JAX logo

JAX is a framework developed by Deep-Mind (Google) that allows to build machine learning models in a more powerful (XLA compilation) and flexible way than its counterpart Tensorflow, using a framework almost entirely based on the nd.array of numpy (but stored on the GPU, or TPU if available). It also provides new utilities for gradient computation (per sample, jacobian with backward propagation and forward-propagation, hessian...) as well as a better seed system (for reproducibility) and a tool to batch complicated operations automatically and efficiently.

Github link: https://github.com/google/jax

GAN

GAN diagram

Generative adversarial networks (GANs) are algorithmic architectures that use two neural networks, pitting one against the other (thus the adversarial) in order to generate new, synthetic instances of data that can pass for real data. They are used widely in image generation, video generation and voice generation. GANs were introduced in a paper by Ian Goodfellow and other researchers at the University of Montreal, including Yoshua Bengio, in 2014. Referring to GANs, Facebook’s AI research director Yann LeCun called adversarial training the most interesting idea in the last 10 years in ML. (source)

Original paper: https://arxiv.org/abs/1406.2661

Some ideas have improved the training of the GANs by the years. For example:

Deep Convolution GAN (DCGAN) paper: https://arxiv.org/abs/1511.06434

Progressive Growing GAN (ProGAN) paper: https://arxiv.org/abs/1710.10196

The goal of this project is to implement these ideas in JAX framework.

Installation

You can install JAX following the instruction on JAX - Installation

It is strongly recommended to run JAX on Linux with CUDA available (Windows has no stable support yet). In this case you can install JAX using the following command:

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Then you can install Tensorflow to benefit from tf.data.Dataset to handle the data and the pre-installed dataset. However, Tensorfow allocate memory of the GPU on use (which is not optimal for running calculation with JAX). Therefore, you should install Tensorflow on the CPU instead of the GPU. Visit this site Tensorflow - Installation with pip to install the CPU-only version of Tensorflow 2 depending on your OS and your Python version.

Exemple with Linux and Python 3.9:

pip install tensorflow -f https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow_cpu-2.6.0-cp39-cp39-manylinux2010_x86_64.whl

Then you can install the other librairies from requirements.txt. It will install Haiku and Optax, two usefull add-on libraries to implement and optimize machine learning models with JAX.

pip install -r requirements.txt

Install CelebA dataset (optional)

To use the CelebA dataset, you need to download the dataset from Kaggle and install the images in the folder img_align_celeba/ in data/CelebA/images. It is recommended to download the dataset from this source because the faces are already cropped.

Note: the other datasets will be automatically installed with keras or tensorflow-datasets.

Quick Start

You can test a pretrained GAN model by using apps/test.py. It will download the model from pretrained models (in pre_trained/) and generate pictures. You can change the GAN to test by changing the path in the script.

You can also train your own GAN from scratch with apps/train.py. To change the parameters of the training, you can change the configs in the script. You can also change the dataset or the type of GAN by changing the imports (there is only one workd to change for each).

Example to train a GAN in celeba (64x64):

from utils.data import load_images_celeba_64 as load_images

To train a DCGAN:

from gan.dcgan import DCGAN as GAN

Then you can implement your own GAN and train/test them in your own dataset (by overriding the appropriate functions, check the examples in the repository).

Some results of pre-trained models

- Deep Convolution GAN

  • On MNIST:

DCGAN Cifar10

  • On Cifar10:

DCGAN Cifar10

  • On CelebA (64x64):

DCGAN CelebA-64

- Progressive Growing GAN

  • On MNIST:

  • On Cifar10:

  • On CelebA (64x64):

  • On CelebA (128x128):

Owner
Valentin Goldité
Student at CentraleSupelec (top french Engineer School) specialized in machine learning (Computer Vision, NLP, Audio, RL, Time Analysis).
Valentin Goldité
particle tracking model, works with the ROMS output file(qck.nc, his.nc)

particle-tracking-model-for-ROMS particle tracking model, works with the ROMS output file(qck.nc, his.nc) description this is a 2-dimensional particle

xusheng 1 Jan 11, 2022
Serve TensorFlow ML models with TF-Serving and then create a Streamlit UI to use them

TensorFlow Serving + Streamlit! ✨ 🖼️ Serve TensorFlow ML models with TF-Serving and then create a Streamlit UI to use them! This is a pretty simple S

Álvaro Bartolomé 18 Jan 07, 2023
Reproduce results and replicate training fo T0 (Multitask Prompted Training Enables Zero-Shot Task Generalization)

T-Zero This repository serves primarily as codebase and instructions for training, evaluation and inference of T0. T0 is the model developed in Multit

BigScience Workshop 253 Dec 27, 2022
[CVPR 2021] Modular Interactive Video Object Segmentation: Interaction-to-Mask, Propagation and Difference-Aware Fusion

[CVPR 2021] Modular Interactive Video Object Segmentation: Interaction-to-Mask, Propagation and Difference-Aware Fusion

Rex Cheng 364 Jan 03, 2023
⚓ Eurybia monitor model drift over time and securize model deployment with data validation

View Demo · Documentation · Medium article 🔍 Overview Eurybia is a Python library which aims to help in : Detecting data drift and model drift Valida

MAIF 172 Dec 27, 2022
Optimizers-visualized - Visualization of different optimizers on local minimas and saddle points.

Optimizers Visualized Visualization of how different optimizers handle mathematical functions for optimization. Contents Installation Usage Functions

Gautam J 1 Jan 01, 2022
Python3 / PyTorch implementation of the following paper: Fine-grained Semantics-aware Representation Enhancement for Self-supervisedMonocular Depth Estimation. ICCV 2021 (oral)

FSRE-Depth This is a Python3 / PyTorch implementation of FSRE-Depth, as described in the following paper: Fine-grained Semantics-aware Representation

77 Dec 28, 2022
toroidal - a lightweight transformer library for PyTorch

toroidal - a lightweight transformer library for PyTorch Toroidal transformers are of smaller size and lower weight than the more common E-I types. Th

MathInf GmbH 64 Jan 07, 2023
unet for image segmentation

Implementation of deep learning framework -- Unet, using Keras The architecture was inspired by U-Net: Convolutional Networks for Biomedical Image Seg

zhixuhao 4.1k Dec 31, 2022
Official code for the ICCV 2021 paper "DECA: Deep viewpoint-Equivariant human pose estimation using Capsule Autoencoders"

DECA Official code for the ICCV 2021 paper "DECA: Deep viewpoint-Equivariant human pose estimation using Capsule Autoencoders". All the code is writte

23 Dec 01, 2022
Code artifacts for the submission "Mind the Gap! A Study on the Transferability of Virtual vs Physical-world Testing of Autonomous Driving Systems"

Code Artifacts Code artifacts for the submission "Mind the Gap! A Study on the Transferability of Virtual vs Physical-world Testing of Autonomous Driv

Andrea Stocco 2 Aug 24, 2022
PHOTONAI is a high level python API for designing and optimizing machine learning pipelines.

PHOTONAI is a high level python API for designing and optimizing machine learning pipelines. We've created a system in which you can easily select and

Medical Machine Learning Lab - University of Münster 57 Nov 12, 2022
RoIAlign & crop_and_resize for PyTorch

RoIAlign for PyTorch This is a PyTorch version of RoIAlign. This implementation is based on crop_and_resize and supports both forward and backward on

Long Chen 530 Jan 07, 2023
Python Interview Questions

Python Interview Questions Clone the code to your computer. You need to understand the code in main.py and modify the content in if __name__ =='__main

ClassmateLin 575 Dec 28, 2022
Measuring if attention is explanation with ROAR

NLP ROAR Interpretability Official code for: Evaluating the Faithfulness of Importance Measures in NLP by Recursively Masking Allegedly Important Toke

Andreas Madsen 19 Nov 13, 2022
codes for Self-paced Deep Regression Forests with Consideration on Ranking Fairness

Self-paced Deep Regression Forests with Consideration on Ranking Fairness This is official codes for paper Self-paced Deep Regression Forests with Con

Learning in Vision 4 Sep 11, 2022
PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

StarEnhancer StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral) Abstract: Image enhancement is a subjective process w

IDKiro 133 Dec 28, 2022
Curriculum Domain Adaptation for Semantic Segmentation of Urban Scenes, ICCV 2017

AdaptationSeg This is the Python reference implementation of AdaptionSeg proposed in "Curriculum Domain Adaptation for Semantic Segmentation of Urban

Yang Zhang 128 Oct 19, 2022
Gray Zone Assessment

Gray Zone Assessment Get started Clone github repository git clone https://github.com/andreanne-lemay/gray_zone_assessment.git Build docker image dock

1 Jan 08, 2022
A PyTorch Reimplementation of TecoGAN: Temporally Coherent GAN for Video Super-Resolution

TecoGAN-PyTorch Introduction This is a PyTorch reimplementation of TecoGAN: Temporally Coherent GAN for Video Super-Resolution (VSR). Please refer to

165 Dec 17, 2022