unofficial pytorch implementation of RefineGAN

Overview

RefineGAN

unofficial pytorch implementation of RefineGAN (https://arxiv.org/abs/1709.00753) for CSMRI reconstruction, the official code using tensorpack can be found at https://github.com/tmquan/RefineGAN

To Do

  • run the original tensorpack code (sorry, can't run tensorpack on my GPU)
  • pytorch implementation and experiments on brain images with radial mask
  • bug fixed. the mean psnr of zero-filled image is not exactly the same as the value in original paper, although the model improvement is similar
  • experiments on different masks

Install

python>=3.7.11 is required with all requirements.txt installed including pytorch>=1.10.0

git clone https://github.com/hellopipu/RefineGAN.git
cd RefineGAN
pip install -r requirements.txt

How to use

for training:

cd run_sh
sh train.sh

the model will be saved in folder weight, tensorboard information will be saved in folder log. You can change the arguments in script such as --mask_type and --sampling_rate for different experiment settings.

for tensorboard:

check the training curves while training

tensorboard --logdir log

the training info of my experiments is already in log folder

for testing:

test after training, or you can download my trained model weights from google drive.

cd run_sh
sh test.sh

for visualization:

cd run_sh
sh visualize.sh

training curves

sampling rates : 10%(light orange), 20%(dark blue), 30%(dark orange), 40%(light blue). You can check more loss curves of my experiments using tensorboard.

loss_G_loss_total loss_recon_img_Aa

PSNR on training set over 500 epochs, compared with results shown in original paper.

my_train_psnr paper_train_psnr

Test results

mean PSNR on validation dataset with radial mask of different sampling rates, batch_size is set as 4;

model 10% 20% 30% 40%
zero-filled 22.296 25.806 28.997 31.699
RefineGAN 32.705 36.734 39.961 42.903

Test cases visualization

rate from left to right: mask, zero-filled, prediction and ground truth error (zero-filled) and error (prediction)
10%
20%
30%
40%

Notes on RefineGAN

  • data processing before training : complex value represents in 2-channel , each channel rescale to [-1,1]; accordingly the last layer of generator is tanh()
  • Generator uses residual learning for reconstruction task
  • Generator is a cascade of two U-net, the U-net doesn't do concatenation but addition when combining the enc and dec features.
  • each U-net is followed by a Data-consistency (DC) module, although the paper doesn't mention it.
  • the last layer of generator is tanh layer on two-channel output, so when we revert output to original pixel scale and calculate abs, the pixel value may exceed 255; we need to do clipping while calculating psnr
  • while training, we get two random image samples A, B for each iteration, RefineGAN calculates a large amount of losses (it may be redundant) including reconstruction loss on different phases of generator output in both image domain and frequency domain, total variantion loss and WGAN loss
  • one special loss is D_loss_AB, D is trained to only distinguish from real samples and fake samples, so D should not only work for (real A, fake A) or (real B, fake B), but also work for (real A, fake B) input
  • WGAN-gp may be used to improve the performance
  • small batch size MAY BE better. In my experiment, batch_size=4 is better than batch_size=16

I will appreciate if you can find any implementation mistakes in codes.

Owner
xinby17
research interest: Medical Image Analysis, Computer Vision
xinby17
《Train in Germany, Test in The USA: Making 3D Object Detectors Generalize》(CVPR 2020)

Train in Germany, Test in The USA: Making 3D Object Detectors Generalize This paper has been accpeted by Conference on Computer Vision and Pattern Rec

Xiangyu Chen 101 Jan 02, 2023
One line to host them all. Bootstrap your image search case in minutes.

One line to host them all. Bootstrap your image search case in minutes. Survey NOW gives the world access to customized neural image search in just on

Jina AI 403 Dec 30, 2022
A highly efficient and modular implementation of Gaussian Processes in PyTorch

GPyTorch GPyTorch is a Gaussian process library implemented using PyTorch. GPyTorch is designed for creating scalable, flexible, and modular Gaussian

3k Jan 02, 2023
Experiments with the Robust Binary Interval Search (RBIS) algorithm, a Query-Based prediction algorithm for the Online Search problem.

OnlineSearchRBIS Online Search with Best-Price and Query-Based Predictions This is the implementation of the Robust Binary Interval Search (RBIS) algo

S. K. 1 Apr 16, 2022
Chinese Advertisement Board Identification(Pytorch)

Chinese-Advertisement-Board-Identification. We use YoloV5 to extract the ROI of the location of the chinese word. Next, we sort the bounding box and recognize every chinese words which we extracted.

Li-Wei Hsiao 12 Jul 21, 2022
Intelligent Video Analytics toolkit based on different inference backends.

English | 中文 OpenIVA OpenIVA is an end-to-end intelligent video analytics development toolkit based on different inference backends, designed to help

Quantum Liu 15 Oct 27, 2022
Header-only library for using Keras models in C++.

frugally-deep Use Keras models in C++ with ease Table of contents Introduction Usage Performance Requirements and Installation FAQ Introduction Would

Tobias Hermann 927 Jan 05, 2023
Semantically Contrastive Learning for Low-light Image Enhancement

Semantically Contrastive Learning for Low-light Image Enhancement Here, we propose an effective semantically contrastive learning paradigm for Low-lig

48 Dec 16, 2022
High dimensional black-box optimizer using Latent Action Monte Carlo Tree Search algorithm

LA-MCTS The code is based of paper Learning Search Space Partition for Black-box Optimization using Monte Carlo Tree Search. Component LA-MCTS has thr

Meta Research 18 Oct 24, 2022
Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships.

feature-set-comp Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships. Reposito

Trent Henderson 7 May 25, 2022
HeatNet is a python package that provides tools to build, train and evaluate neural networks designed to predict extreme heat wave events globally on daily to subseasonal timescales.

HeatNet HeatNet is a python package that provides tools to build, train and evaluate neural networks designed to predict extreme heat wave events glob

Google Research 6 Jul 07, 2022
Robot Hacking Manual (RHM). From robotics to cybersecurity. Papers, notes and writeups from a journey into robot cybersecurity.

RHM: Robot Hacking Manual Download in PDF RHM v0.4 ┃ Read online The Robot Hacking Manual (RHM) is an introductory series about cybersecurity for robo

Víctor Mayoral Vilches 233 Dec 30, 2022
Random Walk Graph Neural Networks

Random Walk Graph Neural Networks This repository is the official implementation of Random Walk Graph Neural Networks. Requirements Code is written in

Giannis Nikolentzos 38 Jan 02, 2023
Yoga - Yoga asana classifier for python

Yoga Asana Classifier Description Hi welcome to my new deep learning project "Yo

Programminghut 35 Dec 12, 2022
Official release of MSHT: Multi-stage Hybrid Transformer for the ROSE Image Analysis of Pancreatic Cancer axriv: http://arxiv.org/abs/2112.13513

MSHT: Multi-stage Hybrid Transformer for the ROSE Image Analysis This is the official page of the MSHT with its experimental script and records. We de

Tianyi Zhang 53 Dec 27, 2022
Official implementation of the ICML2021 paper "Elastic Graph Neural Networks"

ElasticGNN This repository includes the official implementation of ElasticGNN in the paper "Elastic Graph Neural Networks" [ICML 2021]. Xiaorui Liu, W

liuxiaorui 34 Dec 04, 2022
Author's PyTorch implementation of TD3 for OpenAI gym tasks

Addressing Function Approximation Error in Actor-Critic Methods PyTorch implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3). If y

Scott Fujimoto 1.3k Dec 25, 2022
A booklet on machine learning systems design with exercises

Machine Learning Systems Design Read this booklet here. This booklet covers four main steps of designing a machine learning system: Project setup Data

Chip Huyen 7.6k Jan 08, 2023
QICK: Quantum Instrumentation Control Kit

QICK: Quantum Instrumentation Control Kit The QICK is a kit of firmware and software to use the Xilinx RFSoC to control quantum systems. It consists o

81 Dec 15, 2022
A static analysis library for computing graph representations of Python programs suitable for use with graph neural networks.

python_graphs This package is for computing graph representations of Python programs for machine learning applications. It includes the following modu

Google Research 258 Dec 29, 2022