基于pytorch构建cyclegan示例

Overview

cyclegan-demo

基于Pytorch构建CycleGAN示例

如何运行

准备数据集

将数据集整理成4个文件,分别命名为

  • trainA, trainB:训练集,A、B代表两类图片
  • testA, testB:测试集,A、B代表两类图片

例如

D:\CODE\CYCLEGAN-DEMO\DATA\SUMMER2WINTER
├─testA
├─testB
├─trainA
└─trainB

之后在main.py中将root设为数据集的路径。

参数设置

main.py中的初始化参数

# 初始化参数
# seed: 随机种子
# root: 数据集路径
# output_model_root: 模型的输出路径
# image_size: 图片尺寸
# batch_size: 一次喂入的数据量
# lr: 学习率
# betas: 一阶和二阶动量
# epochs: 训练总次数
# historical_epochs: 历史训练次数
# - 0表示不沿用历史模型
# - >0表示对应训练次数的模型
# - -1表示最后一次训练的模型
# save_every: 保存频率
# loss_range: Loss的显示范围
seed = 123
data_root = 'D:/code/cyclegan-demo/data/summer2winter'
output_model_root = 'output/model'
image_size = 64
batch_size = 16
lr = 2e-4
betas = (.5, .999)
epochs = 100
historical_epochs = -1
save_every = 1
loss_range = 1000

安装和运行

  1. 安装依赖
pip install -r requirements.txt
  1. 打开命令行,运行Visdom
python -m visdom.server
  1. 运行主程序
python main.py

训练过程的可视化展示

访问地址http://localhost:8097即可进入Visdom可视化页面,页面中将展示:

  • A类真实图片 -【A2B生成器】 -> B类虚假图片 -【B2A生成器】 -> A类重构图片
  • B类真实图片 -【B2A生成器】 -> A类虚假图片 -【A2B生成器】 -> B类重构图片
  • 判别器A、B以及生成器的Loss曲线

一些可视化的具体用法可见Visdom的使用方法。

测试

TODO

介绍

目录结构

  • dataset.py 数据集
  • discriminator.py 判别器
  • generater.py 生成器
  • main.py 主程序
  • replay_buffer.py 缓冲区
  • resblk.py 残差块
  • util.py 工具方法

原理介绍

残差块是生成器的组成部分,其结构如下

Resblk(
  (main): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (2): InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): ReflectionPad2d((1, 1, 1, 1))
    (5): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (6): InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
)

生成器结构如下,由于采用全卷积结构,事实上其结构与图片尺寸无关

Generater(
  (input): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
  )
  (downsampling): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
  )
  (resnet): Sequential(
    (0): Resblk
    (1): Resblk
    (2): Resblk
    (3): Resblk
    (4): Resblk
    (5): Resblk
    (6): Resblk
    (7): Resblk
    (8): Resblk
  )

  (upsampling): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
  )
  (output): Sequential(
    (0): ReplicationPad2d((3, 3, 3, 3))
    (1): Conv2d(64, 3, kernel_size=(7, 7), stride=(1, 1))
    (2): Tanh()
  )
)

判别器结构如下,池化层具体尺寸由图片尺寸决定,64x64的图片对应池化层为6x6

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
  (output): Sequential(
    (0): AvgPool2d(kernel_size=torch.Size([6, 6]), stride=torch.Size([6, 6]), padding=0)
    (1): Flatten(start_dim=1, end_dim=-1)
  )
)

训练共有三个优化器,分别负责生成器、判别器A、判别器B的优化。

损失有三种类型:

  • 一致性损失:A(B)类真实图片与经生成器生成的图片的误差,该损失使得生成后的风格与原图更接近,采用L1Loss
  • 对抗损失:A(B)类图片经生成器得到B(A)类图片,再经判别器判别的错误率,采用MSELoss
  • 循环损失:A(B)类图片经生成器得到B(A)类图片,再经生成器得到A(B)类的重建图片,原图和重建图片的误差,采用L1Loss

生成器的训练过程:

  1. 将A(B)类真实图片送入生成器,得到生成的图片,计算生成图片与原图的一致性损失
  2. 将A(B)类真实图片送入生成器得到虚假图片,再送入判别器得到判别结果,计算判别结果与真实标签1的对抗损失(虚假图片应能被判别器判别为真实图片,即生成器能骗过判别器)
  3. 将A(B)类虚假图片送入生成器,得到重建图片,计算重建图片与原图的循环损失
  4. 计算、更新梯度

判别器A的训练过程:

  1. 将A类真实图片送入判别器A,得到判别结果,计算判别结果与真实标签1的对抗损失(判别器应将真实图片判别为真实)
  2. 将A类虚假图片送入判别器A,得到判别结果,计算判别结果与虚假标签0的对抗损失(判别器应将虚假图片判别为虚假)
  3. 计算、更新梯度
