Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two

Overview

512x512 flowers after 12 hours of training, 1 gpu

256x256 flowers after 12 hours of training, 1 gpu

Pizza

'Lightweight' GAN

PyPI version

Implementation of 'lightweight' GAN proposed in ICLR 2021, in Pytorch. The main contributions of the paper is a skip-layer excitation in the generator, paired with autoencoding self-supervised learning in the discriminator. Quoting the one-line summary "converge on single gpu with few hours' training, on 1024 resolution sub-hundred images".

Install

$ pip install lightweight-gan

Use

One command

$ lightweight_gan --data ./path/to/images --image-size 512

Model will be saved to ./models/{name} every 1000 iterations, and samples from the model saved to ./results/{name}. name will be default, by default.

Training settings

Pretty self explanatory for deep learning practitioners

$ lightweight_gan \
    --data ./path/to/images \
    --name {name of run} \
    --batch-size 16 \
    --gradient-accumulate-every 4 \
    --num-train-steps 200000

Augmentation

Augmentation is essential for Lightweight GAN to work effectively in a low data setting

By default, the augmentation types is set to translation and cutout, with color omitted. You can include color as well with the following.

$ lightweight_gan --data ./path/to/images --aug-prob 0.25 --aug-types [translation,cutout,color]

Test augmentation

You can test and see how your images will be augmented before it pass into a neural network (if you use augmentation). Let's see how it works on this image:

Basic usage

Base code to augment your image, define --aug-test and put path to your image into --data:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg

After this will be created the file lena_augs.jpg that will be look something like this:

Options

You can use some options to change result:

  • --image-size 256 to change size of image tiles in the result. Default: 256.
  • --aug-type [color,cutout,translation] to combine several augmentations. Default: [cutout,translation].
  • --batch-size 10 to change count of images in the result image. Default: 10.
  • --num-image-tiles 5 to change count of tiles in the result image. Default: 5.

Try this command:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg \
    --batch-size 16 \
    --num-image-tiles 4 \
    --aug-types [color,translation]

result wil be something like that:

Types of augmentations

This library contains several types of embedded augmentations.
Some of these works by default, some of these can be controlled from a command as options in the --aug-types:

  • Horizontal flip (work by default, not under control, runs in the AugWrapper class);
  • color randomly change brightness, saturation and contrast;
  • cutout creates random black boxes on the image;
  • offset randomly moves image by x and y-axis with repeating image;
    • offset_h only by an x-axis;
    • offset_v only by a y-axis;
  • translation randomly moves image on the canvas with black background;

Full setup of augmentations is --aug-types [color,cutout,offset,translation].
General recommendation is using suitable augs for your data and as many as possible, then after sometime of training disable most destructive (for image) augs.

Color

Cutout

Offset

Only x-axis:

Only y-axis:

Translation

Mixed precision

You can turn on automatic mixed precision with one flag --amp

You should expect it to be 33% faster and save up to 40% memory

Multiple GPUs

Also one flag to use --multi-gpus

Generating

Once you have finished training, you can generate samples with one command. You can select which checkpoint number to load from. If --load-from is not specified, will default to the latest.

$ lightweight_gan \
  --name {name of run} \
  --load-from {checkpoint num} \
  --generate \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After run this command you will get folder near results image folder with postfix "-generated-{checkpoint num}".

You can also generate interpolations

$ lightweight_gan --name {name of run} --generate-interpolation

Show progress

After creating several checkpoints of model you can generate progress as sequence images by command:

$ lightweight_gan \
  --name {name of run} \
  --show-progress \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After running this command you will get a new folder in the results folder, with postfix "-progress". You can convert the images to a video with ffmpeg using the command "ffmpeg -framerate 10 -pattern_type glob -i '*-ema.jpg' out.mp4".

Show progress gif demonstration

Show progress video demonstration

Discriminator output size

The author has kindly let me know that the discriminator output size (5x5 vs 1x1) leads to different results on different datasets. (5x5 works better for art than for faces, as an example). You can toggle this with a single flag

# disc output size is by default 1x1
$ lightweight_gan --data ./path/to/art --image-size 512 --disc-output-size 5

Attention

You can add linear + axial attention to specific resolution layers with the following

# make sure there are no spaces between the values within the brackets []
$ lightweight_gan --data ./path/to/images --image-size 512 --attn-res-layers [32,64] --aug-prob 0.25

Bonus

You can also train with transparent images

