基于pytorch构建cyclegan示例

Overview

cyclegan-demo

基于Pytorch构建CycleGAN示例

如何运行

准备数据集

将数据集整理成4个文件,分别命名为

  • trainA, trainB:训练集,A、B代表两类图片
  • testA, testB:测试集,A、B代表两类图片

例如

D:\CODE\CYCLEGAN-DEMO\DATA\SUMMER2WINTER
├─testA
├─testB
├─trainA
└─trainB

之后在main.py中将root设为数据集的路径。

参数设置

main.py中的初始化参数

# 初始化参数
# seed: 随机种子
# root: 数据集路径
# output_model_root: 模型的输出路径
# image_size: 图片尺寸
# batch_size: 一次喂入的数据量
# lr: 学习率
# betas: 一阶和二阶动量
# epochs: 训练总次数
# historical_epochs: 历史训练次数
# - 0表示不沿用历史模型
# - >0表示对应训练次数的模型
# - -1表示最后一次训练的模型
# save_every: 保存频率
# loss_range: Loss的显示范围
seed = 123
data_root = 'D:/code/cyclegan-demo/data/summer2winter'
output_model_root = 'output/model'
image_size = 64
batch_size = 16
lr = 2e-4
betas = (.5, .999)
epochs = 100
historical_epochs = -1
save_every = 1
loss_range = 1000

安装和运行

  1. 安装依赖
pip install -r requirements.txt
  1. 打开命令行,运行Visdom
python -m visdom.server
  1. 运行主程序
python main.py

训练过程的可视化展示

访问地址http://localhost:8097即可进入Visdom可视化页面,页面中将展示:

  • A类真实图片 -【A2B生成器】 -> B类虚假图片 -【B2A生成器】 -> A类重构图片
  • B类真实图片 -【B2A生成器】 -> A类虚假图片 -【A2B生成器】 -> B类重构图片
  • 判别器A、B以及生成器的Loss曲线

一些可视化的具体用法可见Visdom的使用方法。

测试

TODO

介绍

目录结构

  • dataset.py 数据集
  • discriminator.py 判别器
  • generater.py 生成器
  • main.py 主程序
  • replay_buffer.py 缓冲区
  • resblk.py 残差块
  • util.py 工具方法

原理介绍

残差块是生成器的组成部分,其结构如下

Resblk(
  (main): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (2): InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): ReflectionPad2d((1, 1, 1, 1))
    (5): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (6): InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
)

生成器结构如下,由于采用全卷积结构,事实上其结构与图片尺寸无关

Generater(
  (input): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
  )
  (downsampling): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
  )
  (resnet): Sequential(
    (0): Resblk
    (1): Resblk
    (2): Resblk
    (3): Resblk
    (4): Resblk
    (5): Resblk
    (6): Resblk
    (7): Resblk
    (8): Resblk
  )

  (upsampling): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
  )
  (output): Sequential(
    (0): ReplicationPad2d((3, 3, 3, 3))
    (1): Conv2d(64, 3, kernel_size=(7, 7), stride=(1, 1))
    (2): Tanh()
  )
)

判别器结构如下,池化层具体尺寸由图片尺寸决定,64x64的图片对应池化层为6x6

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
  (output): Sequential(
    (0): AvgPool2d(kernel_size=torch.Size([6, 6]), stride=torch.Size([6, 6]), padding=0)
    (1): Flatten(start_dim=1, end_dim=-1)
  )
)

训练共有三个优化器,分别负责生成器、判别器A、判别器B的优化。

损失有三种类型:

  • 一致性损失:A(B)类真实图片与经生成器生成的图片的误差,该损失使得生成后的风格与原图更接近,采用L1Loss
  • 对抗损失:A(B)类图片经生成器得到B(A)类图片,再经判别器判别的错误率,采用MSELoss
  • 循环损失:A(B)类图片经生成器得到B(A)类图片,再经生成器得到A(B)类的重建图片,原图和重建图片的误差,采用L1Loss

生成器的训练过程:

  1. 将A(B)类真实图片送入生成器,得到生成的图片,计算生成图片与原图的一致性损失
  2. 将A(B)类真实图片送入生成器得到虚假图片,再送入判别器得到判别结果,计算判别结果与真实标签1的对抗损失(虚假图片应能被判别器判别为真实图片,即生成器能骗过判别器)
  3. 将A(B)类虚假图片送入生成器,得到重建图片,计算重建图片与原图的循环损失
  4. 计算、更新梯度

判别器A的训练过程:

  1. 将A类真实图片送入判别器A,得到判别结果,计算判别结果与真实标签1的对抗损失(判别器应将真实图片判别为真实)
  2. 将A类虚假图片送入判别器A,得到判别结果,计算判别结果与虚假标签0的对抗损失(判别器应将虚假图片判别为虚假)
  3. 计算、更新梯度
Owner
Koorye
学习?学个屁
Koorye
Rafael Project- Classifying rockets to different types using data science algorithms.

Rocket-Classify Rafael Project- Classifying rockets to different types using data science algorithms. In this project we received data base with data

Hadassah Engel 5 Sep 18, 2021
Hands-On Machine Learning for Algorithmic Trading, published by Packt

Hands-On Machine Learning for Algorithmic Trading Hands-On Machine Learning for Algorithmic Trading, published by Packt This is the code repository fo

