Simple GAN
This is my attempt to make a wrapper class for a GAN in keras which can be used to abstract the whole architecture process.
Overview
Flow Chart
Setting up a Generative Adversarial Network involves having a discriminator and a generator working in tandem, with the ultimate goal being that the generator can come up with samples that are indistinguishable from valid samples by the discriminator.
Installation
pip install adversarials
Example
import numpy as np
from keras.datasets import mnist
from adversarials.core import Log
from adversarials import SimpleGAN
if __name__ == '__main__':
(X_train, _), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
Log.info('X_train.shape = {}'.format(X_train.shape))
gan = SimpleGAN(save_to_dir="./assets/images",
save_interval=20)
gan.train(X_train, epochs=40)
Documentation
Credits
- Understanding Generative Adversarial Networks - Noaki Shibuya
- Github Keras Gan
- Simple gan
Contribution
You are very welcome to modify and use them in your own projects.
Please keep a link to the original repository. If you have made a fork with substantial modifications that you feel may be useful, then please open a new issue on GitHub with a link and short description.
License (MIT)
This project is opened under the MIT 2.0 License which allows very broad use for both academic and commercial purposes.
A few of the images used for demonstration purposes may be under copyright. These images are included under the "fair usage" laws.
Todo
- Add view training(discriminator and generator) simultaneously using tensorboard
- Provision for Parallel data processing and multithreading
- Saving models to Protobuff files
- Using TfGraphDef and other things that could speed up training and inference