$ lightweight_gan --data ./path/to/images --transparent

Or greyscale

$ lightweight_gan --data ./path/to/images --greyscale

Alternatives

If you want the current state of the art GAN, you can find it at https://github.com/lucidrains/stylegan2-pytorch

Citations

@inproceedings{
    anonymous2021towards,
    title={Towards Faster and Stabilized {\{}GAN{\}} Training for High-fidelity Few-shot Image Synthesis},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=1Fqg133qRaI},
    note={under review}
}
@inproceedings{
    anonymous2021global,
    title={Global Self-Attention Networks},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=KiFeuZu24k},
    note={under review}
}
@misc{cao2020global,
    title={Global Context Networks},
    author={Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
    year={2020},
    eprint={2012.13375},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{qin2020fcanet,
    title={FcaNet: Frequency Channel Attention Networks},
    author={Zequn Qin and Pengyi Zhang and Fei Wu and Xi Li},
    year={2020},
    eprint={2012.11879},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{sinha2020topk,
    title={Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples},
    author={Samarth Sinha and Zhengli Zhao and Anirudh Goyal and Colin Raffel and Augustus Odena},
    year={2020},
    eprint={2002.06224},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}

What I cannot create, I do not understand - Richard Feynman

Comments
  • Troubles with global context module in 0.15.0

    Troubles with global context module in 0.15.0

    @lucidrains

    After update to this version https://github.com/lucidrains/lightweight-gan/releases/tag/0.15.0 I cant continue train my network and did start in from zero. Previous version was in state 117k batches by 4 (468k images, around 66 hours of trainig) image and was pretty good. In new version 0.15.0 on same dataset with same parameters (--image-size 1024 --aug-types [color,offset_h] --aug-prob 1 --amp --batch-size 7) after 77k batches by 7 (539k images, around 49 hours of training) I see some bugs like oil puddle. Did you meet this or do you know how avoid this?

    image

    In previous version with sle-spatial I didnt meet something like this.

    opened by Dok11 9
  • What is sle_spatial?

    What is sle_spatial?

    I have seen this function argument mentioned in this issue:

    https://github.com/lucidrains/lightweight-gan/issues/14#issuecomment-733432989

    What is sle_spatial?

    opened by woctezuma 8
  • unable to load save model. please try downgrading the package to the version specified by the saved model

    unable to load save model. please try downgrading the package to the version specified by the saved model

    I have the following problem since today. How to do/solve this?

    continuing from previous epoch - 118 loading from version 0.21.4 unable to load save model. please try downgrading the package to the version specified by the saved model Traceback (most recent call last): File "/opt/conda/bin/lightweight_gan", line 8, in sys.exit(main()) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 193, in main fire.Fire(train_from_folder) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 184, in train_from_folder run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 59, in run_training model.load(load_from) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1603, in load raise e File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1600, in load self.GAN.load_state_dict(load_data['GAN']) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightweightGAN: Missing key(s) in state_dict: "G.layers.0.0.2.1.weight", "G.layers.0.0.2.1.bias", "G.layers.0.0.4.weight", "G.layers.0.0.4.bias", "G.layers.0.0.4.running_mean", "G.layers.0.0.4.running_var", "G.layers.1.0.2.1.weight", "G.layers.1.0.2.1.bias", "G.layers.1.0.4.weight", "G.layers.1.0.4.bias", "G.layers.1.0.4.running_mean", "G.layers.1.0.4.running_var", "G.layers.2.0.2.1.weight", "G.layers.2.0.2.1.bias", "G.layers.2.0.4.weight", "G.layers.2.0.4.bias", "G.layers.2.0.4.running_mean", "G.layers.2.0.4.running_var", "G.layers.3.0.2.1.weight", "G.layers.3.0.2.1.bias", "G.layers.3.0.4.weight", "G.layers.3.0.4.bias", "G.layers.3.0.4.running_mean", "G.layers.3.0.4.running_var", "G.layers.3.2.fn.to_lin_q.weight", "G.layers.3.2.fn.to_lin_kv.net.0.weight", "G.layers.3.2.fn.to_lin_kv.net.1.weight", "G.layers.3.2.fn.to_kv.weight", "G.layers.4.0.2.1.weight", "G.layers.4.0.2.1.bias", "G.layers.4.0.4.weight", "G.layers.4.0.4.bias", "G.layers.4.0.4.running_mean", "G.layers.4.0.4.running_var", "G.layers.5.0.2.1.weight", "G.layers.5.0.2.1.bias", "G.layers.5.0.4.weight", "G.layers.5.0.4.bias", "G.layers.5.0.4.running_mean", "G.layers.5.0.4.running_var", "D.residual_layers.3.1.fn.to_lin_q.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D.residual_layers.3.1.fn.to_kv.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.weight", "GE.layers.0.0.2.1.weight", "GE.layers.0.0.2.1.bias", "GE.layers.0.0.4.weight", "GE.layers.0.0.4.bias", "GE.layers.0.0.4.running_mean", "GE.layers.0.0.4.running_var", "GE.layers.1.0.2.1.weight", "GE.layers.1.0.2.1.bias", "GE.layers.1.0.4.weight", "GE.layers.1.0.4.bias", "GE.layers.1.0.4.running_mean", "GE.layers.1.0.4.running_var", "GE.layers.2.0.2.1.weight", "GE.layers.2.0.2.1.bias", "GE.layers.2.0.4.weight", "GE.layers.2.0.4.bias", "GE.layers.2.0.4.running_mean", "GE.layers.2.0.4.running_var", "GE.layers.3.0.2.1.weight", "GE.layers.3.0.2.1.bias", "GE.layers.3.0.4.weight", "GE.layers.3.0.4.bias", "GE.layers.3.0.4.running_mean", "GE.layers.3.0.4.running_var", "GE.layers.3.2.fn.to_lin_q.weight", "GE.layers.3.2.fn.to_lin_kv.net.0.weight", "GE.layers.3.2.fn.to_lin_kv.net.1.weight", "GE.layers.3.2.fn.to_kv.weight", "GE.layers.4.0.2.1.weight", "GE.layers.4.0.2.1.bias", "GE.layers.4.0.4.weight", "GE.layers.4.0.4.bias", "GE.layers.4.0.4.running_mean", "GE.layers.4.0.4.running_var", "GE.layers.5.0.2.1.weight", "GE.layers.5.0.2.1.bias", "GE.layers.5.0.4.weight", "GE.layers.5.0.4.bias", "GE.layers.5.0.4.running_mean", "GE.layers.5.0.4.running_var", "D_aug.D.residual_layers.3.1.fn.to_lin_q.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.weight". Unexpected key(s) in state_dict: "G.layers.0.0.2.weight", "G.layers.0.0.2.bias", "G.layers.0.0.3.bias", "G.layers.0.0.3.running_mean", "G.layers.0.0.3.running_var", "G.layers.0.0.3.num_batches_tracked", "G.layers.1.0.2.weight", "G.layers.1.0.2.bias", "G.layers.1.0.3.bias", "G.layers.1.0.3.running_mean", "G.layers.1.0.3.running_var", "G.layers.1.0.3.num_batches_tracked", "G.layers.2.0.2.weight", "G.layers.2.0.2.bias", "G.layers.2.0.3.bias", "G.layers.2.0.3.running_mean", "G.layers.2.0.3.running_var", "G.layers.2.0.3.num_batches_tracked", "G.layers.3.0.2.weight", "G.layers.3.0.2.bias", "G.layers.3.0.3.bias", "G.layers.3.0.3.running_mean", "G.layers.3.0.3.running_var", "G.layers.3.0.3.num_batches_tracked", "G.layers.3.2.fn.to_kv.net.0.weight", "G.layers.3.2.fn.to_kv.net.1.weight", "G.layers.4.0.2.weight", "G.layers.4.0.2.bias", "G.layers.4.0.3.bias", "G.layers.4.0.3.running_mean", "G.layers.4.0.3.running_var", "G.layers.4.0.3.num_batches_tracked", "G.layers.5.0.2.weight", "G.layers.5.0.2.bias", "G.layers.5.0.3.bias", "G.layers.5.0.3.running_mean", "G.layers.5.0.3.running_var", "G.layers.5.0.3.num_batches_tracked", "D.residual_layers.3.1.fn.to_kv.net.0.weight", "D.residual_layers.3.1.fn.to_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight", "GE.layers.0.0.2.weight", "GE.layers.0.0.2.bias", "GE.layers.0.0.3.bias", "GE.layers.0.0.3.running_mean", "GE.layers.0.0.3.running_var", "GE.layers.0.0.3.num_batches_tracked", "GE.layers.1.0.2.weight", "GE.layers.1.0.2.bias", "GE.layers.1.0.3.bias", "GE.layers.1.0.3.running_mean", "GE.layers.1.0.3.running_var", "GE.layers.1.0.3.num_batches_tracked", "GE.layers.2.0.2.weight", "GE.layers.2.0.2.bias", "GE.layers.2.0.3.bias", "GE.layers.2.0.3.running_mean", "GE.layers.2.0.3.running_var", "GE.layers.2.0.3.num_batches_tracked", "GE.layers.3.0.2.weight", "GE.layers.3.0.2.bias", "GE.layers.3.0.3.bias", "GE.layers.3.0.3.running_mean", "GE.layers.3.0.3.running_var", "GE.layers.3.0.3.num_batches_tracked", "GE.layers.3.2.fn.to_kv.net.0.weight", "GE.layers.3.2.fn.to_kv.net.1.weight", "GE.layers.4.0.2.weight", "GE.layers.4.0.2.bias", "GE.layers.4.0.3.bias", "GE.layers.4.0.3.running_mean", "GE.layers.4.0.3.running_var", "GE.layers.4.0.3.num_batches_tracked", "GE.layers.5.0.2.weight", "GE.layers.5.0.2.bias", "GE.layers.5.0.3.bias", "GE.layers.5.0.3.running_mean", "GE.layers.5.0.3.running_var", "GE.layers.5.0.3.num_batches_tracked", "D_aug.D.residual_layers.3.1.fn.to_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight". size mismatch for G.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for G.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]). size mismatch for GE.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for GE.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D_aug.D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]).

    opened by sebastiantrella 7
  • Greyscale image generation

    Greyscale image generation

    Hi,

    thank you for this repo, I've been playing with it a bit and it seems very good! I am trying to generate greyscale images, so I modified the channel accordingly

    init_channel = 4 if transparent else 1

    unfortunately, this seemed to have no effect as the images generated are still RGB (even though they converge towards greyscale with time), even weirder IMO is that I can modify the number of channels for the generator and keep the original 3 for the discriminator without any issue.

    I have also changed this part to no effect

    convert_image_fn = partial(convert_image_to, 'RGBA' if transparent else 'L') num_channels = 1 if not transparent else 4

    Am I missing something here?

    opened by stefanorosss 7
  • Getting NoneType is not subscriptable when trying to start training.

    Getting NoneType is not subscriptable when trying to start training.

    I've been able to train models before but after changing my dataset I'm getting the error.

    My trace: File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/lightweight_gan.py", line 1356, in load name = checkpoints[-1] TypeError: 'NoneType' object is not subscriptable

    opened by TomCallan 6
  • Optimal parameters for Google Colab

    Optimal parameters for Google Colab

    Hello,

    First of all, thank you for sharing your code and insights with the rest of us!

    As for your code, I plan to run it for 12 hours on Google Colab, similarly to the set-up for what is shown in the README.

    My datasets consists of images of 256x256 resolution, and I have started training with the following command-line:

    !lightweight_gan \
     --data {image_dir} \
     --disc-output-size 5 \
     --aug-prob 0.25 \
     --aug-types [translation,cutout,color] \
     --amp \
    

    I have noticed that the expected training time is 112.5 hours with 150k iterations (the default setting), which is consistent with the average time of 2.7 seconds per iteration shown in the log. However, it is ~ 9 times more than what is shown in the README. So I wonder if I am doing something wrong, and I see 2 solutions.

    First, I could decrease the number of iterations so that it takes 12 hours, by choosing 16k iterations instead of 150k with:

     --num-train-steps 16000 \
    

    Is it what you have done for the results shown in the README?

    Second, I have noticed that I am only using 3.8 GB of GPU memory, so I could increase the batch size, as you mentioned in https://github.com/lucidrains/lightweight-gan/issues/13#issuecomment-732486110. Edit: However, the training time increases with a larger batch size. For instance, I am using 7.2 GB of GPU memory, and it takes 8.2 seconds per iteration, with the following:

     --batch-size 32 \
     --gradient-accumulate-every 4 \
    
    opened by woctezuma 6
  • Added Experimentation Tracking.

    Added Experimentation Tracking.

    Added Experimentation Tracking using Aim.

    Now you can:

    Track all the model hyperparameters and architectural choices. Track all types of losses. Filter all the experiments with respect to hyperparameters or the architecture Group and aggregate w.r.t. all the trackables to dive into granular experimentation assessment. Track the generated images to track how the model improves.

    Screen Shot 2022-04-12 at 16 56 35 Screen Shot 2022-04-12 at 16 57 24
    opened by hnhnarek 5
  • Aim installation error

    Aim installation error

    I'm trying to run the generator after training, to generate fake samples using the following command

    lightweight_gan --generate --load-from 299

    I get this following error:

    Traceback (most recent call last):
      File "C:\anaconda3\lib\runpy.py", line 197, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "C:\anaconda3\lib\runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "C:\anaconda3\Scripts\lightweight_gan.exe\__main__.py", line 7, in <module>
      File "C:\anaconda3\lib\site-packages\lightweight_gan\cli.py", line 195, in main
        fire.Fire(train_from_folder)
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 141, in Fire
        component_trace = _Fire(component, args, parsed_flag_args, context, name)
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 466, in _Fire
        component, remaining_args = _CallAndUpdateTrace(
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 681, in _CallAndUpdateTrace
        component = fn(*varargs, **kwargs)
      File "C:\anaconda3\lib\site-packages\lightweight_gan\cli.py", line 158, in train_from_folder
        model = Trainer(**model_args)
      File "C:\anaconda3\lib\site-packages\lightweight_gan\lightweight_gan.py", line 1057, in __init__
        self.run = self.aim.Run(run_hash=aim_run_hash, repo=aim_repo)
    AttributeError: 'Trainer' object has no attribute 'aim'
    

    and when I try to run pip install aim, I get a dependency error with aimrocks

      ERROR: Command errored out with exit status 1:
       command: 'C:\Anaconda3\envs\aerialweb\python.exe' 'C:\Anaconda3\envs\aerialweb\lib\site-packages\pip' install --ignore-installed --no-user --prefix 'C:\Users\ahmed\AppData\Local\Temp\pip-build-env-b2ysw94t\overlay' --no-warn-script-location --no-binary :none: --only-binary :none: -i https://pypi.org/simple -- setuptools 'cython >= 3.0.0a9' 'aimrocks == 0.2.1'
           cwd: None
      Complete output (12 lines):
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
      Collecting setuptools
        Using cached setuptools-59.6.0-py3-none-any.whl (952 kB)
      Collecting cython>=3.0.0a9
        Using cached Cython-3.0.0a10-py2.py3-none-any.whl (1.1 MB)
      ERROR: Could not find a version that satisfies the requirement aimrocks==0.2.1 (from versions: 0.1.3a14, 0.2.0.dev1, 0.2.0)
      ERROR: No matching distribution found for aimrocks==0.2.1
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
    

    What is aimrocks and what does it actually do? I am unable to find a matching distribution or even a wheel file to install it manually. Please help

    opened by demiahmed 4
  • Can't find

    Can't find "__main__" module (sorry if noob question)

    Hello, I hope it's not too much of a noob question, I don't have any background in coding.

    After creating the env and installing Pytorch I ran "python setup.py install" and then I ran "python lightweight_gan --data /source --image-size 512" (I filled a "source" folder with pictures of fishes) but I get the error "can't find 'main' module". More exactly, C:\Programmes perso\Logiciels\Anaconda\envs\lightweightgan\python.exe: can't find 'main' module in 'C:\Programmes perso\Logiciels\LightweightGan\lightweight_gan' I tried to copy and rename some of the other modules (init, lightweight_gan...), the code seems to start to run but stops before doing anything. So I guess some file must be missing, or did I do something wrong ?

    Thanks a lot for the repo and have a nice day

    opened by SPotaufeux 4
  • Hard cutoff straight lines/boxes of nothing in generated images

    Hard cutoff straight lines/boxes of nothing in generated images

    Hello! Training on Google Colab with

    !lightweight_gan --data my/images/ --name my-name --image-size 256 --transparent --dual-contrast-loss --num-train-steps 250000
    

    I'm at 250k iterations over the course of 5 days at 2s/it, and have gotten strange results with boxes.

    I've circled some examples of this below. image

    My training data is 22k images of 256x256 .pngs that do not contain large hard edges or boxes like this. They're video game sprites with hard edges being limited to at most 10x10px

    Are there any suggestions I can do with arguments in order to decrease the chance of the models learning that transparent boxes are good? Would converting to a white background help?

    Thank you!

    opened by timendez 4
  • Amount of training steps

    Amount of training steps

    If I bring down the number of training steps from 150 000 to 30 000, will the trained model be overall bad? Does it really need the 100 000 or 150 000 training steps?

    opened by MartimQS 4
  • Executing with a trailing \ in the arguments sets the variable new to the truthy value '\\' and deletes all progress

    Executing with a trailing \ in the arguments sets the variable new to the truthy value '\\' and deletes all progress

    A rather frustrating issue:

    calling it with a trailing \ like lightweight_gan --data full_cleaned256/ --name c256 --models_dir models --results_dir results \

    sets the variable new to the truthy value '\' and deletes all progress.

    This might well be an issue with Fire but might be mitigated or fixed here too, I am unsure about that.

    Thanks. Jonas

    opened by deklesen 0
  • Projecting generated Images to Latent Space

    Projecting generated Images to Latent Space

    Is there any way to reverse engineer the generated images into the latent space?

    I am trying to embed fresh RGB as well as ones generated by the Generator into the latent space so I can find its nearest neighbour, pretty much like AI image editing tools.

    I plan to convert my RGB image into tensor embeddings based on my trained model and tweak the feature vectors.

    How can I achieve this with lightweight-gan?

    opened by demiahmed 0
  • Discriminator Loss converges to 0 while Generator loss pretty high

    Discriminator Loss converges to 0 while Generator loss pretty high

    I am trying to train with a custom image dataset for about 600,000 epochs. At about halfway, my D_loss converges to 0 while my G_loss stays put at 2.5

    My evaluation outputs are slowly starting to fade out to either black or white.

    Is there any thing that I could to tweak my model? Either by increasing the threshold for the Discriminator or by training the Generator only?

    opened by demiahmed 3
  • loss implementation differs from paper

    loss implementation differs from paper

    Hi,

    Thanks for this amazing implementation! I have a question concerning the loss implementation, as it seems to differ from the original equations. The screenshot below shows the GAN loss as presented in the paper :

    paper_losses

    • in red, the discriminator loss (D loss) on the true labels,
    • in green the D loss on labels for fake generated images,
    • and in blue, the generator loss (G loss) on labels for fake images.

    This makes sense to me. Since it is assumed that D outputs values between 0 and 1 (0 = fake, 1 = real) :

    • in red, we want D to output 1 for true images → let's assume D indeed outputs 1 for true images : -min(0, -1 + D(x)) = 0, which is indeed the minimum achievable,
    • in green, we want D to output 0 (from the discriminator perspective) for fake images → let's assume D indeed outputs 0 for fake images : -min(0, -1 - D(x^)) = 1, which is the minimum achievable if D outputs values only between 0 and 1,
    • in blue, we want D to output 1 (from the generator perspective) for fake images : the equation follows directly.

    Now, the way the authors implement this in the code provided in the supplementary materials of the paper is as follows (the colors match the ones in the above picture)

    og_code_loss_d_real og_code_loss_d_fake og_code_loss_g

    Except for the strange involved randomness (already explained in https://github.com/lucidrains/lightweight-gan/issues/11), their implementation is a one to one match with the paper equations.


    The way it is implemented in this repo however is quite different, and I do not understand why..

    lighweight_gan_losses

    Let's start with the discriminator loss :

    • in red, you want D to output small values (negative if allowed), to set this term as small as possible (0 if D can output negative values)
    • in green, you want D to output values as large as possible (larger or equal to 1) to cancel this term out as well

    For the generator loss :

    • in blue, you want the opposite of green, that is for D to output values as small as possible

    This implementation seems to be meaningful, and yields coherent results (as proven in examples). It also seems to me that D is not limited to output values between 0 and 1, but any real value (I might be wrong). I am just wondering why this choice? Could you perhaps elaborate why you decided to implement the loss differently from the original paper?

    opened by maximeraafat 1
  • showing results while training ?

    showing results while training ?

    how to show generator results after every epoch during training ?

    this is my current configuration

     lightweight_gan \
      --data "/content/dataset/Dataset/" \
      --num-train-steps 100000 \
      --image-size 128 \
      --name GAN2DBlood5k \
      --batch-size 32 \
      --gradient-accumulate-every 5 \
      --disc-output-size 1 \
      --dual-contrast-loss \
      --attn-res-layers [] \
      --calculate_fid_every 1000\
      --greyscale \
      --amp
    

    using --show-progress only works after training. Also it seems that there is no longer checkpoints per epoch

    opened by galaelized 2
Releases(1.1.1)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
An open-source, low-cost, image-based weed detection device for fallow scenarios.

Welcome to the OpenWeedLocator (OWL) project, an opensource hardware and software green-on-brown weed detector that uses entirely off-the-shelf compon

Guy Coleman 145 Jan 05, 2023
Solutions of Reinforcement Learning 2nd Edition

Solutions of Reinforcement Learning, An Introduction

YIFAN WANG 1.4k Dec 30, 2022
Deep and online learning with spiking neural networks in Python

Introduction The brain is the perfect place to look for inspiration to develop more efficient neural networks. One of the main differences with modern

Jason Eshraghian 447 Jan 03, 2023
Single-step adversarial training (AT) has received wide attention as it proved to be both efficient and robust.

Subspace Adversarial Training Single-step adversarial training (AT) has received wide attention as it proved to be both efficient and robust. However,

15 Sep 02, 2022
Codes for "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation"

CSDI This is the github repository for the NeurIPS 2021 paper "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation

106 Jan 04, 2023
Implementation of a Transformer, but completely in Triton

Transformer in Triton (wip) Implementation of a Transformer, but completely in Triton. I'm completely new to lower-level neural net code, so this repo

Phil Wang 152 Dec 22, 2022
Faster Convex Lipschitz Regression

Faster Convex Lipschitz Regression This reepository provides a python implementation of our Faster Convex Lipschitz Regression algorithm with GPU and

Ali Siahkamari 0 Nov 19, 2021
SSD: Single Shot MultiBox Detector pytorch implementation focusing on simplicity

SSD: Single Shot MultiBox Detector Introduction Here is my pytorch implementation of 2 models: SSD-Resnet50 and SSDLite-MobilenetV2.

Viet Nguyen 149 Jan 07, 2023
Rename Images with Auto Generated Neural Image Captions

Recaption Images with Generated Neural Image Caption Example Usage: Commandline: Recaption all images from folder /home/feng/Downloads/images to folde

feng wang 3 May 01, 2022
pytorch implementation of dftd2 & dftd3

torch-dftd pytorch implementation of dftd2 [1] & dftd3 [2, 3] Install # Install from pypi pip install torch-dftd # Install from source (for developer

33 Nov 28, 2022
Reducing Information Bottleneck for Weakly Supervised Semantic Segmentation (NeurIPS 2021)

Reducing Information Bottleneck for Weakly Supervised Semantic Segmentation (NeurIPS 2021) The implementation of Reducing Infromation Bottleneck for W

Jungbeom Lee 81 Dec 16, 2022
[NeurIPS'21] Projected GANs Converge Faster

[Project] [PDF] [Supplementary] [Talk] This repository contains the code for our NeurIPS 2021 paper "Projected GANs Converge Faster" by Axel Sauer, Ka

798 Jan 04, 2023
Advanced yabai wooting scripts

Yabai Wooting scripts Installation requirements Both https://github.com/xiamaz/python-yabai-client and https://github.com/xiamaz/python-wooting-rgb ne

Max Zhao 3 Dec 31, 2021
AFL binary instrumentation

E9AFL --- Binary AFL E9AFL inserts American Fuzzy Lop (AFL) instrumentation into x86_64 Linux binaries. This allows binaries to be fuzzed without the

242 Dec 12, 2022
Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR

UniSpeech The family of UniSpeech: UniSpeech (ICML 2021): Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR UniSpeech-

Microsoft 282 Jan 09, 2023
OpenMMLab Detection Toolbox and Benchmark

MMDetection is an open source object detection toolbox based on PyTorch. It is a part of the OpenMMLab project.

OpenMMLab 22.5k Jan 05, 2023
A best practice for tensorflow project template architecture.

A best practice for tensorflow project template architecture.

Mahmoud Gamal Salem 3.6k Dec 22, 2022
Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)"

Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)" which introduces a new class of deep generative models that gene

Guan-Horng Liu 43 Jan 03, 2023
Official Chainer implementation of GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral)

GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral) [Project] [Paper] [Demo] [Related Work: A2RL (for Auto Image Cropping)] [C

Wu Huikai 402 Dec 27, 2022
The code of NeurIPS 2021 paper "Scalable Rule-Based Representation Learning for Interpretable Classification".

Rule-based Representation Learner This is a PyTorch implementation of Rule-based Representation Learner (RRL) as described in NeurIPS 2021 paper: Scal

Zhuo Wang 53 Dec 17, 2022