Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two

Overview

512x512 flowers after 12 hours of training, 1 gpu

256x256 flowers after 12 hours of training, 1 gpu

Pizza

'Lightweight' GAN

PyPI version

Implementation of 'lightweight' GAN proposed in ICLR 2021, in Pytorch. The main contributions of the paper is a skip-layer excitation in the generator, paired with autoencoding self-supervised learning in the discriminator. Quoting the one-line summary "converge on single gpu with few hours' training, on 1024 resolution sub-hundred images".

Install

$ pip install lightweight-gan

Use

One command

$ lightweight_gan --data ./path/to/images --image-size 512

Model will be saved to ./models/{name} every 1000 iterations, and samples from the model saved to ./results/{name}. name will be default, by default.

Training settings

Pretty self explanatory for deep learning practitioners

$ lightweight_gan \
    --data ./path/to/images \
    --name {name of run} \
    --batch-size 16 \
    --gradient-accumulate-every 4 \
    --num-train-steps 200000

Augmentation

Augmentation is essential for Lightweight GAN to work effectively in a low data setting

By default, the augmentation types is set to translation and cutout, with color omitted. You can include color as well with the following.

$ lightweight_gan --data ./path/to/images --aug-prob 0.25 --aug-types [translation,cutout,color]

Test augmentation

You can test and see how your images will be augmented before it pass into a neural network (if you use augmentation). Let's see how it works on this image:

Basic usage

Base code to augment your image, define --aug-test and put path to your image into --data:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg

After this will be created the file lena_augs.jpg that will be look something like this:

Options

You can use some options to change result:

  • --image-size 256 to change size of image tiles in the result. Default: 256.
  • --aug-type [color,cutout,translation] to combine several augmentations. Default: [cutout,translation].
  • --batch-size 10 to change count of images in the result image. Default: 10.
  • --num-image-tiles 5 to change count of tiles in the result image. Default: 5.

Try this command:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg \
    --batch-size 16 \
    --num-image-tiles 4 \
    --aug-types [color,translation]

result wil be something like that:

Types of augmentations

This library contains several types of embedded augmentations.
Some of these works by default, some of these can be controlled from a command as options in the --aug-types:

  • Horizontal flip (work by default, not under control, runs in the AugWrapper class);
  • color randomly change brightness, saturation and contrast;
  • cutout creates random black boxes on the image;
  • offset randomly moves image by x and y-axis with repeating image;
    • offset_h only by an x-axis;
    • offset_v only by a y-axis;
  • translation randomly moves image on the canvas with black background;

Full setup of augmentations is --aug-types [color,cutout,offset,translation].
General recommendation is using suitable augs for your data and as many as possible, then after sometime of training disable most destructive (for image) augs.

Color

Cutout

Offset

Only x-axis:

Only y-axis:

Translation

Mixed precision

You can turn on automatic mixed precision with one flag --amp

You should expect it to be 33% faster and save up to 40% memory

Multiple GPUs

Also one flag to use --multi-gpus

Generating

Once you have finished training, you can generate samples with one command. You can select which checkpoint number to load from. If --load-from is not specified, will default to the latest.

$ lightweight_gan \
  --name {name of run} \
  --load-from {checkpoint num} \
  --generate \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After run this command you will get folder near results image folder with postfix "-generated-{checkpoint num}".

You can also generate interpolations

$ lightweight_gan --name {name of run} --generate-interpolation

Show progress

After creating several checkpoints of model you can generate progress as sequence images by command:

$ lightweight_gan \
  --name {name of run} \
  --show-progress \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After running this command you will get a new folder in the results folder, with postfix "-progress". You can convert the images to a video with ffmpeg using the command "ffmpeg -framerate 10 -pattern_type glob -i '*-ema.jpg' out.mp4".

Show progress gif demonstration

Show progress video demonstration

Discriminator output size

