Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution

Overview

Travis CI

Single Image Super-Resolution with EDSR, WDSR and SRGAN

A Tensorflow 2.x based implementation of

This is a complete re-write of the old Keras/Tensorflow 1.x based implementation available here. Some parts are still work in progress but you can already train models as described in the papers via a high-level training API. Furthermore, you can also fine-tune EDSR and WDSR models in an SRGAN context. Training and usage examples are given in the notebooks

A DIV2K data provider automatically downloads DIV2K training and validation images of given scale (2, 3, 4 or 8) and downgrade operator ("bicubic", "unknown", "mild" or "difficult").

Important: if you want to evaluate the pre-trained models with a dataset other than DIV2K please read this comment (and replies) first.

Environment setup

Create a new conda environment with

conda env create -f environment.yml

and activate it with

conda activate sisr

Introduction

You can find an introduction to single-image super-resolution in this article. It also demonstrates how EDSR and WDSR models can be fine-tuned with SRGAN (see also this section).

Getting started

Examples in this section require following pre-trained weights for running (see also example notebooks):

Pre-trained weights

  • weights-edsr-16-x4.tar.gz
    • EDSR x4 baseline as described in the EDSR paper: 16 residual blocks, 64 filters, 1.52M parameters.
    • PSNR on DIV2K validation set = 28.89 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-wdsr-b-32-x4.tar.gz
    • WDSR B x4 custom model: 32 residual blocks, 32 filters, expansion factor 6, 0.62M parameters.
    • PSNR on DIV2K validation set = 28.91 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-srgan.tar.gz
    • SRGAN as described in the SRGAN paper: 1.55M parameters, trained with VGG54 content loss.

After download, extract them in the root folder of the project with

tar xvfz weights-<...>.tar.gz

EDSR

from model import resolve_single
from model.edsr import edsr

from utils import load_image, plot_sample

model = edsr(scale=4, num_res_blocks=16)
model.load_weights('weights/edsr-16-x4/weights.h5')

