Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution

Overview

Travis CI

Single Image Super-Resolution with EDSR, WDSR and SRGAN

A Tensorflow 2.x based implementation of

This is a complete re-write of the old Keras/Tensorflow 1.x based implementation available here. Some parts are still work in progress but you can already train models as described in the papers via a high-level training API. Furthermore, you can also fine-tune EDSR and WDSR models in an SRGAN context. Training and usage examples are given in the notebooks

A DIV2K data provider automatically downloads DIV2K training and validation images of given scale (2, 3, 4 or 8) and downgrade operator ("bicubic", "unknown", "mild" or "difficult").

Important: if you want to evaluate the pre-trained models with a dataset other than DIV2K please read this comment (and replies) first.

Environment setup

Create a new conda environment with

conda env create -f environment.yml

and activate it with

conda activate sisr

Introduction

You can find an introduction to single-image super-resolution in this article. It also demonstrates how EDSR and WDSR models can be fine-tuned with SRGAN (see also this section).

Getting started

Examples in this section require following pre-trained weights for running (see also example notebooks):

Pre-trained weights

  • weights-edsr-16-x4.tar.gz
    • EDSR x4 baseline as described in the EDSR paper: 16 residual blocks, 64 filters, 1.52M parameters.
    • PSNR on DIV2K validation set = 28.89 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-wdsr-b-32-x4.tar.gz
    • WDSR B x4 custom model: 32 residual blocks, 32 filters, expansion factor 6, 0.62M parameters.
    • PSNR on DIV2K validation set = 28.91 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-srgan.tar.gz
    • SRGAN as described in the SRGAN paper: 1.55M parameters, trained with VGG54 content loss.

After download, extract them in the root folder of the project with

tar xvfz weights-<...>.tar.gz

EDSR

from model import resolve_single
from model.edsr import edsr

from utils import load_image, plot_sample

model = edsr(scale=4, num_res_blocks=16)
model.load_weights('weights/edsr-16-x4/weights.h5')

lr = load_image('demo/0851x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-edsr

WDSR

from model.wdsr import wdsr_b

model = wdsr_b(scale=4, num_res_blocks=32)
model.load_weights('weights/wdsr-b-32-x4/weights.h5')

lr = load_image('demo/0829x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-wdsr

Weight normalization in WDSR models is implemented with the new WeightNormalization layer wrapper of Tensorflow Addons. In its latest version, this wrapper seems to corrupt weights when running model.predict(...). A workaround is to set model.run_eagerly = True or compile the model with model.compile(loss='mae') in advance. This issue doesn't arise when calling the model directly with model(...) though. To be further investigated ...

SRGAN

from model.srgan import generator

model = generator()
model.load_weights('weights/srgan/gan_generator.h5')

lr = load_image('demo/0869x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-srgan

DIV2K dataset

For training and validation on DIV2K images, applications should use the provided DIV2K data loader. It automatically downloads DIV2K images to .div2k directory and converts them to a different format for faster loading.

Training dataset

from data import DIV2K

train_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='train')      # Training dataset are images 001 - 800
                     
# Create a tf.data.Dataset          
train_ds = train_loader.dataset(batch_size=16,         # batch size as described in the EDSR and WDSR papers
                                random_transform=True, # random crop, flip, rotate as described in the EDSR paper
                                repeat_count=None)     # repeat iterating over training images indefinitely

# Iterate over LR/HR image pairs                                
for lr, hr in train_ds:
    # .... 

Crop size in HR images is 96x96.

Validation dataset

from data import DIV2K

valid_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='valid')      # Validation dataset are images 801 - 900
                     
# Create a tf.data.Dataset          
valid_ds = valid_loader.dataset(batch_size=1,           # use batch size of 1 as DIV2K images have different size
                                random_transform=False, # use DIV2K images in original size 
                                repeat_count=1)         # 1 epoch
                                
# Iterate over LR/HR image pairs                                
for lr, hr in valid_ds:
    # ....                                 

Training

The following training examples use the training and validation datasets described earlier. The high-level training API is designed around steps (= minibatch updates) rather than epochs to better match the descriptions in the papers.

EDSR

from model.edsr import edsr
from train import EdsrTrainer

# Create a training context for an EDSR x4 model with 16 
# residual blocks.
trainer = EdsrTrainer(model=edsr(scale=4, num_res_blocks=16), 
                      checkpoint_dir=f'.ckpt/edsr-16-x4')
                      
# Train EDSR model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)
              
# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/edsr-16-x4/weights.h5')                                    

Interrupting training and restarting it again resumes from the latest saved checkpoint. The trained Keras model can be accessed with trainer.model.

WDSR

from model.wdsr import wdsr_b
from train import WdsrTrainer

# Create a training context for a WDSR B x4 model with 32 
# residual blocks.
trainer = WdsrTrainer(model=wdsr_b(scale=4, num_res_blocks=32), 
                      checkpoint_dir=f'.ckpt/wdsr-b-8-x4')

# Train WDSR B model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)

# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/wdsr-b-32-x4/weights.h5')

SRGAN

Generator pre-training

from model.srgan import generator
from train import SrganGeneratorTrainer

# Create a training context for the generator (SRResNet) alone.
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')

# Pre-train the generator with 1,000,000 steps (100,000 works fine too). 
pre_trainer.train(train_ds, valid_ds.take(10), steps=1000000, evaluate_every=1000)

# Save weights of pre-trained generator (needed for fine-tuning with GAN).
pre_trainer.model.save_weights('weights/srgan/pre_generator.h5')

Generator fine-tuning (GAN)

from model.srgan import generator, discriminator
from train import SrganTrainer

# Create a new generator and init it with pre-trained weights.
gan_generator = generator()
gan_generator.load_weights('weights/srgan/pre_generator.h5')

# Create a training context for the GAN (generator + discriminator).
gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())