The author has kindly let me know that the discriminator output size (5x5 vs 1x1) leads to different results on different datasets. (5x5 works better for art than for faces, as an example). You can toggle this with a single flag

# disc output size is by default 1x1
$ lightweight_gan --data ./path/to/art --image-size 512 --disc-output-size 5

Attention

You can add linear + axial attention to specific resolution layers with the following

# make sure there are no spaces between the values within the brackets []
$ lightweight_gan --data ./path/to/images --image-size 512 --attn-res-layers [32,64] --aug-prob 0.25

Bonus

You can also train with transparent images

$ lightweight_gan --data ./path/to/images --transparent

Or greyscale

$ lightweight_gan --data ./path/to/images --greyscale

Alternatives

If you want the current state of the art GAN, you can find it at https://github.com/lucidrains/stylegan2-pytorch

Citations

@inproceedings{
    anonymous2021towards,
    title={Towards Faster and Stabilized {\{}GAN{\}} Training for High-fidelity Few-shot Image Synthesis},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=1Fqg133qRaI},
    note={under review}
}
@inproceedings{
    anonymous2021global,
    title={Global Self-Attention Networks},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=KiFeuZu24k},
    note={under review}
}
@misc{cao2020global,
    title={Global Context Networks},
    author={Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
    year={2020},
    eprint={2012.13375},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{qin2020fcanet,
    title={FcaNet: Frequency Channel Attention Networks},
    author={Zequn Qin and Pengyi Zhang and Fei Wu and Xi Li},
    year={2020},
    eprint={2012.11879},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{sinha2020topk,
    title={Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples},
    author={Samarth Sinha and Zhengli Zhao and Anirudh Goyal and Colin Raffel and Augustus Odena},
    year={2020},
    eprint={2002.06224},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}

What I cannot create, I do not understand - Richard Feynman

Comments
  • Troubles with global context module in 0.15.0

    Troubles with global context module in 0.15.0

    @lucidrains

    After update to this version https://github.com/lucidrains/lightweight-gan/releases/tag/0.15.0 I cant continue train my network and did start in from zero. Previous version was in state 117k batches by 4 (468k images, around 66 hours of trainig) image and was pretty good. In new version 0.15.0 on same dataset with same parameters (--image-size 1024 --aug-types [color,offset_h] --aug-prob 1 --amp --batch-size 7) after 77k batches by 7 (539k images, around 49 hours of training) I see some bugs like oil puddle. Did you meet this or do you know how avoid this?

    image

    In previous version with sle-spatial I didnt meet something like this.

    opened by Dok11 9
  • What is sle_spatial?

    What is sle_spatial?

    I have seen this function argument mentioned in this issue:

    https://github.com/lucidrains/lightweight-gan/issues/14#issuecomment-733432989

    What is sle_spatial?

    opened by woctezuma 8
  • unable to load save model. please try downgrading the package to the version specified by the saved model

    unable to load save model. please try downgrading the package to the version specified by the saved model

    I have the following problem since today. How to do/solve this?

    continuing from previous epoch - 118 loading from version 0.21.4 unable to load save model. please try downgrading the package to the version specified by the saved model Traceback (most recent call last): File "/opt/conda/bin/lightweight_gan", line 8, in sys.exit(main()) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 193, in main fire.Fire(train_from_folder) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 184, in train_from_folder run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 59, in run_training model.load(load_from) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1603, in load raise e File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1600, in load self.GAN.load_state_dict(load_data['GAN']) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightweightGAN: Missing key(s) in state_dict: "G.layers.0.0.2.1.weight", "G.layers.0.0.2.1.bias", "G.layers.0.0.4.weight", "G.layers.0.0.4.bias", "G.layers.0.0.4.running_mean", "G.layers.0.0.4.running_var", "G.layers.1.0.2.1.weight", "G.layers.1.0.2.1.bias", "G.layers.1.0.4.weight", "G.layers.1.0.4.bias", "G.layers.1.0.4.running_mean", "G.layers.1.0.4.running_var", "G.layers.2.0.2.1.weight", "G.layers.2.0.2.1.bias", "G.layers.2.0.4.weight", "G.layers.2.0.4.bias", "G.layers.2.0.4.running_mean", "G.layers.2.0.4.running_var", "G.layers.3.0.2.1.weight", "G.layers.3.0.2.1.bias", "G.layers.3.0.4.weight", "G.layers.3.0.4.bias", "G.layers.3.0.4.running_mean", "G.layers.3.0.4.running_var", "G.layers.3.2.fn.to_lin_q.weight", "G.layers.3.2.fn.to_lin_kv.net.0.weight", "G.layers.3.2.fn.to_lin_kv.net.1.weight", "G.layers.3.2.fn.to_kv.weight", "G.layers.4.0.2.1.weight", "G.layers.4.0.2.1.bias", "G.layers.4.0.4.weight", "G.layers.4.0.4.bias", "G.layers.4.0.4.running_mean", "G.layers.4.0.4.running_var", "G.layers.5.0.2.1.weight", "G.layers.5.0.2.1.bias", "G.layers.5.0.4.weight", "G.layers.5.0.4.bias", "G.layers.5.0.4.running_mean", "G.layers.5.0.4.running_var", "D.residual_layers.3.1.fn.to_lin_q.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D.residual_layers.3.1.fn.to_kv.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.weight", "GE.layers.0.0.2.1.weight", "GE.layers.0.0.2.1.bias", "GE.layers.0.0.4.weight", "GE.layers.0.0.4.bias", "GE.layers.0.0.4.running_mean", "GE.layers.0.0.4.running_var", "GE.layers.1.0.2.1.weight", "GE.layers.1.0.2.1.bias", "GE.layers.1.0.4.weight", "GE.layers.1.0.4.bias", "GE.layers.1.0.4.running_mean", "GE.layers.1.0.4.running_var", "GE.layers.2.0.2.1.weight", "GE.layers.2.0.2.1.bias", "GE.layers.2.0.4.weight", "GE.layers.2.0.4.bias", "GE.layers.2.0.4.running_mean", "GE.layers.2.0.4.running_var", "GE.layers.3.0.2.1.weight", "GE.layers.3.0.2.1.bias", "GE.layers.3.0.4.weight", "GE.layers.3.0.4.bias", "GE.layers.3.0.4.running_mean", "GE.layers.3.0.4.running_var", "GE.layers.3.2.fn.to_lin_q.weight", "GE.layers.3.2.fn.to_lin_kv.net.0.weight", "GE.layers.3.2.fn.to_lin_kv.net.1.weight", "GE.layers.3.2.fn.to_kv.weight", "GE.layers.4.0.2.1.weight", "GE.layers.4.0.2.1.bias", "GE.layers.4.0.4.weight", "GE.layers.4.0.4.bias", "GE.layers.4.0.4.running_mean", "GE.layers.4.0.4.running_var", "GE.layers.5.0.2.1.weight", "GE.layers.5.0.2.1.bias", "GE.layers.5.0.4.weight", "GE.layers.5.0.4.bias", "GE.layers.5.0.4.running_mean", "GE.layers.5.0.4.running_var", "D_aug.D.residual_layers.3.1.fn.to_lin_q.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.weight". Unexpected key(s) in state_dict: "G.layers.0.0.2.weight", "G.layers.0.0.2.bias", "G.layers.0.0.3.bias", "G.layers.0.0.3.running_mean", "G.layers.0.0.3.running_var", "G.layers.0.0.3.num_batches_tracked", "G.layers.1.0.2.weight", "G.layers.1.0.2.bias", "G.layers.1.0.3.bias", "G.layers.1.0.3.running_mean", "G.layers.1.0.3.running_var", "G.layers.1.0.3.num_batches_tracked", "G.layers.2.0.2.weight", "G.layers.2.0.2.bias", "G.layers.2.0.3.bias", "G.layers.2.0.3.running_mean", "G.layers.2.0.3.running_var", "G.layers.2.0.3.num_batches_tracked", "G.layers.3.0.2.weight", "G.layers.3.0.2.bias", "G.layers.3.0.3.bias", "G.layers.3.0.3.running_mean", "G.layers.3.0.3.running_var", "G.layers.3.0.3.num_batches_tracked", "G.layers.3.2.fn.to_kv.net.0.weight", "G.layers.3.2.fn.to_kv.net.1.weight", "G.layers.4.0.2.weight", "G.layers.4.0.2.bias", "G.layers.4.0.3.bias", "G.layers.4.0.3.running_mean", "G.layers.4.0.3.running_var", "G.layers.4.0.3.num_batches_tracked", "G.layers.5.0.2.weight", "G.layers.5.0.2.bias", "G.layers.5.0.3.bias", "G.layers.5.0.3.running_mean", "G.layers.5.0.3.running_var", "G.layers.5.0.3.num_batches_tracked", "D.residual_layers.3.1.fn.to_kv.net.0.weight", "D.residual_layers.3.1.fn.to_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight", "GE.layers.0.0.2.weight", "GE.layers.0.0.2.bias", "GE.layers.0.0.3.bias", "GE.layers.0.0.3.running_mean", "GE.layers.0.0.3.running_var", "GE.layers.0.0.3.num_batches_tracked", "GE.layers.1.0.2.weight", "GE.layers.1.0.2.bias", "GE.layers.1.0.3.bias", "GE.layers.1.0.3.running_mean", "GE.layers.1.0.3.running_var", "GE.layers.1.0.3.num_batches_tracked", "GE.layers.2.0.2.weight", "GE.layers.2.0.2.bias", "GE.layers.2.0.3.bias", "GE.layers.2.0.3.running_mean", "GE.layers.2.0.3.running_var", "GE.layers.2.0.3.num_batches_tracked", "GE.layers.3.0.2.weight", "GE.layers.3.0.2.bias", "GE.layers.3.0.3.bias", "GE.layers.3.0.3.running_mean", "GE.layers.3.0.3.running_var", "GE.layers.3.0.3.num_batches_tracked", "GE.layers.3.2.fn.to_kv.net.0.weight", "GE.layers.3.2.fn.to_kv.net.1.weight", "GE.layers.4.0.2.weight", "GE.layers.4.0.2.bias", "GE.layers.4.0.3.bias", "GE.layers.4.0.3.running_mean", "GE.layers.4.0.3.running_var", "GE.layers.4.0.3.num_batches_tracked", "GE.layers.5.0.2.weight", "GE.layers.5.0.2.bias", "GE.layers.5.0.3.bias", "GE.layers.5.0.3.running_mean", "GE.layers.5.0.3.running_var", "GE.layers.5.0.3.num_batches_tracked", "D_aug.D.residual_layers.3.1.fn.to_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight". size mismatch for G.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for G.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]). size mismatch for GE.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for GE.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D_aug.D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]).

    opened by sebastiantrella 7
  • Greyscale image generation

    Greyscale image generation

    Hi,

    thank you for this repo, I've been playing with it a bit and it seems very good! I am trying to generate greyscale images, so I modified the channel accordingly

    init_channel = 4 if transparent else 1

    unfortunately, this seemed to have no effect as the images generated are still RGB (even though they converge towards greyscale with time), even weirder IMO is that I can modify the number of channels for the generator and keep the original 3 for the discriminator without any issue.

    I have also changed this part to no effect

    convert_image_fn = partial(convert_image_to, 'RGBA' if transparent else 'L') num_channels = 1 if not transparent else 4

    Am I missing something here?

    opened by stefanorosss 7
  • Getting NoneType is not subscriptable when trying to start training.

    Getting NoneType is not subscriptable when trying to start training.

    I've been able to train models before but after changing my dataset I'm getting the error.

    My trace: File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/lightweight_gan.py", line 1356, in load name = checkpoints[-1] TypeError: 'NoneType' object is not subscriptable

    opened by TomCallan 6
  • Optimal parameters for Google Colab

    Optimal parameters for Google Colab

    Hello,

    First of all, thank you for sharing your code and insights with the rest of us!

    As for your code, I plan to run it for 12 hours on Google Colab, similarly to the set-up for what is shown in the README.

    My datasets consists of images of 256x256 resolution, and I have started training with the following command-line:

    !lightweight_gan \
     --data {image_dir} \
     --disc-output-size 5 \
     --aug-prob 0.25 \
     --aug-types [translation,cutout,color] \
     --amp \
    

    I have noticed that the expected training time is 112.5 hours with 150k iterations (the default setting), which is consistent with the average time of 2.7 seconds per iteration shown in the log. However, it is ~ 9 times more than what is shown in the README. So I wonder if I am doing something wrong, and I see 2 solutions.

    First, I could decrease the number of iterations so that it takes 12 hours, by choosing 16k iterations instead of 150k with:

     --num-train-steps 16000 \
    

    Is it what you have done for the results shown in the README?

    Second, I have noticed that I am only using 3.8 GB of GPU memory, so I could increase the batch size, as you mentioned in https://github.com/lucidrains/lightweight-gan/issues/13#issuecomment-732486110. Edit: However, the training time increases with a larger batch size. For instance, I am using 7.2 GB of GPU memory, and it takes 8.2 seconds per iteration, with the following:

     --batch-size 32 \
     --gradient-accumulate-every 4 \
    
    opened by woctezuma 6
  • Added Experimentation Tracking.

    Added Experimentation Tracking.

    Added Experimentation Tracking using Aim.

    Now you can:

    Track all the model hyperparameters and architectural choices. Track all types of losses. Filter all the experiments with respect to hyperparameters or the architecture Group and aggregate w.r.t. all the trackables to dive into granular experimentation assessment. Track the generated images to track how the model improves.

    Screen Shot 2022-04-12 at 16 56 35 Screen Shot 2022-04-12 at 16 57 24
    opened by hnhnarek 5
  • Aim installation error

    Aim installation error

    I'm trying to run the generator after training, to generate fake samples using the following command

    lightweight_gan --generate --load-from 299

    I get this following error:

    Traceback (most recent call last):
      File "C:\anaconda3\lib\runpy.py", line 197, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "C:\anaconda3\lib\runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "C:\anaconda3\Scripts\lightweight_gan.exe\__main__.py", line 7, in <module>
      File "C:\anaconda3\lib\site-packages\lightweight_gan\cli.py", line 195, in main
        fire.Fire(train_from_folder)
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 141, in Fire
        component_trace = _Fire(component, args, parsed_flag_args, context, name)
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 466, in _Fire
        component, remaining_args = _CallAndUpdateTrace(
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 681, in _CallAndUpdateTrace
        component = fn(*varargs, **kwargs)
      File "C:\anaconda3\lib\site-packages\lightweight_gan\cli.py", line 158, in train_from_folder
        model = Trainer(**model_args)
      File "C:\anaconda3\lib\site-packages\lightweight_gan\lightweight_gan.py", line 1057, in __init__
        self.run = self.aim.Run(run_hash=aim_run_hash, repo=aim_repo)
    AttributeError: 'Trainer' object has no attribute 'aim'
    

    and when I try to run pip install aim, I get a dependency error with aimrocks

      ERROR: Command errored out with exit status 1:
       command: 'C:\Anaconda3\envs\aerialweb\python.exe' 'C:\Anaconda3\envs\aerialweb\lib\site-packages\pip' install --ignore-installed --no-user --prefix 'C:\Users\ahmed\AppData\Local\Temp\pip-build-env-b2ysw94t\overlay' --no-warn-script-location --no-binary :none: --only-binary :none: -i https://pypi.org/simple -- setuptools 'cython >= 3.0.0a9' 'aimrocks == 0.2.1'
           cwd: None
      Complete output (12 lines):
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
      Collecting setuptools
        Using cached setuptools-59.6.0-py3-none-any.whl (952 kB)
      Collecting cython>=3.0.0a9
        Using cached Cython-3.0.0a10-py2.py3-none-any.whl (1.1 MB)
      ERROR: Could not find a version that satisfies the requirement aimrocks==0.2.1 (from versions: 0.1.3a14, 0.2.0.dev1, 0.2.0)
      ERROR: No matching distribution found for aimrocks==0.2.1
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
    

    What is aimrocks and what does it actually do? I am unable to find a matching distribution or even a wheel file to install it manually. Please help

    opened by demiahmed 4
  • Can't find

    Can't find "__main__" module (sorry if noob question)

    Hello, I hope it's not too much of a noob question, I don't have any background in coding.

    After creating the env and installing Pytorch I ran "python setup.py install" and then I ran "python lightweight_gan --data /source --image-size 512" (I filled a "source" folder with pictures of fishes) but I get the error "can't find 'main' module". More exactly, C:\Programmes perso\Logiciels\Anaconda\envs\lightweightgan\python.exe: can't find 'main' module in 'C:\Programmes perso\Logiciels\LightweightGan\lightweight_gan' I tried to copy and rename some of the other modules (init, lightweight_gan...), the code seems to start to run but stops before doing anything. So I guess some file must be missing, or did I do something wrong ?

    Thanks a lot for the repo and have a nice day

    opened by SPotaufeux 4
  • Hard cutoff straight lines/boxes of nothing in generated images

    Hard cutoff straight lines/boxes of nothing in generated images

    Hello! Training on Google Colab with

    !lightweight_gan --data my/images/ --name my-name --image-size 256 --transparent --dual-contrast-loss --num-train-steps 250000
    

    I'm at 250k iterations over the course of 5 days at 2s/it, and have gotten strange results with boxes.

    I've circled some examples of this below. image

    My training data is 22k images of 256x256 .pngs that do not contain large hard edges or boxes like this. They're video game sprites with hard edges being limited to at most 10x10px

    Are there any suggestions I can do with arguments in order to decrease the chance of the models learning that transparent boxes are good? Would converting to a white background help?

    Thank you!

    opened by timendez 4
  • Amount of training steps

    Amount of training steps

    If I bring down the number of training steps from 150 000 to 30 000, will the trained model be overall bad? Does it really need the 100 000 or 150 000 training steps?

    opened by MartimQS 4
  • Executing with a trailing \ in the arguments sets the variable new to the truthy value '\\' and deletes all progress

    Executing with a trailing \ in the arguments sets the variable new to the truthy value '\\' and deletes all progress

    A rather frustrating issue:

    calling it with a trailing \ like lightweight_gan --data full_cleaned256/ --name c256 --models_dir models --results_dir results \

    sets the variable new to the truthy value '\' and deletes all progress.

    This might well be an issue with Fire but might be mitigated or fixed here too, I am unsure about that.

    Thanks. Jonas

    opened by deklesen 0
  • Projecting generated Images to Latent Space

    Projecting generated Images to Latent Space

    Is there any way to reverse engineer the generated images into the latent space?

    I am trying to embed fresh RGB as well as ones generated by the Generator into the latent space so I can find its nearest neighbour, pretty much like AI image editing tools.

    I plan to convert my RGB image into tensor embeddings based on my trained model and tweak the feature vectors.

    How can I achieve this with lightweight-gan?

    opened by demiahmed 0
  • Discriminator Loss converges to 0 while Generator loss pretty high

    Discriminator Loss converges to 0 while Generator loss pretty high

    I am trying to train with a custom image dataset for about 600,000 epochs. At about halfway, my D_loss converges to 0 while my G_loss stays put at 2.5

    My evaluation outputs are slowly starting to fade out to either black or white.

    Is there any thing that I could to tweak my model? Either by increasing the threshold for the Discriminator or by training the Generator only?

    opened by demiahmed 3
  • loss implementation differs from paper

    loss implementation differs from paper

    Hi,

    Thanks for this amazing implementation! I have a question concerning the loss implementation, as it seems to differ from the original equations. The screenshot below shows the GAN loss as presented in the paper :

    paper_losses

    • in red, the discriminator loss (D loss) on the true labels,
    • in green the D loss on labels for fake generated images,
    • and in blue, the generator loss (G loss) on labels for fake images.

    This makes sense to me. Since it is assumed that D outputs values between 0 and 1 (0 = fake, 1 = real) :

    • in red, we want D to output 1 for true images → let's assume D indeed outputs 1 for true images : -min(0, -1 + D(x)) = 0, which is indeed the minimum achievable,
    • in green, we want D to output 0 (from the discriminator perspective) for fake images → let's assume D indeed outputs 0 for fake images : -min(0, -1 - D(x^)) = 1, which is the minimum achievable if D outputs values only between 0 and 1,
    • in blue, we want D to output 1 (from the generator perspective) for fake images : the equation follows directly.

    Now, the way the authors implement this in the code provided in the supplementary materials of the paper is as follows (the colors match the ones in the above picture)

    og_code_loss_d_real og_code_loss_d_fake og_code_loss_g

    Except for the strange involved randomness (already explained in https://github.com/lucidrains/lightweight-gan/issues/11), their implementation is a one to one match with the paper equations.


    The way it is implemented in this repo however is quite different, and I do not understand why..

    lighweight_gan_losses

    Let's start with the discriminator loss :

    • in red, you want D to output small values (negative if allowed), to set this term as small as possible (0 if D can output negative values)
    • in green, you want D to output values as large as possible (larger or equal to 1) to cancel this term out as well

    For the generator loss :

    • in blue, you want the opposite of green, that is for D to output values as small as possible

    This implementation seems to be meaningful, and yields coherent results (as proven in examples). It also seems to me that D is not limited to output values between 0 and 1, but any real value (I might be wrong). I am just wondering why this choice? Could you perhaps elaborate why you decided to implement the loss differently from the original paper?

    opened by maximeraafat 1
  • showing results while training ?

    showing results while training ?

    how to show generator results after every epoch during training ?

    this is my current configuration

     lightweight_gan \
      --data "/content/dataset/Dataset/" \
      --num-train-steps 100000 \
      --image-size 128 \
      --name GAN2DBlood5k \
      --batch-size 32 \
      --gradient-accumulate-every 5 \
      --disc-output-size 1 \
      --dual-contrast-loss \
      --attn-res-layers [] \
      --calculate_fid_every 1000\
      --greyscale \
      --amp
    

    using --show-progress only works after training. Also it seems that there is no longer checkpoints per epoch

    opened by galaelized 2
Releases(1.1.1)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Collection of NLP model explanations and accompanying analysis tools

Thermostat is a large collection of NLP model explanations and accompanying analysis tools. Combines explainability methods from the captum library wi

126 Nov 22, 2022
Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks.

Self Supervised Learning with Fastai Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks. Install pip install self-

Kerem Turgutlu 276 Dec 23, 2022
Boundary-aware Transformers for Skin Lesion Segmentation

Boundary-aware Transformers for Skin Lesion Segmentation Introduction This is an official release of the paper Boundary-aware Transformers for Skin Le

Jiacheng Wang 79 Dec 16, 2022
official Pytorch implementation of ICCV 2021 paper FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting.

FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting By Rui Liu, Hanming Deng, Yangyi Huang, Xiaoyu Shi, Lewei Lu, Wenxiu

77 Dec 27, 2022
This is the official Pytorch implementation of "Lung Segmentation from Chest X-rays using Variational Data Imputation", Raghavendra Selvan et al. 2020

README This is the official Pytorch implementation of "Lung Segmentation from Chest X-rays using Variational Data Imputation", Raghavendra Selvan et a

Raghav 42 Dec 15, 2022
Text and code for the forthcoming second edition of Think Bayes, by Allen Downey.

Think Bayes 2 by Allen B. Downey The HTML version of this book is here. Think Bayes is an introduction to Bayesian statistics using computational meth

Allen Downey 1.5k Jan 08, 2023
LightningFSL: Pytorch-Lightning implementations of Few-Shot Learning models.

LightningFSL: Few-Shot Learning with Pytorch-Lightning In this repo, a number of pytorch-lightning implementations of FSL algorithms are provided, inc

Xu Luo 76 Dec 11, 2022
Open CV - Convert a picture to look like a cartoon sketch in python

Use the video https://www.youtube.com/watch?v=k7cVPGpnels for initial learning.

Sammith S Bharadwaj 3 Jan 29, 2022
Direct LiDAR Odometry: Fast Localization with Dense Point Clouds

Direct LiDAR Odometry: Fast Localization with Dense Point Clouds DLO is a lightweight and computationally-efficient frontend LiDAR odometry solution w

VECTR at UCLA 369 Dec 30, 2022
A collection of SOTA Image Classification Models in PyTorch

A collection of SOTA Image Classification Models in PyTorch

sithu3 85 Dec 30, 2022
This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detection', CVPR 2019.

Code-and-Dataset-for-CapSal This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detec

lu zhang 48 Aug 19, 2022
OpenMMLab Model Deployment Toolset

Introduction English | 简体中文 MMDeploy is an open-source deep learning model deployment toolset. It is a part of the OpenMMLab project. Major features F

OpenMMLab 1.5k Dec 30, 2022
This program will stylize your photos with fast neural style transfer.

Neural Style Transfer (NST) Using TensorFlow Demo TensorFlow TensorFlow is an end-to-end open source platform for machine learning. It has a comprehen

Ismail Boularbah 1 Aug 08, 2022
Generative Query Network (GQN) in PyTorch as described in "Neural Scene Representation and Rendering"

Update 2019/06/24: A model trained on 10% of the Shepard-Metzler dataset has been added, the following notebook explains the main features of this mod

Jesper Wohlert 313 Dec 27, 2022
Spatial-Location-Constraint-Prototype-Loss-for-Open-Set-Recognition

Spatial Location Constraint Prototype Loss for Open Set Recognition Official PyTorch implementation of "Spatial Location Constraint Prototype Loss for

Xia Ziheng 12 Jun 24, 2022
Implementation for Stankevičiūtė et al. "Conformal time-series forecasting", NeurIPS 2021.

Conformal time-series forecasting Implementation for Stankevičiūtė et al. "Conformal time-series forecasting", NeurIPS 2021. If you use our code in yo

Kamilė Stankevičiūtė 36 Nov 21, 2022
Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP

Wav2CLIP 🚧 WIP 🚧 Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP 📄 🔗 Ho-Hsiang Wu, Prem Seetharaman

Descript 240 Dec 13, 2022
Source code for the BMVC-2021 paper "SimReg: Regression as a Simple Yet Effective Tool for Self-supervised Knowledge Distillation".

SimReg: A Simple Regression Based Framework for Self-supervised Knowledge Distillation Source code for the paper "SimReg: Regression as a Simple Yet E

9 Oct 15, 2022
This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in Eurographics 2021

Deep-Detail-Enhancement-for-Any-Garment Introduction This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in

40 Dec 13, 2022
Qimera: Data-free Quantization with Synthetic Boundary Supporting Samples

Qimera: Data-free Quantization with Synthetic Boundary Supporting Samples This repository is the official implementation of paper [Qimera: Data-free Q

Kanghyun Choi 21 Nov 03, 2022