Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution

Overview

Travis CI

Single Image Super-Resolution with EDSR, WDSR and SRGAN

A Tensorflow 2.x based implementation of

This is a complete re-write of the old Keras/Tensorflow 1.x based implementation available here. Some parts are still work in progress but you can already train models as described in the papers via a high-level training API. Furthermore, you can also fine-tune EDSR and WDSR models in an SRGAN context. Training and usage examples are given in the notebooks

A DIV2K data provider automatically downloads DIV2K training and validation images of given scale (2, 3, 4 or 8) and downgrade operator ("bicubic", "unknown", "mild" or "difficult").

Important: if you want to evaluate the pre-trained models with a dataset other than DIV2K please read this comment (and replies) first.

Environment setup

Create a new conda environment with

conda env create -f environment.yml

and activate it with

conda activate sisr

Introduction

You can find an introduction to single-image super-resolution in this article. It also demonstrates how EDSR and WDSR models can be fine-tuned with SRGAN (see also this section).

Getting started

Examples in this section require following pre-trained weights for running (see also example notebooks):

Pre-trained weights

  • weights-edsr-16-x4.tar.gz
    • EDSR x4 baseline as described in the EDSR paper: 16 residual blocks, 64 filters, 1.52M parameters.
    • PSNR on DIV2K validation set = 28.89 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-wdsr-b-32-x4.tar.gz
    • WDSR B x4 custom model: 32 residual blocks, 32 filters, expansion factor 6, 0.62M parameters.
    • PSNR on DIV2K validation set = 28.91 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-srgan.tar.gz
    • SRGAN as described in the SRGAN paper: 1.55M parameters, trained with VGG54 content loss.

After download, extract them in the root folder of the project with

tar xvfz weights-<...>.tar.gz

EDSR

from model import resolve_single
from model.edsr import edsr

from utils import load_image, plot_sample

model = edsr(scale=4, num_res_blocks=16)
model.load_weights('weights/edsr-16-x4/weights.h5')

lr = load_image('demo/0851x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-edsr

WDSR

from model.wdsr import wdsr_b

model = wdsr_b(scale=4, num_res_blocks=32)
model.load_weights('weights/wdsr-b-32-x4/weights.h5')

lr = load_image('demo/0829x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-wdsr

Weight normalization in WDSR models is implemented with the new WeightNormalization layer wrapper of Tensorflow Addons. In its latest version, this wrapper seems to corrupt weights when running model.predict(...). A workaround is to set model.run_eagerly = True or compile the model with model.compile(loss='mae') in advance. This issue doesn't arise when calling the model directly with model(...) though. To be further investigated ...

SRGAN

from model.srgan import generator

model = generator()
model.load_weights('weights/srgan/gan_generator.h5')

lr = load_image('demo/0869x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-srgan

DIV2K dataset

For training and validation on DIV2K images, applications should use the provided DIV2K data loader. It automatically downloads DIV2K images to .div2k directory and converts them to a different format for faster loading.

Training dataset

from data import DIV2K

train_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='train')      # Training dataset are images 001 - 800
                     
# Create a tf.data.Dataset          
train_ds = train_loader.dataset(batch_size=16,         # batch size as described in the EDSR and WDSR papers
                                random_transform=True, # random crop, flip, rotate as described in the EDSR paper
                                repeat_count=None)     # repeat iterating over training images indefinitely

# Iterate over LR/HR image pairs                                
for lr, hr in train_ds:
    # .... 

Crop size in HR images is 96x96.

Validation dataset

from data import DIV2K

valid_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='valid')      # Validation dataset are images 801 - 900
                     
# Create a tf.data.Dataset          
valid_ds = valid_loader.dataset(batch_size=1,           # use batch size of 1 as DIV2K images have different size
                                random_transform=False, # use DIV2K images in original size 
                                repeat_count=1)         # 1 epoch
                                
# Iterate over LR/HR image pairs                                
for lr, hr in valid_ds:
    # ....                                 

Training

The following training examples use the training and validation datasets described earlier. The high-level training API is designed around steps (= minibatch updates) rather than epochs to better match the descriptions in the papers.

EDSR

from model.edsr import edsr
from train import EdsrTrainer

# Create a training context for an EDSR x4 model with 16 
# residual blocks.
trainer = EdsrTrainer(model=edsr(scale=4, num_res_blocks=16), 
                      checkpoint_dir=f'.ckpt/edsr-16-x4')
                      
# Train EDSR model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)
              
# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/edsr-16-x4/weights.h5')                                    

Interrupting training and restarting it again resumes from the latest saved checkpoint. The trained Keras model can be accessed with trainer.model.

WDSR

from model.wdsr import wdsr_b
from train import WdsrTrainer

# Create a training context for a WDSR B x4 model with 32 
# residual blocks.
trainer = WdsrTrainer(model=wdsr_b(scale=4, num_res_blocks=32), 
                      checkpoint_dir=f'.ckpt/wdsr-b-8-x4')

# Train WDSR B model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)

# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/wdsr-b-32-x4/weights.h5')

SRGAN

Generator pre-training

from model.srgan import generator
from train import SrganGeneratorTrainer

# Create a training context for the generator (SRResNet) alone.
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')

# Pre-train the generator with 1,000,000 steps (100,000 works fine too). 
pre_trainer.train(train_ds, valid_ds.take(10), steps=1000000, evaluate_every=1000)

# Save weights of pre-trained generator (needed for fine-tuning with GAN).
pre_trainer.model.save_weights('weights/srgan/pre_generator.h5')

Generator fine-tuning (GAN)

from model.srgan import generator, discriminator
from train import SrganTrainer

# Create a new generator and init it with pre-trained weights.
gan_generator = generator()
gan_generator.load_weights('weights/srgan/pre_generator.h5')

# Create a training context for the GAN (generator + discriminator).
gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())