# Train the GAN with 200,000 steps.
gan_trainer.train(train_ds, steps=200000)

# Save weights of generator and discriminator.
gan_trainer.generator.save_weights('weights/srgan/gan_generator.h5')
gan_trainer.discriminator.save_weights('weights/srgan/gan_discriminator.h5')

SRGAN for fine-tuning EDSR and WDSR models

It is also possible to fine-tune EDSR and WDSR x4 models with SRGAN. They can be used as drop-in replacement for the original SRGAN generator. More details in this article.

# Create EDSR generator and init with pre-trained weights
generator = edsr(scale=4, num_res_blocks=16)
generator.load_weights('weights/edsr-16-x4/weights.h5')

# Fine-tune EDSR model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
# Create WDSR B generator and init with pre-trained weights
generator = wdsr_b(scale=4, num_res_blocks=32)
generator.load_weights('weights/wdsr-b-16-32/weights.h5')

# Fine-tune WDSR B  model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
Owner
Martin Krasser
Freelance machine learning engineer, software developer and consultant. Mountainbike freerider, bass guitar player.
Martin Krasser
PyTorch implementation of "Conformer: Convolution-augmented Transformer for Speech Recognition" (INTERSPEECH 2020)

PyTorch implementation of Conformer: Convolution-augmented Transformer for Speech Recognition. Transformer models are good at capturing content-based

Soohwan Kim 565 Jan 04, 2023
This repo in the implementation of EMNLP'21 paper "SPARQLing Database Queries from Intermediate Question Decompositions" by Irina Saparina, Anton Osokin

SPARQLing Database Queries from Intermediate Question Decompositions This repo is the implementation of the following paper: SPARQLing Database Querie

Yandex Research 20 Dec 19, 2022
Applying CLIP to Point Cloud Recognition.

PointCLIP: Point Cloud Understanding by CLIP This repository is an official implementation of the paper 'PointCLIP: Point Cloud Understanding by CLIP'

Renrui Zhang 175 Dec 24, 2022
This project is a re-implementation of MASTER: Multi-Aspect Non-local Network for Scene Text Recognition by MMOCR

This project is a re-implementation of MASTER: Multi-Aspect Non-local Network for Scene Text Recognition by MMOCR,which is an open-source toolbox based on PyTorch. The overall architecture will be sh

Jianquan Ye 82 Nov 17, 2022
Simple node deletion tool for onnx.

snd4onnx Simple node deletion tool for onnx. I only test very miscellaneous and limited patterns as a hobby. There are probably a large number of bugs

Katsuya Hyodo 6 May 15, 2022
Code for the AAAI 2022 paper "Zero-Shot Cross-Lingual Machine Reading Comprehension via Inter-Sentence Dependency Graph".

