PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

Overview

StarEnhancer

StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral)

Abstract: Image enhancement is a subjective process whose targets vary with user preferences. In this paper, we propose a deep learning-based image enhancement method covering multiple tonal styles using only a single model dubbed StarEnhancer. It can transform an image from one tonal style to another, even if that style is unseen. With a simple one-time setting, users can customize the model to make the enhanced images more in line with their aesthetics. To make the method more practical, we propose a well-designed enhancer that can process a 4K-resolution image over 200 FPS but surpasses the contemporaneous single style image enhancement methods in terms of PSNR, SSIM, and LPIPS. Finally, our proposed enhancement method has good interactability, which allows the user to fine-tune the enhanced image using intuitive options.

StarEnhancer

Getting started

Install

We test the code on PyTorch 1.8.1 + CUDA 11.1 + cuDNN 8.0.5, and close versions also work fine.

pip install -r requirements.txt

We mainly train the model on RTX 2080Ti * 4, but a smaller mini batch size can also work.

Prepare

You can generate your own dataset, or download the one we generate.

The final file path should be the same as the following:

┬─ save_model
│   ├─ stylish.pth.tar
│   └─ ... (model & embedding)
└─ data
    ├─ train
    │   ├─ 01-Experts-A
    │   │   ├─ a0001.jpg
    │   │   └─ ... (id.jpg)
    │   └─ ... (style folder)
    ├─ valid
    │   └─ ... (style folder)
    └─ test
        └─ ... (style folder)

Download

Data and pretrained models are available on GoogleDrive.

Generate

  1. Download raw data from MIT-Adobe FiveK Dataset.
  2. Download the modified Lightroom database fivek.lrcat, and replace the original database with it.
  3. Generate dataset in JPEG format with quality 100, which can refer to this issue.
  4. Run generate_dataset.py in data folder to generate dataset.

Train

Firstly, train the style encoder:

python train_stylish.py

Secondly, fetch the style embedding for each sample in the train set:

python fetch_embedding.py

Lastly, train the curve encoder and mapping network:

python train_enhancer.py

Test

Just run:

python test.py

Testing LPIPS requires about 10 GB GPU memory, and if an OOM occurs, replace the following lines

lpips_val = loss_fn_alex(output * 2 - 1, target_img * 2 - 1).item()

with

lpips_val = 0

Notes

Due to agreements, we are unable to release part of the source code. This repository provides a pure python implementation for research use. There are some differences between the repository and the paper as follows:

  1. The repository uses a ResNet-18 w/o BN as the curve encoder's backbone, and the paper uses a more lightweight model.
  2. The paper uses CUDA to implement the color transform function, and the repository uses torch.gather to implement it.
  3. The repository removes some tricks used in training lightweight models.

Overall, this repository can achieve higher performance, but will be slightly slower.

