Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Overview

Swin-Transformer-Tensorflow

A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" to TensorFlow 2.

The official Pytorch implementation can be found here.

Introduction:

Swin Transformer Architecture Diagram

Swin Transformer (the name Swin stands for Shifted window) is initially described in arxiv, which capably serves as a general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection.

Swin Transformer achieves strong performance on COCO object detection (58.7 box AP and 51.1 mask AP on test-dev) and ADE20K semantic segmentation (53.5 mIoU on val), surpassing previous models by a large margin.

Usage:

1. To Run a Pre-trained Swin Transformer

Swin-T:

python main.py --cfg configs/swin_tiny_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-S:

python main.py --cfg configs/swin_small_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-B:

python main.py --cfg configs/swin_base_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

The possible options for cfg and weights_type are:

cfg weights_type 22K model 1K Model
configs/swin_tiny_patch4_window7_224.yaml imagenet_1k - github
configs/swin_small_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22k github -
configs/swin_base_patch4_window12_384.yaml imagenet_22k github -
configs/swin_large_patch4_window7_224.yaml imagenet_22k github -
configs/swin_large_patch4_window12_384.yaml imagenet_22k github -

2. Create custom models

To create a custom classification model:

import argparse

import tensorflow as tf

from config import get_config
from models.build import build_model

parser = argparse.ArgumentParser('Custom Swin Transformer')

parser.add_argument(
    '--cfg',
    type=str,
    metavar="FILE",
    help='path to config file',
    default="CUSTOM_YAML_FILE_PATH"
)
parser.add_argument(
    '--resume',
    type=int,
    help='Whether or not to resume training from pretrained weights',
    choices={0, 1},
    default=1,
)
parser.add_argument(
    '--weights_type',
    type=str,
    help='Type of pretrained weight file to load including number of classes',
    choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"},
    default="imagenet_1k",
)

args = parser.parse_args()
custom_config = get_config(args, include_top=False)