# Train the GAN with 200,000 steps.
gan_trainer.train(train_ds, steps=200000)

# Save weights of generator and discriminator.
gan_trainer.generator.save_weights('weights/srgan/gan_generator.h5')
gan_trainer.discriminator.save_weights('weights/srgan/gan_discriminator.h5')

SRGAN for fine-tuning EDSR and WDSR models

It is also possible to fine-tune EDSR and WDSR x4 models with SRGAN. They can be used as drop-in replacement for the original SRGAN generator. More details in this article.

# Create EDSR generator and init with pre-trained weights
generator = edsr(scale=4, num_res_blocks=16)
generator.load_weights('weights/edsr-16-x4/weights.h5')

# Fine-tune EDSR model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
# Create WDSR B generator and init with pre-trained weights
generator = wdsr_b(scale=4, num_res_blocks=32)
generator.load_weights('weights/wdsr-b-16-32/weights.h5')

# Fine-tune WDSR B  model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
Owner
Martin Krasser
Freelance machine learning engineer, software developer and consultant. Mountainbike freerider, bass guitar player.
Martin Krasser
Official implementation of the paper Image Generators with Conditionally-Independent Pixel Synthesis https://arxiv.org/abs/2011.13775

CIPS -- Official Pytorch Implementation of the paper Image Generators with Conditionally-Independent Pixel Synthesis Requirements pip install -r requi