Owner
Koorye
学习?学个屁
Koorye
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation

PyGRANSO PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation Please check https://ncvx.org/PyGRANSO for detailed instructions (introd

SUN Group @ UMN 26 Nov 16, 2022
Learning Generative Models of Textured 3D Meshes from Real-World Images, ICCV 2021

Learning Generative Models of Textured 3D Meshes from Real-World Images This is the reference implementation of "Learning Generative Models of Texture

Dario Pavllo 115 Jan 07, 2023
Official codebase for Pretrained Transformers as Universal Computation Engines.

universal-computation Overview Official codebase for Pretrained Transformers as Universal Computation Engines. Contains demo notebook and scripts to r

Kevin Lu 210 Dec 28, 2022
Synthesizing and manipulating 2048x1024 images with conditional GANs

pix2pixHD Project | Youtube | Paper Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translatio

NVIDIA Corporation 6k Dec 27, 2022
Validated, scalable, community developed variant calling, RNA-seq and small RNA analysis

Validated, scalable, community developed variant calling, RNA-seq and small RNA analysis. You write a high level configuration file specifying your in

Blue Collar Bioinformatics 917 Jan 03, 2023
Navigating StyleGAN2 w latent space using CLIP

Navigating StyleGAN2 w latent space using CLIP an attempt to build sth with the official SG2-ADA Pytorch impl kinda inspired by Generating Images from

Mike K. 55 Dec 06, 2022
BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment

BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment

Holy Wu 35 Jan 01, 2023
Rafael Project- Classifying rockets to different types using data science algorithms.

Rocket-Classify Rafael Project- Classifying rockets to different types using data science algorithms. In this project we received data base with data

Hadassah Engel 5 Sep 18, 2021
Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

Super Resolution Examples We run this script under TensorFlow 2.0 and the TensorLayer2.0+. For TensorLayer 1.4 version, please check release. 🚀 🚀 🚀

TensorLayer Community 2.9k Jan 08, 2023
SHRIMP: Sparser Random Feature Models via Iterative Magnitude Pruning

SHRIMP: Sparser Random Feature Models via Iterative Magnitude Pruning This repository is the official implementation of "SHRIMP: Sparser Random Featur

Bobby Shi 0 Dec 16, 2021
Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis

Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis [Paper] [Online Demo] The following results are obtained by our SCUNet with purely syn

Kai Zhang 312 Jan 07, 2023
Groceries ARL: Association Rules (Birliktelik Kuralı)

Groceries_ARL Association Rules (Birliktelik Kuralı) Birliktelik kuralları, mark

Şebnem 5 Feb 08, 2022
Weakly Supervised 3D Object Detection from Point Cloud with Only Image Level Annotation

SCCKTIM Weakly Supervised 3D Object Detection from Point Cloud with Only Image-Level Annotation Our code will be available soon. The class knowledge t

1 Nov 12, 2021
Some useful blender add-ons for SMPL skeleton's poses and global translation.

Blender add-ons for SMPL skeleton's poses and trans There are two blender add-ons for SMPL skeleton's poses and trans.The first is for making an offli

犹在镜中 154 Jan 04, 2023
Defending graph neural networks against adversarial attacks (NeurIPS 2020)

GNNGuard: Defending Graph Neural Networks against Adversarial Attacks Authors: Xiang Zhang ( Zitnik Lab @ Harvard 44 Dec 07, 2022

ML-Decoder: Scalable and Versatile Classification Head

ML-Decoder: Scalable and Versatile Classification Head Paper Official PyTorch Implementation Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben-Baru

189 Jan 04, 2023
Generative Autoregressive, Normalized Flows, VAEs, Score-based models (GANVAS)

GANVAS-models This is an implementation of various generative models. It contains implementations of the following: Autoregressive Models: PixelCNN, G

MRSAIL (Mini Robotics, Software & AI Lab) 6 Nov 26, 2022
Res2Net for Instance segmentation and Object detection using MaskRCNN

Res2Net for Instance segmentation and Object detection using MaskRCNN Since the MaskRCNN-benchmark of facebook is deprecated, we suggest to use our mm

Res2Net Applications 55 Oct 30, 2022
Official code for NeurIPS 2021 paper "Towards Scalable Unpaired Virtual Try-On via Patch-Routed Spatially-Adaptive GAN"

Towards Scalable Unpaired Virtual Try-On via Patch-Routed Spatially-Adaptive GAN Official code for NeurIPS 2021 paper "Towards Scalable Unpaired Virtu

68 Dec 21, 2022
WSDM2022 "A Simple but Effective Bidirectional Extraction Framework for Relational Triple Extraction"

BiRTE WSDM2022 "A Simple but Effective Bidirectional Extraction Framework for Relational Triple Extraction" Requirements The main requirements are: py

9 Dec 27, 2022