Jittor 64*64 implementation of StyleGAN

Overview

StyleGanJittor (Tsinghua university computer graphics course)

Overview

Jittor 64*64 implementation of StyleGAN (Tsinghua university computer graphics course) This project is a repetition of StyleGAN based on python 3.8 + Jittor(计图) and The open source StyleGAN-Pytorch project. I train the model on the color_symbol_7k dataset for 40000 iterations. The model can generate 64×64 symbolic images.

StyleGAN is a generative adversarial network for image generation proposed by NVIDIA in 2018. According to the paper, the generator improves the state-of-the-art in terms of traditional distribution quality metrics, leads to demonstrably better interpolation properties, and also better disentangles the latent factors of variation. The main improvement of this network model over previous models is the structure of the generator, including the addition of an eight-layer Mapping Network, the use of the AdaIn module, and the introduction of image randomness - these structures allow the generator to The overall features of the image are decoupled from the local features to synthesize images with better effects; at the same time, the network also has better latent space interpolation effects.

(Karras T, Laine S, Aila T. A style-based generator architecture for generative adversarial networks[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019: 4401-4410.)

The training results are shown in Video1trainingResult.avi, Video2GenerationResult1.avi, and Video3GenerationResul2t.avi generated by the trained model.

The Checkpoint folder is the trained StyleGAN model, because it takes up a lot of storage space, the models have been deleted.The data folder is the color_symbol_7k dataset folder. The dataset is processed by the prepare_data file to obtain the LMDB database for accelerated training, and the database is stored in the mdb folder.The sample folder is the folder where the images are generated during the model training process, which can be used to traverse the training process. The generateSample folder is the sample image generated by calling StyleGenerator after the model training is completed.

The MultiResolutionDataset method for reading the LMDB database is defined in dataset.py, the Jittor model reproduced by Jittor is defined in model.py, train.py is used for the model training script, and VideoWrite.py is used to convert the generated image. output for video.

Environment and execution instructions

Project environment dependencies include jittor, ldbm, PIL, argparse, tqdm and some common python libraries.

First you need to unzip the dataset in the data folder. The model can be trained by the script in the terminal of the project environment python train.py --mixing "./mdb/color_symbol_7k_mdb"

Images can be generated based on the trained model and compared for their differences by the script python generate.py --size 64 --n_row 3 --n_col 5 --path './checkpoint/040000.model'

You can adjust the model training parameters by referring to the code in the args section of train.py and generate.py.

Details

The first is the data set preparation, using the LMDB database to accelerate the training. For model construction, refer to the model structure shown in the following figure in the original text, and the recurring Suri used in Pytorch open source version 1. Using the model-dependent framework shown in the second figure below, the original model is split into EqualConv2d, EqualLinear, StyleConvBlock , Convblock and other sub-parts are implemented, and finally built into a complete StyleGenerator and Discriminator.

image

image

In the model building and training part, follow the tutorial provided by the teaching assistant on the official website to help convert the torch method to the jittor method, and explore some other means to implement it yourself. Jittor's documentation is relatively incomplete, and some methods are different from Pytorch. In this case, I use a lower-level method for implementation.

For example: jt.sqrt(out.var(0, unbiased=False) + 1e-8) is used in the Discrimination part of the model to solve the variance of the given dimension of the tensor, and there is no corresponding var() in the Jittor framework method, so I use ((out-out.mean(0)).sqr().sum(0)+1e-8).sqrt() to implement the same function.

Results

Limited by the hardware, the model training time is long, and I don't have enough time to fine-tune various parameters, optimizers and various parameters, so the results obtained by training on Jittor are not as good as when I use the same model framework to train on Pytorch The result is good, but the progressive training process can be clearly seen from the video, and the generated symbols are gradually clear, and the results are gradually getting better.

Figures below are sample results obtained by training on Jittor and Pytorch respectively. For details, please refer to the video files in the folder. The training results of the same model and code on Pytorch can be found in the sample_torch folder.

figures by Jittor figures by Pytorch

To be continued

Owner
Song Shengyu
Song Shengyu
RLHive: a framework designed to facilitate research in reinforcement learning.

RLHive is a framework designed to facilitate research in reinforcement learning. It provides the components necessary to run a full RL experiment, for both single agent and multi agent environments.

88 Jan 05, 2023
GMFlow: Learning Optical Flow via Global Matching

GMFlow GMFlow: Learning Optical Flow via Global Matching Authors: Haofei Xu, Jing Zhang, Jianfei Cai, Hamid Rezatofighi, Dacheng Tao We streamline the

Haofei Xu 298 Jan 04, 2023
MARE - Multi-Attribute Relation Extraction

MARE - Multi-Attribute Relation Extraction Repository for the paper submission: #TODO: insert link, when available Environment Tested with Ubuntu 18.0

0 May 11, 2021
Enabling dynamic analysis of Legacy Embedded Systems in full emulated environment

PENecro This project is based on "Enabling dynamic analysis of Legacy Embedded Systems in full emulated environment", published on hardwear.io USA 202

Ta-Lun Yen 10 May 17, 2022
Keras Image Embeddings using Contrastive Loss

Keras-Image-Embeddings-using-Contrastive-Loss Image to Embedding projection in vector space. Implementation in keras and tensorflow for custom data. B

Shravan Anand K 5 Mar 21, 2022
Crosslingual Segmental Language Model

Crosslingual Segmental Language Model This repository contains the code from Multilingual unsupervised sequence segmentation transfers to extremely lo

C.M. Downey 1 Jun 13, 2022
Create images and texts with the First Order Generative Adversarial Networks

First Order Divergence for training GANs This repository contains code accompanying the paper First Order Generative Advesarial Netoworks The majority

Zalando Research 35 Dec 11, 2021
MINOS: Multimodal Indoor Simulator

MINOS Simulator MINOS is a simulator designed to support the development of multisensory models for goal-directed navigation in complex indoor environ

194 Dec 27, 2022
[ICRA 2022] An opensource framework for cooperative detection. Official implementation for OPV2V.

OpenCOOD OpenCOOD is an Open COOperative Detection framework for autonomous driving. It is also the official implementation of the ICRA 2022 paper OPV

Runsheng Xu 322 Dec 23, 2022
Prevent `CUDA error: out of memory` in just 1 line of code.

🐨 Koila Koila solves CUDA error: out of memory error painlessly. Fix it with just one line of code, and forget it. 🚀 Features 🙅 Prevents CUDA error

RenChu Wang 1.7k Jan 02, 2023
Pytorch implementation of our paper under review -- 1xN Pattern for Pruning Convolutional Neural Networks

1xN Pattern for Pruning Convolutional Neural Networks (paper) . This is Pytorch re-implementation of "1xN Pattern for Pruning Convolutional Neural Net

Mingbao Lin (林明宝) 29 Nov 29, 2022
Encoding Causal Macrovariables

Encoding Causal Macrovariables Data Natural climate data ('El Nino') Self-generated data ('Simulated') Experiments Detecting macrovariables through th

Benedikt Höltgen 3 Jul 31, 2022
Simple and Distributed Machine Learning

Synapse Machine Learning SynapseML (previously MMLSpark) is an open source library to simplify the creation of scalable machine learning pipelines. Sy

Microsoft 3.9k Dec 30, 2022
GrailQA: Strongly Generalizable Question Answering

GrailQA is a new large-scale, high-quality KBQA dataset with 64,331 questions annotated with both answers and corresponding logical forms in different syntax (i.e., SPARQL, S-expression, etc.). It ca

OSU DKI Lab 76 Dec 21, 2022
A dual benchmarking study of visual forgery and visual forensics techniques

A dual benchmarking study of facial forgery and facial forensics In recent years, visual forgery has reached a level of sophistication that humans can

8 Jul 06, 2022
The Implicit Bias of Gradient Descent on Generalized Gated Linear Networks

The Implicit Bias of Gradient Descent on Generalized Gated Linear Networks This folder contains the code to reproduce the data in "The Implicit Bias o

Samuel Lippl 0 Feb 05, 2022
Transformer in Vision

Transformer-in-Vision Recent Transformer-based CV and related works. Welcome to comment/contribute! Keep updated. Resource SCENIC: A JAX Library for C

Yong-Lu Li 1.1k Dec 30, 2022
Official repository for Natural Image Matting via Guided Contextual Attention

GCA-Matting: Natural Image Matting via Guided Contextual Attention The source codes and models of Natural Image Matting via Guided Contextual Attentio

Li Yaoyi 349 Dec 26, 2022
ReLoss - Official implementation for paper "Relational Surrogate Loss Learning" ICLR 2022

Relational Surrogate Loss Learning (ReLoss) Official implementation for paper "R

Tao Huang 31 Nov 22, 2022
FaceQgen: Semi-Supervised Deep Learning for Face Image Quality Assessment

FaceQgen FaceQgen: Semi-Supervised Deep Learning for Face Image Quality Assessment This repository is based on the paper: "FaceQgen: Semi-Supervised D

Javier Hernandez-Ortega 3 Aug 04, 2022