Multimodal Lab @ Samsung AI Center Moscow 201 Dec 21, 2022
[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling Introduction Contrastive learning approaches have achieved great success in

VITA 24 Dec 17, 2022
Data and extra materials for the food safety publications classifier

Data and extra materials for the food safety publications classifier The subdirectories contain detailed descriptions of their contents in the README.

1 Jan 20, 2022
PyTorch code for the "Deep Neural Networks with Box Convolutions" paper

Box Convolution Layer for ConvNets Single-box-conv network (from `examples/mnist.py`) learns patterns on MNIST What This Is This is a PyTorch implemen

Egor Burkov 515 Dec 18, 2022
PyMatting: A Python Library for Alpha Matting

Given an input image and a hand-drawn trimap (top row), alpha matting estimates the alpha channel of a foreground object which can then be composed onto a different background (bottom row).

PyMatting 1.4k Dec 30, 2022
PyTorch implementation of DeepUME: Learning the Universal Manifold Embedding for Robust Point Cloud Registration (BMVC 2021)

DeepUME: Learning the Universal Manifold Embedding for Robust Point Cloud Registration [video] [paper] [supplementary] [data] [thesis] Introduction De

Natalie Lang 10 Dec 14, 2022
[CVPR 2016] Unsupervised Feature Learning by Image Inpainting using GANs

Context Encoders: Feature Learning by Inpainting CVPR 2016 [Project Website] [Imagenet Results] Sample results on held-out images: This is the trainin

Deepak Pathak 829 Dec 31, 2022
TargetAllDomainObjects - A python wrapper to run a command on against all users/computers/DCs of a Windows Domain

TargetAllDomainObjects A python wrapper to run a command on against all users/co

Podalirius 19 Dec 13, 2022
This is a simple framework to make object detection dataset very quickly

FastAnnotation Table of contents General info Requirements Setup General info This is a simple framework to make object detection dataset very quickly

Serena Tetart 1 Jan 24, 2022
TransZero++: Cross Attribute-guided Transformer for Zero-Shot Learning

TransZero++ This repository contains the testing code for the paper "TransZero++: Cross Attribute-guided Transformer for Zero-Shot Learning" submitted

Shiming Chen 6 Aug 16, 2022
Keras Image Embeddings using Contrastive Loss

Keras-Image-Embeddings-using-Contrastive-Loss Image to Embedding projection in vector space. Implementation in keras and tensorflow for custom data. B

Shravan Anand K 5 Mar 21, 2022
EfficientNetV2 implementation using PyTorch

EfficientNetV2-S implementation using PyTorch Train Steps Configure imagenet path by changing data_dir in train.py python main.py --benchmark for mode

Jahongir Yunusov 86 Dec 29, 2022
PyTorch code of paper "LiVLR: A Lightweight Visual-Linguistic Reasoning Framework for Video Question Answering"

LiVLR-VideoQA We propose a Lightweight Visual-Linguistic Reasoning framework (LiVLR) for VideoQA. The overview of LiVLR: Evaluation on MSRVTT-QA Datas

JJ Jiang 7 Dec 30, 2022
[CVPR 2022] Deep Equilibrium Optical Flow Estimation

Deep Equilibrium Optical Flow Estimation This is the official repo for the paper Deep Equilibrium Optical Flow Estimation (CVPR 2022), by Shaojie Bai*

CMU Locus Lab 136 Dec 18, 2022
Implementation of Segnet, FCN, UNet , PSPNet and other models in Keras.

Image Segmentation Keras : Implementation of Segnet, FCN, UNet, PSPNet and other models in Keras. Implementation of various Deep Image Segmentation mo

Divam Gupta 2.6k Jan 05, 2023
Official PyTorch code for Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021)

Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021) This repository is the official PyTorc

Jingyun Liang 139 Dec 29, 2022
Gas detection for Raspberry Pi using ADS1x15 and MQ-2 sensors

Gas detection Gas detection for Raspberry Pi using ADS1x15 and MQ-2 sensors. Description The MQ-2 sensor can detect multiple gases (CO, H2, CH4, LPG,

Filip Š 15 Sep 30, 2022
InsTrim: Lightweight Instrumentation for Coverage-guided Fuzzing

InsTrim The paper: InsTrim: Lightweight Instrumentation for Coverage-guided Fuzzing Build Prerequisite llvm-8.0-dev clang-8.0 cmake = 3.2 Make git cl

75 Dec 23, 2022
A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs.

PYGON A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs. Installation This code requires to install and run the graph

Yoram Louzoun's Lab 0 Jun 25, 2021
PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.

DLR-RM 4.7k Jan 01, 2023