multilingual-mrc-isdg Code for the AAAI 2022 paper "Zero-Shot Cross-Lingual Machine Reading Comprehension via Inter-Sentence Dependency Graph". This r

Liyan 5 Dec 07, 2022
Code for Contrastive-Geometry Networks for Generalized 3D Pose Transfer

CGTransformer Code for our AAAI 2022 paper "Contrastive-Geometry Transformer network for Generalized 3D Pose Transfer" Contrastive-Geometry Transforme

18 Jun 28, 2022
Addon and nodes for working with structural biology and molecular data in Blender.

Molecular Nodes 🧬 🔬 💻 Buy Me a Coffee to Keep Development Going! Join a Community of Blender SciVis People! What is Molecular Nodes? Molecular Node

Brady Johnston 456 Jan 08, 2023
Training PSPNet in Tensorflow. Reproduce the performance from the paper.

Training Reproduce of PSPNet. (Updated 2021/04/09. Authors of PSPNet have provided a Pytorch implementation for PSPNet and their new work with support

Li Xuhong 126 Jul 13, 2022
For medical image segmentation

LeViT_UNet For medical image segmentation Our model is based on LeViT (https://github.com/facebookresearch/LeViT). You'd better gitclone its codes. Th

13 Dec 24, 2022
💡 Type hints for Numpy

Type hints with dynamic checks for Numpy! (❒) Installation pip install nptyping (❒) Usage (❒) NDArray nptyping.NDArray lets you define the shape and

Ramon Hagenaars 377 Dec 28, 2022
A non-linear, non-parametric Machine Learning method capable of modeling complex datasets

Fast Symbolic Regression Symbolic Regression is a non-linear, non-parametric Machine Learning method capable of modeling complex data sets. fastsr aim

VAMSHI CHOWDARY 3 Jun 22, 2022
Code to run experiments in SLOE: A Faster Method for Statistical Inference in High-Dimensional Logistic Regression.

Code to run experiments in SLOE: A Faster Method for Statistical Inference in High-Dimensional Logistic Regression. Not an official Google product. Me

Google Research 27 Dec 12, 2022
Go from graph data to a secure and interactive visual graph app in 15 minutes. Batteries-included self-hosting of graph data apps with Streamlit, Graphistry, RAPIDS, and more!

✔️ Linux ✔️ OS X ❌ Windows (#39) Welcome to graph-app-kit Turn your graph data into a secure and interactive visual graph app in 15 minutes! Why This

Graphistry 107 Jan 02, 2023
DABO: Data Augmentation with Bilevel Optimization

DABO: Data Augmentation with Bilevel Optimization [Paper] The goal is to automatically learn an efficient data augmentation regime for image classific

ElementAI 24 Aug 12, 2022
Integrated Semantic and Phonetic Post-correction for Chinese Speech Recognition

Integrated Semantic and Phonetic Post-correction for Chinese Speech Recognition | paper | dataset | pretrained detection model | Authors: Yi-Chang Che

Yi-Chang Chen 1 Aug 23, 2022
(CVPR 2022) A minimalistic mapless end-to-end stack for joint perception, prediction, planning and control for self driving.

LAV Learning from All Vehicles Dian Chen, Philipp Krähenbühl CVPR 2022 (also arXiV 2203.11934) This repo contains code for paper Learning from all veh

Dian Chen 300 Dec 15, 2022
Revisiting Video Saliency: A Large-scale Benchmark and a New Model (CVPR18, PAMI19)

DHF1K =========================================================================== Wenguan Wang, J. Shen, M.-M Cheng and A. Borji, Revisiting Video Sal

Wenguan Wang 126 Dec 03, 2022
ImageNet-CoG is a benchmark for concept generalization. It provides a full evaluation framework for pre-trained visual representations which measure how well they generalize to unseen concepts.

The ImageNet-CoG Benchmark Project Website Paper (arXiv) Code repository for the ImageNet-CoG Benchmark introduced in the paper "Concept Generalizatio

NAVER 23 Oct 09, 2022
Exponential Graph is Provably Efficient for Decentralized Deep Training

Exponential Graph is Provably Efficient for Decentralized Deep Training This code repository is for the paper Exponential Graph is Provably Efficient

3 Apr 20, 2022