Comments
  • Multi-style, unpaired setting

    Multi-style, unpaired setting

    您好,在多风格非配对图场景,能否交换source和target的位置,并将得到的output_A和output_B进一步经过enhancer,得到recover_A和recover_B。最后计算l1_loss(source, recover_A)和l1_loss(target, recover_B)及Triplet_loss(output_A,target, source) 和 Triplet_loss(output_B,source,target)

    def train(train_loader, mapping, enhancer, criterion, optimizer):
        losses = AverageMeter()
        criterionTriplet = torch.nn.TripletMarginLoss(margin=1.0, p=2)
        FEModel = Feature_Extract_Model().cuda()
    
        mapping.train()
        enhancer.train()
    
        for (source_img, source_center, target_img, target_center) in train_loader:
            source_img = source_img.cuda(non_blocking=True)
            source_center = source_center.cuda(non_blocking=True)
            target_img = target_img.cuda(non_blocking=True)
            target_center = target_center.cuda(non_blocking=True)
    
            style_A = mapping(source_center)
            style_B = mapping(target_center)
    
            output_A = enhancer(source_img, style_A, style_B)
            output_B = enhancer(target_img, style_B, style_A)
            recoverA = enhancer(output_A, style_B, style_A)
            recoverB = enhancer(output_B, style_A, style_B)
    
            source_img_feature = FEModel(source_img)
            target_img_feature = FEModel(target_img)
            output_A_feature = FEModel(output_A)
            output_B_feature = FEModel(output_B)
    
            loss_l1 = criterion(recoverA, source_img) + criterion(recoverB, target_img)
            loss_triplet = criterionTriplet(output_B_feature, source_img_feature, target_img_feature) + \
                           criterionTriplet(output_A_feature, target_img_feature, source_img_feature)
            loss = loss_l1 + loss_triplet
    
            losses.update(loss.item(), args.t_batch_size)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        return losses.avg
    
    opened by jxust01 4
  • Questions about dataset preparation

    Questions about dataset preparation

    您好,我想用您的工程跑一下自己的数据,现在有输入,输出一组数据对,训练数据里面A-E剩下的4种效果是怎样生成的呢,这些目标效果数据能否是非成对的呢?如果只有一种风格,能否A-E目标效果都拷贝成一样的数据呢,在train_enhancer.py所训练的单风格脚本是需要embeddings.npy文件,这个文件在单风格训练时是必须的吗

    opened by zener90818 4
  • Dataset processing

    Dataset processing

    你好,我在您提供的fivek.lrcat没找到 DeepUPE issue里的"(default) input with ExpertC"。请问单风格实验的输入是下图中的“InputAsShotZeroed”还是“(Q)InputZeroed with ExpertC WhiteBalance” image

    opened by madfff 2
  • Configure Renovate

    Configure Renovate

    WhiteSource Renovate

    Welcome to Renovate! This is an onboarding PR to help you understand and configure settings before regular Pull Requests begin.

    🚦 To activate Renovate, merge this Pull Request. To disable Renovate, simply close this Pull Request unmerged.


    Detected Package Files

    • requirements.txt (pip_requirements)

    Configuration Summary

    Based on the default config's presets, Renovate will:

    • Start dependency updates only once this onboarding PR is merged
    • Enable Renovate Dependency Dashboard creation
    • If semantic commits detected, use semantic commit type fix for dependencies and chore for all others
    • Ignore node_modules, bower_components, vendor and various test/tests directories
    • Autodetect whether to pin dependencies or maintain ranges
    • Rate limit PR creation to a maximum of two per hour
    • Limit to maximum 20 open PRs at any time
    • Group known monorepo packages together
    • Use curated list of recommended non-monorepo package groupings
    • Fix some problems with very old Maven commons versions
    • Ignore spring cloud 1.x releases
    • Ignore http4s digest-based 1.x milestones
    • Use node versioning for @types/node
    • Limit concurrent requests to reduce load on Repology servers until we can fix this properly, see issue 10133

    🔡 Would you like to change the way Renovate is upgrading your dependencies? Simply edit the renovate.json in this branch with your custom config and the list of Pull Requests in the "What to Expect" section below will be updated the next time Renovate runs.


    What to Expect

    With your current configuration, Renovate will create 1 Pull Request:

    Pin dependency torch to ==1.10.0
    • Schedule: ["at any time"]
    • Branch name: renovate/pin-dependencies
    • Merge into: main
    • Pin torch to ==1.10.0

    ❓ Got questions? Check out Renovate's Docs, particularly the Getting Started section. If you need any further assistance then you can also request help here.


    This PR has been generated by WhiteSource Renovate. View repository job log here.

    opened by renovate[bot] 1
  • The results are not the same as the paper

    The results are not the same as the paper

    I am the author.

    Some peers have emailed me asking about the performance of the open source model that does not agree with the results in the paper. As stated in the README, the model is not the model of the paper, but the performance is similar. The exact result should be: PSNR: 25.41, SSIM: 0.942, LPIPS: 0.085

    If you find that your result is not this, then it may be that the JPEG codec is different, which is related to the version of opencv and how it is installed.

    You can uninstall your opencv (either with pip or conda) and reinstall it using pip (it must be pip, because conda installs a different JPEG codec):

    pip install opencv-python==4.5.5.62​
    
    opened by IDKiro 0
Owner
IDKiro
Stroll in the abyss
IDKiro
A simple but complete full-attention transformer with a set of promising experimental features from various papers

x-transformers A concise but fully-featured transformer, complete with a set of promising experimental features from various papers. Install $ pip ins

Phil Wang 2.3k Jan 03, 2023
A library for hidden semi-Markov models with explicit durations

hsmmlearn hsmmlearn is a library for unsupervised learning of hidden semi-Markov models with explicit durations. It is a port of the hsmm package for

Joris Vankerschaver 69 Dec 20, 2022
Semi-Supervised Learning, Object Detection, ICCV2021

End-to-End Semi-Supervised Object Detection with Soft Teacher By Mengde Xu*, Zheng Zhang*, Han Hu, Jianfeng Wang, Lijuan Wang, Fangyun Wei, Xiang Bai,