lr = load_image('demo/0851x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-edsr

WDSR

from model.wdsr import wdsr_b

model = wdsr_b(scale=4, num_res_blocks=32)
model.load_weights('weights/wdsr-b-32-x4/weights.h5')

lr = load_image('demo/0829x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-wdsr

Weight normalization in WDSR models is implemented with the new WeightNormalization layer wrapper of Tensorflow Addons. In its latest version, this wrapper seems to corrupt weights when running model.predict(...). A workaround is to set model.run_eagerly = True or compile the model with model.compile(loss='mae') in advance. This issue doesn't arise when calling the model directly with model(...) though. To be further investigated ...

SRGAN

from model.srgan import generator

model = generator()
model.load_weights('weights/srgan/gan_generator.h5')

lr = load_image('demo/0869x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-srgan

DIV2K dataset

For training and validation on DIV2K images, applications should use the provided DIV2K data loader. It automatically downloads DIV2K images to .div2k directory and converts them to a different format for faster loading.

Training dataset

from data import DIV2K

train_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='train')      # Training dataset are images 001 - 800
                     
# Create a tf.data.Dataset          
train_ds = train_loader.dataset(batch_size=16,         # batch size as described in the EDSR and WDSR papers
                                random_transform=True, # random crop, flip, rotate as described in the EDSR paper
                                repeat_count=None)     # repeat iterating over training images indefinitely

# Iterate over LR/HR image pairs                                
for lr, hr in train_ds:
    # .... 

Crop size in HR images is 96x96.

Validation dataset

from data import DIV2K

valid_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='valid')      # Validation dataset are images 801 - 900
                     
# Create a tf.data.Dataset          
valid_ds = valid_loader.dataset(batch_size=1,           # use batch size of 1 as DIV2K images have different size
                                random_transform=False, # use DIV2K images in original size 
                                repeat_count=1)         # 1 epoch
                                
# Iterate over LR/HR image pairs                                
for lr, hr in valid_ds:
    # ....                                 

Training

The following training examples use the training and validation datasets described earlier. The high-level training API is designed around steps (= minibatch updates) rather than epochs to better match the descriptions in the papers.

EDSR

from model.edsr import edsr
from train import EdsrTrainer

# Create a training context for an EDSR x4 model with 16 
# residual blocks.
trainer = EdsrTrainer(model=edsr(scale=4, num_res_blocks=16), 
                      checkpoint_dir=f'.ckpt/edsr-16-x4')
                      
# Train EDSR model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)
              
# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/edsr-16-x4/weights.h5')                                    

Interrupting training and restarting it again resumes from the latest saved checkpoint. The trained Keras model can be accessed with trainer.model.

WDSR

from model.wdsr import wdsr_b
from train import WdsrTrainer

# Create a training context for a WDSR B x4 model with 32 
# residual blocks.
trainer = WdsrTrainer(model=wdsr_b(scale=4, num_res_blocks=32), 
                      checkpoint_dir=f'.ckpt/wdsr-b-8-x4')

# Train WDSR B model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)

# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/wdsr-b-32-x4/weights.h5')

SRGAN

Generator pre-training

from model.srgan import generator
from train import SrganGeneratorTrainer

# Create a training context for the generator (SRResNet) alone.
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')

# Pre-train the generator with 1,000,000 steps (100,000 works fine too). 
pre_trainer.train(train_ds, valid_ds.take(10), steps=1000000, evaluate_every=1000)

# Save weights of pre-trained generator (needed for fine-tuning with GAN).
pre_trainer.model.save_weights('weights/srgan/pre_generator.h5')

Generator fine-tuning (GAN)

from model.srgan import generator, discriminator
from train import SrganTrainer

# Create a new generator and init it with pre-trained weights.
gan_generator = generator()
gan_generator.load_weights('weights/srgan/pre_generator.h5')

# Create a training context for the GAN (generator + discriminator).
gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())

# Train the GAN with 200,000 steps.
gan_trainer.train(train_ds, steps=200000)

# Save weights of generator and discriminator.
gan_trainer.generator.save_weights('weights/srgan/gan_generator.h5')
gan_trainer.discriminator.save_weights('weights/srgan/gan_discriminator.h5')

SRGAN for fine-tuning EDSR and WDSR models

It is also possible to fine-tune EDSR and WDSR x4 models with SRGAN. They can be used as drop-in replacement for the original SRGAN generator. More details in this article.

# Create EDSR generator and init with pre-trained weights
generator = edsr(scale=4, num_res_blocks=16)
generator.load_weights('weights/edsr-16-x4/weights.h5')

# Fine-tune EDSR model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
# Create WDSR B generator and init with pre-trained weights
generator = wdsr_b(scale=4, num_res_blocks=32)
generator.load_weights('weights/wdsr-b-16-32/weights.h5')

# Fine-tune WDSR B  model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
Owner
Martin Krasser
Freelance machine learning engineer, software developer and consultant. Mountainbike freerider, bass guitar player.
Martin Krasser
Source code for "Understanding Knowledge Integration in Language Models with Graph Convolutions"

Graph Convolution Simulator (GCS) Source code for "Understanding Knowledge Integration in Language Models with Graph Convolutions" Requirements: PyTor

yifan 10 Oct 18, 2022
Create Own QR code with Python

Create-Own-QR-code Create Own QR code with Python SO guys in here, you have to install pyqrcode 2. open CMD and type python -m pip install pyqrcode

JehanKandy 10 Jul 13, 2022
Pytorch implementation for "Density-aware Chamfer Distance as a Comprehensive Metric for Point Cloud Completion" (NeurIPS 2021)

Density-aware Chamfer Distance This repository contains the official PyTorch implementation of our paper: Density-aware Chamfer Distance as a Comprehe

Tong WU 93 Dec 15, 2022
💡 Learnergy is a Python library for energy-based machine learning models.