swin_transformer = tf.keras.Sequential([
    build_model(config=custom_config, load_pretrained=args.resume, weights_type=args.weights_type),
    tf.keras.layers.Dense(CUSTOM_NUM_CLASSES)
)

Model ouputs are logits, so don't forget to include softmax in training/inference!!

You can easily customize the model configs with custom YAML files. Predefined YAML files provided by Microsoft are located in the configs directory.

3. Convert PyTorch pretrained weights into Tensorflow checkpoints

We provide a python script with which we convert official PyTorch weights into Tensorflow checkpoints.

$ python convert_weights.py --cfg config_file --weights the_path_to_pytorch_weights --weights_type type_of_pretrained_weights --output the_path_to_output_tf_weights

TODO:

  • Translate model code over to TensorFlow
  • Load PyTorch pretrained weights into TensorFlow model
  • Write trainer code
  • Reproduce results presented in paper
    • Object Detection
  • Reproduce training efficiency of official code in TensorFlow

Citations:

@misc{liu2021swin,
      title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, 
      author={Ze Liu and Yutong Lin and Yue Cao and Han Hu and Yixuan Wei and Zheng Zhang and Stephen Lin and Baining Guo},
      year={2021},
      eprint={2103.14030},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
You might also like...
This is an official implementation of our CVPR 2021 paper "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression" (https://arxiv.org/abs/2104.02300)

Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression Introduction In this paper, we are interested in the bottom-up paradigm of estima

Non-Official Pytorch implementation of
Non-Official Pytorch implementation of "Face Identity Disentanglement via Latent Space Mapping" https://arxiv.org/abs/2005.07728 Using StyleGAN2 instead of StyleGAN

Face Identity Disentanglement via Latent Space Mapping - Implement in pytorch with StyleGAN 2 Description Pytorch implementation of the paper Face Ide

Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.
Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

PAWS-TF 🐾 Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS)

A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks
A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks

Spiking Neural Network training with EventProp This is an unofficial PyTorch implemenation of EventProp, a method to compute exact gradients for Spiki

Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Tensorflow implementation of Semi-supervised Sequence Learning (https://arxiv.org/abs/1511.01432)
Tensorflow implementation of Semi-supervised Sequence Learning (https://arxiv.org/abs/1511.01432)

Transfer Learning for Text Classification with Tensorflow Tensorflow implementation of Semi-supervised Sequence Learning(https://arxiv.org/abs/1511.01

PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

https://arxiv.org/abs/2102.11005
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

Comments
  • Custom Swin Transformer: error: unrecognized arguments

    Custom Swin Transformer: error: unrecognized arguments

    parser = argparse.ArgumentParser('Custom Swin Transformer')

    parser.add_argument( '--cfg', type=str, metavar="FILE", help='/content/Swin-Transformer-Tensorflow/configs/swin_tiny_patch4_window7_224.yaml', default="CUSTOM_YAML_FILE_PATH" ) parser.add_argument( '--resume', type=int, help=1, choices={0, 1}, default=1, ) parser.add_argument( '--weights_type', type=str, help='imagenet_22k', choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"}, default="imagenet_1k", )

    args = parser.parse_args() custom_config = get_config(args, include_top=False)

    i am trying to use it but it throws an error below

    usage: Custom Swin Transformer [-h] [--cfg FILE] [--resume {0,1}] [--weights_type {imagenet_22kto1k,imagenet_1k,imagenet_22k}] Custom Swin Transformer: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-ee309a98-1f20-4bb7-aa12-c2980aea076c.json An exception has occurred, use %tb to see the full traceback.

    SystemExit: 2

    opened by AliKayhanAtay 1
  • train dataset

    train dataset

    Thank you for Thank you for providing your code. I've been running the pretrained model, and I'd like to know how to learn about custom data from the code you provided and how to transfer learning to custom data using the pretrained model. Thank you.

    opened by hoyeoung 1
The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Dice Loss for NLP Tasks This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020. Setup Install Package Dependencies The c

223 Dec 17, 2022
MCMC samplers for Bayesian estimation in Python, including Metropolis-Hastings, NUTS, and Slice

Sampyl May 29, 2018: version 0.3 Sampyl is a package for sampling from probability distributions using MCMC methods. Similar to PyMC3 using theano to

Mat Leonard 304 Dec 25, 2022
Multiple-Object Tracking with Transformer

TransTrack: Multiple-Object Tracking with Transformer Introduction TransTrack: Multiple-Object Tracking with Transformer Models Training data Training

Peize Sun 537 Jan 04, 2023
Codebase for the self-supervised goal reaching benchmark introduced in the LEXA paper

LEXA Benchmark Codebase for the self-supervised goal reaching benchmark introduced in the LEXA paper (Discovering and Achieving Goals via World Models

Oleg Rybkin 36 Dec 22, 2022
PyTorch framework for Deep Learning research and development.

Accelerated DL & RL PyTorch framework for Deep Learning research and development. It was developed with a focus on reproducibility, fast experimentati

Catalyst-Team 29 Jul 13, 2022
Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Şebnem 6 Jan 18, 2022
PyTorch implementation of the paper: Long-tail Learning via Logit Adjustment

logit-adj-pytorch PyTorch implementation of the paper: Long-tail Learning via Logit Adjustment This code implements the paper: Long-tail Learning via

Chamuditha Jayanga 53 Dec 23, 2022
This repository is an implementation of paper : Improving the Training of Graph Neural Networks with Consistency Regularization

CRGNN Paper : Improving the Training of Graph Neural Networks with Consistency Regularization Environments Implementing environment: GeForce RTX™ 3090

THUDM 28 Dec 09, 2022
To build a regression model to predict the concrete compressive strength based on the different features in the training data.

Cement-Strength-Prediction Problem Statement To build a regression model to predict the concrete compressive strength based on the different features

Ashish Kumar 4 Jun 11, 2022
Global Filter Networks for Image Classification

Global Filter Networks for Image Classification Created by Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, Jie Zhou This repository contains PyTorch

Yongming Rao 273 Dec 26, 2022
A tutorial on DataFrames.jl prepared for JuliaCon2021

JuliaCon2021 DataFrames.jl Tutorial This is a tutorial on DataFrames.jl prepared for JuliaCon2021. A video recording of the tutorial is available here

Bogumił Kamiński 106 Jan 09, 2023
PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimation

StructDepth PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimat

SJTU-ViSYS 112 Nov 28, 2022
Easy-to-use micro-wrappers for Gym and PettingZoo based RL Environments

SuperSuit introduces a collection of small functions which can wrap reinforcement learning environments to do preprocessing ('microwrappers'). We supp

Farama Foundation 357 Jan 06, 2023
A GUI for Face Recognition, based upon Docker, Tkinter, GPU and a camera device.

Face Recognition GUI This repository is a GUI version of Face Recognition by Adam Geitgey, where e.g. Docker and Tkinter are utilized. All the materia

Kasper Henriksen 6 Dec 05, 2022
PyTorch implementation of Histogram Layers from DeepHist: Differentiable Joint and Color Histogram Layers for Image-to-Image Translation

deep-hist PyTorch implementation of Histogram Layers from DeepHist: Differentiable Joint and Color Histogram Layers for Image-to-Image Translation PyT

Winfried Lötzsch 10 Dec 06, 2022
Out-of-Domain Human Mesh Reconstruction via Dynamic Bilevel Online Adaptation

DynaBOA Code repositoty for the paper: Out-of-Domain Human Mesh Reconstruction via Dynamic Bilevel Online Adaptation Shanyan Guan, Jingwei Xu, Michell

198 Dec 29, 2022
Code for the paper titled "Generalized Depthwise-Separable Convolutions for Adversarially Robust and Efficient Neural Networks" (NeurIPS 2021 Spotlight).

Generalized Depthwise-Separable Convolutions for Adversarially Robust and Efficient Neural Networks This repository contains the code and pre-trained

Hassan Dbouk 7 Dec 05, 2022
MoCoPnet - Deformable 3D Convolution for Video Super-Resolution

Deformable 3D Convolution for Video Super-Resolution Pytorch implementation of l

Xinyi Ying 28 Dec 15, 2022
Demonstration of the Model Training as a CI/CD System in Vertex AI

Model Training as a CI/CD System This project demonstrates the machine model training as a CI/CD system in GCP platform. You will see more detailed wo

Chansung Park 19 Dec 28, 2022
HDR Video Reconstruction: A Coarse-to-fine Network and A Real-world Benchmark Dataset (ICCV 2021)

Code for HDR Video Reconstruction HDR Video Reconstruction: A Coarse-to-fine Network and A Real-world Benchmark Dataset (ICCV 2021) Guanying Chen, Cha

Guanying Chen 64 Nov 19, 2022