Packt 981 Dec 29, 2022
An implementation of Equivariant e2 convolutional kernals into a convolutional self attention network, applied to radio astronomy data.

EquivariantSelfAttention An implementation of Equivariant e2 convolutional kernals into a convolutional self attention network, applied to radio astro

2 Nov 09, 2021
PyTorch Implement for Path Attention Graph Network

SPAGAN in PyTorch This is a PyTorch implementation of the paper "SPAGAN: Shortest Path Graph Attention Network" Prerequisites We prefer to create a ne

Yang Yiding 38 Dec 28, 2022
Multi-modal co-attention for drug-target interaction annotation and Its Application to SARS-CoV-2

CoaDTI Multi-modal co-attention for drug-target interaction annotation and Its Application to SARS-CoV-2 Abstract Environment The test was conducted i

Layne_Huang 7 Nov 14, 2022
Using Convolutional Neural Networks (CNN) for Semantic Segmentation of Breast Cancer Lesions (BRCA)

Using Convolutional Neural Networks (CNN) for Semantic Segmentation of Breast Cancer Lesions (BRCA). Master's thesis documents. Bibliography, experiments and reports.

Erick Cobos 73 Dec 04, 2022
An Unbiased Learning To Rank Algorithms (ULTRA) toolbox

Unbiased Learning to Rank Algorithms (ULTRA) This is an Unbiased Learning To Rank Algorithms (ULTRA) toolbox, which provides a codebase for experiment

back 3 Nov 18, 2022
Self-Correcting Quantum Many-Body Control using Reinforcement Learning with Tensor Networks

Self-Correcting Quantum Many-Body Control using Reinforcement Learning with Tensor Networks This repository contains the code and data for the corresp

Friederike Metz 7 Apr 23, 2022
Scalable implementation of Lee / Mykland (2012) and Ait-Sahalia / Jacod (2012) Jump tests for noisy high frequency data

JumpDetectR Name of QuantLet : JumpDetectR Published in : 'To be published as "Jump dynamics in high frequency crypto markets"' Description : 'Scala

LvB 12 Jan 01, 2023
Build upon neural radiance fields to create a scene-specific implicit 3D semantic representation, Semantic-NeRF

Semantic-NeRF: Semantic Neural Radiance Fields Project Page | Video | Paper | Data In-Place Scene Labelling and Understanding with Implicit Scene Repr

Shuaifeng Zhi 243 Jan 07, 2023
An Unsupervised Graph-based Toolbox for Fraud Detection

An Unsupervised Graph-based Toolbox for Fraud Detection Introduction: UGFraud is an unsupervised graph-based fraud detection toolbox that integrates s

SafeGraph 99 Dec 11, 2022
A collection of random and hastily hacked together scripts for investigating EU-DCC

A collection of random and hastily hacked together scripts for investigating EU-DCC

Ryan Barrett 8 Mar 01, 2022
🔅 Shapash makes Machine Learning models transparent and understandable by everyone

🎉 What's new ? Version New Feature Description Tutorial 1.6.x Explainability Quality Metrics To help increase confidence in explainability methods, y

MAIF 2.1k Dec 27, 2022
Photo2cartoon - 人像卡通化探索项目 (photo-to-cartoon translation project)

人像卡通化 (Photo to Cartoon) 中文版 | English Version 该项目为小视科技卡通肖像探索项目。您可使用微信扫描下方二维码或搜索“AI卡通秀”小程序体验卡通化效果。

Minivision_AI 3.5k Dec 30, 2022
Python wrapper of LSODA (solving ODEs) which can be called from within numba functions.

numbalsoda numbalsoda is a python wrapper to the LSODA method in ODEPACK, which is for solving ordinary differential equation initial value problems.

Nick Wogan 52 Jan 09, 2023
Official implementation of the ICML2021 paper "Elastic Graph Neural Networks"

ElasticGNN This repository includes the official implementation of ElasticGNN in the paper "Elastic Graph Neural Networks" [ICML 2021]. Xiaorui Liu, W

liuxiaorui 34 Dec 04, 2022
Gesture-controlled Video Game. Just swing your finger and play the game without touching your PC

Gesture Controlled Video Game Detailed Blog : https://www.analyticsvidhya.com/blog/2021/06/gesture-controlled-video-game/ Introduction This project is

Devbrat Anuragi 35 Jan 06, 2023
DECAF: Deep Extreme Classification with Label Features

DECAF DECAF: Deep Extreme Classification with Label Features @InProceedings{Mittal21, author = "Mittal, A. and Dahiya, K. and Agrawal, S. and Sain

46 Nov 06, 2022
DAN: Unfolding the Alternating Optimization for Blind Super Resolution

DAN-Basd-on-Openmmlab DAN: Unfolding the Alternating Optimization for Blind Super Resolution We reproduce DAN via mmediting based on open-sourced code

AlexZou 72 Dec 13, 2022
The official PyTorch implementation for the paper "sMGC: A Complex-Valued Graph Convolutional Network via Magnetic Laplacian for Directed Graphs".

Magnetic Graph Convolutional Networks About The official PyTorch implementation for the paper sMGC: A Complex-Valued Graph Convolutional Network via M

3 Feb 25, 2022