Microsoft 789 Dec 27, 2022
Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations. [2021]

Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations This repo contains the Pytorch implementation of our paper: Revisit

Wouter Van Gansbeke 80 Nov 20, 2022
Zalo AI challenge 2021 task hum to song

Zalo AI challenge 2021 task Hum to Song pipeline: Chuẩn bị dữ liệu cho quá trình train: Sửa các file đường dẫn trong config/preprocess.yaml raw_path:

Vo Van Phuc 105 Dec 16, 2022
OpenMMLab Model Deployment Toolset

Introduction English | 简体中文 MMDeploy is an open-source deep learning model deployment toolset. It is a part of the OpenMMLab project. Major features F

OpenMMLab 1.5k Dec 30, 2022
A Convolutional Transformer for Keyword Spotting

☢️ Audiomer ☢️ Audiomer: A Convolutional Transformer for Keyword Spotting [ arXiv ] [ Previous SOTA ] [ Model Architecture ] Results on SpeechCommands

49 Jan 27, 2022
Implementation of Neonatal Seizure Detection using EEG signals for deploying on edge devices including Raspberry Pi.

NeonatalSeizureDetection Description Link: https://arxiv.org/abs/2111.15569 Citation: @misc{nagarajan2021scalable, title={Scalable Machine Learn

Vishal Nagarajan 11 Nov 08, 2022
A project which aims to protect your privacy using inexpensive hardware and easily modifiable software

Protecting your privacy using an ESP32, an IR sensor and a python script This project, which I personally call the "never-gonna-catch-me-in-the-act-ev

8 Oct 10, 2022
Official pytorch implementation of the IrwGAN for unaligned image-to-image translation

IrwGAN (ICCV2021) Unaligned Image-to-Image Translation by Learning to Reweight [Update] 12/15/2021 All dataset are released, trained models and genera

37 Nov 09, 2022
交互式标注软件,暂定名 iann

iann 交互式标注软件,暂定名iann。 安装 按照官网介绍安装paddle。 安装其他依赖 pip install -r requirements.txt 运行 git clone https://github.com/PaddleCV-SIG/iann/ cd iann python iann

294 Dec 30, 2022
[ACL-IJCNLP 2021] Improving Named Entity Recognition by External Context Retrieving and Cooperative Learning

CLNER The code is for our ACL-IJCNLP 2021 paper: Improving Named Entity Recognition by External Context Retrieving and Cooperative Learning CLNER is a

71 Dec 08, 2022
Generalized Random Forests

generalized random forests A pluggable package for forest-based statistical estimation and inference. GRF currently provides non-parametric methods fo

GRF Labs 781 Dec 25, 2022
Pre-trained model, code, and materials from the paper "Impact of Adversarial Examples on Deep Learning Models for Biomedical Image Segmentation" (MICCAI 2019).

Adaptive Segmentation Mask Attack This repository contains the implementation of the Adaptive Segmentation Mask Attack (ASMA), a targeted adversarial

Utku Ozbulak 53 Jul 04, 2022
Tensorflow Repo for "DeepGCNs: Can GCNs Go as Deep as CNNs?"

DeepGCNs: Can GCNs Go as Deep as CNNs? In this work, we present new ways to successfully train very deep GCNs. We borrow concepts from CNNs, mainly re

Guohao Li 612 Nov 15, 2022
Deep Q-learning for playing chrome dino game

[PYTORCH] Deep Q-learning for playing Chrome Dino

Viet Nguyen 68 Dec 05, 2022
Stroke-predictions-ml-model - Machine learning model to predict individuals chances of having a stroke

stroke-predictions-ml-model machine learning model to predict individuals chance

Alex Volchek 1 Jan 03, 2022
Submanifold sparse convolutional networks

Submanifold Sparse Convolutional Networks This is the PyTorch library for training Submanifold Sparse Convolutional Networks. Spatial sparsity This li

Facebook Research 1.8k Jan 06, 2023
VOneNet: CNNs with a Primary Visual Cortex Front-End

VOneNet: CNNs with a Primary Visual Cortex Front-End A family of biologically-inspired Convolutional Neural Networks (CNNs). VOneNets have the followi

The DiCarlo Lab at MIT 99 Dec 22, 2022
TensorFlow code for the neural network presented in the paper: "Structural Language Models of Code" (ICML'2020)

SLM: Structural Language Models of Code This is an official implementation of the model described in: "Structural Language Models of Code" [PDF] To ap

73 Nov 06, 2022