Learnergy: Energy-based Machine Learners Welcome to Learnergy. Did you ever reach a bottleneck in your computational experiments? Are you tired of imp

Gustavo Rosa 57 Nov 17, 2022
A note taker for NVDA. Allows the user to create, edit, view, manage and export notes to different formats.

Quick Notetaker add-on for NVDA The Quick Notetaker add-on is a wonderful tool which allows writing notes quickly and easily anytime and from any app

5 Dec 06, 2022
Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation

STCN Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [a

Rex Cheng 456 Dec 12, 2022
LERP : Label-dependent and event-guided interpretable disease risk prediction using EHRs

LERP : Label-dependent and event-guided interpretable disease risk prediction using EHRs This is the code for the LERP. Dataset The dataset used is MI

5 Jun 18, 2022
A curated list of awesome game datasets, and tools to artificial intelligence in games

🎮 Awesome Game Datasets In computer science, Artificial Intelligence (AI) is intelligence demonstrated by machines. Its definition, AI research as th

Leonardo Mauro 454 Jan 03, 2023
Surrogate- and Invariance-Boosted Contrastive Learning (SIB-CL)

Surrogate- and Invariance-Boosted Contrastive Learning (SIB-CL) This repository contains all source code used to generate the results in the article "

Charlotte Loh 3 Jul 23, 2022
Simple API for UCI Machine Learning Dataset Repository (search, download, analyze)

A simple API for working with University of California, Irvine (UCI) Machine Learning (ML) repository Table of Contents Introduction About Page of the

Tirthajyoti Sarkar 223 Dec 05, 2022
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
From the basics to slightly more interesting applications of Tensorflow

TensorFlow Tutorials You can find python source code under the python directory, and associated notebooks under notebooks. Source code Description 1 b

Parag K Mital 5.6k Jan 09, 2023
A library for efficient similarity search and clustering of dense vectors.

Faiss Faiss is a library for efficient similarity search and clustering of dense vectors. It contains algorithms that search in sets of vectors of any

Meta Research 18.8k Jan 08, 2023
Pretraining on Dynamic Graph Neural Networks

Pretraining on Dynamic Graph Neural Networks Our article is PT-DGNN and the code is modified based on GPT-GNN Requirements python 3.6 Ubuntu 18.04.5 L

7 Dec 17, 2022
Scalable Multi-Agent Reinforcement Learning

Scalable Multi-Agent Reinforcement Learning 1. Featured algorithms: Value Function Factorization with Variable Agent Sub-Teams (VAST) [1] 2. Implement

3 Aug 02, 2022
Reproducible research and reusable acyclic workflows in Python. Execute code on HPC systems as if you executed them on your personal computer!

Reproducible research and reusable acyclic workflows in Python. Execute code on HPC systems as if you executed them on your machine! Motivation Would

Joeri Hermans 15 Sep 11, 2022
[CVPR 2021] Scan2Cap: Context-aware Dense Captioning in RGB-D Scans

Scan2Cap: Context-aware Dense Captioning in RGB-D Scans Introduction We introduce the task of dense captioning in 3D scans from commodity RGB-D sensor

Dave Z. Chen 79 Nov 07, 2022
Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones

HaloNet - Pytorch Implementation of the Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones. This re

Phil Wang 189 Nov 22, 2022
zeus is a Python implementation of the Ensemble Slice Sampling method.

zeus is a Python implementation of the Ensemble Slice Sampling method. Fast & Robust Bayesian Inference, Efficient Markov Chain Monte Carlo (MCMC), Bl

Minas Karamanis 197 Dec 04, 2022
[CVPR2021 Oral] UP-DETR: Unsupervised Pre-training for Object Detection with Transformers

UP-DETR: Unsupervised Pre-training for Object Detection with Transformers This is the official PyTorch implementation and models for UP-DETR paper: @a

dddzg 430 Dec 23, 2022