Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Overview

Swin-Transformer-Tensorflow

A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" to TensorFlow 2.

The official Pytorch implementation can be found here.

Introduction:

Swin Transformer Architecture Diagram

Swin Transformer (the name Swin stands for Shifted window) is initially described in arxiv, which capably serves as a general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection.

Swin Transformer achieves strong performance on COCO object detection (58.7 box AP and 51.1 mask AP on test-dev) and ADE20K semantic segmentation (53.5 mIoU on val), surpassing previous models by a large margin.

Usage:

1. To Run a Pre-trained Swin Transformer

Swin-T:

python main.py --cfg configs/swin_tiny_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-S:

python main.py --cfg configs/swin_small_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-B:

python main.py --cfg configs/swin_base_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

The possible options for cfg and weights_type are:

cfg weights_type 22K model 1K Model
configs/swin_tiny_patch4_window7_224.yaml imagenet_1k - github
configs/swin_small_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22k github -
configs/swin_base_patch4_window12_384.yaml imagenet_22k github -
configs/swin_large_patch4_window7_224.yaml imagenet_22k github -
configs/swin_large_patch4_window12_384.yaml imagenet_22k github -

2. Create custom models

To create a custom classification model:

import argparse

import tensorflow as tf

from config import get_config
from models.build import build_model

parser = argparse.ArgumentParser('Custom Swin Transformer')

parser.add_argument(
    '--cfg',
    type=str,
    metavar="FILE",
    help='path to config file',
    default="CUSTOM_YAML_FILE_PATH"
)
parser.add_argument(
    '--resume',
    type=int,
    help='Whether or not to resume training from pretrained weights',
    choices={0, 1},
    default=1,
)
parser.add_argument(
    '--weights_type',
    type=str,
    help='Type of pretrained weight file to load including number of classes',
    choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"},
    default="imagenet_1k",
)

args = parser.parse_args()
custom_config = get_config(args, include_top=False)

swin_transformer = tf.keras.Sequential([
    build_model(config=custom_config, load_pretrained=args.resume, weights_type=args.weights_type),
    tf.keras.layers.Dense(CUSTOM_NUM_CLASSES)
)

Model ouputs are logits, so don't forget to include softmax in training/inference!!

You can easily customize the model configs with custom YAML files. Predefined YAML files provided by Microsoft are located in the configs directory.

3. Convert PyTorch pretrained weights into Tensorflow checkpoints

We provide a python script with which we convert official PyTorch weights into Tensorflow checkpoints.

$ python convert_weights.py --cfg config_file --weights the_path_to_pytorch_weights --weights_type type_of_pretrained_weights --output the_path_to_output_tf_weights

TODO:

  • Translate model code over to TensorFlow
  • Load PyTorch pretrained weights into TensorFlow model
  • Write trainer code
  • Reproduce results presented in paper
    • Object Detection
  • Reproduce training efficiency of official code in TensorFlow

Citations:

@misc{liu2021swin,
      title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, 
      author={Ze Liu and Yutong Lin and Yue Cao and Han Hu and Yixuan Wei and Zheng Zhang and Stephen Lin and Baining Guo},
      year={2021},
      eprint={2103.14030},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
You might also like...
This is an official implementation of our CVPR 2021 paper "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression" (https://arxiv.org/abs/2104.02300)

Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression Introduction In this paper, we are interested in the bottom-up paradigm of estima

Non-Official Pytorch implementation of
Non-Official Pytorch implementation of "Face Identity Disentanglement via Latent Space Mapping" https://arxiv.org/abs/2005.07728 Using StyleGAN2 instead of StyleGAN

Face Identity Disentanglement via Latent Space Mapping - Implement in pytorch with StyleGAN 2 Description Pytorch implementation of the paper Face Ide

Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.
Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

PAWS-TF 🐾 Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS)

A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks
A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks

Spiking Neural Network training with EventProp This is an unofficial PyTorch implemenation of EventProp, a method to compute exact gradients for Spiki

Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Tensorflow implementation of Semi-supervised Sequence Learning (https://arxiv.org/abs/1511.01432)
Tensorflow implementation of Semi-supervised Sequence Learning (https://arxiv.org/abs/1511.01432)

Transfer Learning for Text Classification with Tensorflow Tensorflow implementation of Semi-supervised Sequence Learning(https://arxiv.org/abs/1511.01

PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

https://arxiv.org/abs/2102.11005
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

Comments
  • Custom Swin Transformer: error: unrecognized arguments

    Custom Swin Transformer: error: unrecognized arguments

    parser = argparse.ArgumentParser('Custom Swin Transformer')

    parser.add_argument( '--cfg', type=str, metavar="FILE", help='/content/Swin-Transformer-Tensorflow/configs/swin_tiny_patch4_window7_224.yaml', default="CUSTOM_YAML_FILE_PATH" ) parser.add_argument( '--resume', type=int, help=1, choices={0, 1}, default=1, ) parser.add_argument( '--weights_type', type=str, help='imagenet_22k', choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"}, default="imagenet_1k", )

    args = parser.parse_args() custom_config = get_config(args, include_top=False)

    i am trying to use it but it throws an error below

    usage: Custom Swin Transformer [-h] [--cfg FILE] [--resume {0,1}] [--weights_type {imagenet_22kto1k,imagenet_1k,imagenet_22k}] Custom Swin Transformer: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-ee309a98-1f20-4bb7-aa12-c2980aea076c.json An exception has occurred, use %tb to see the full traceback.

    SystemExit: 2

    opened by AliKayhanAtay 1
  • train dataset

    train dataset

    Thank you for Thank you for providing your code. I've been running the pretrained model, and I'd like to know how to learn about custom data from the code you provided and how to transfer learning to custom data using the pretrained model. Thank you.

    opened by hoyeoung 1
The Self-Supervised Learner can be used to train a classifier with fewer labeled examples needed using self-supervised learning.

Published by SpaceML • About SpaceML • Quick Colab Example Self-Supervised Learner The Self-Supervised Learner can be used to train a classifier with

SpaceML 92 Nov 30, 2022
Code for Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation (CVPR 2021)

Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation (CVPR 2021) Hang Zhou, Yasheng Sun, Wayne Wu, Chen Cha

Hang_Zhou 628 Dec 28, 2022
Collapse by Conditioning: Training Class-conditional GANs with Limited Data

Collapse by Conditioning: Training Class-conditional GANs with Limited Data Moha

Mohamad Shahbazi 33 Dec 06, 2022
[ICLR 2021] Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization

Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization Kaidi Cao, Yining Chen, Junwei Lu, Nikos Arechiga, Adrien Gaidon, Tengyu Ma

Kaidi Cao 29 Oct 20, 2022
Histology images query (unsupervised)

110-1-NTU-DBME5028-Histology-images-query Final Project: Histology images query (unsupervised) Kaggle: https://www.kaggle.com/c/histology-images-query

1 Jan 05, 2022
Improving Deep Network Debuggability via Sparse Decision Layers

Improving Deep Network Debuggability via Sparse Decision Layers This repository contains the code for our paper: Leveraging Sparse Linear Layers for D

Madry Lab 35 Nov 14, 2022
ParmeSan: Sanitizer-guided Greybox Fuzzing

ParmeSan: Sanitizer-guided Greybox Fuzzing ParmeSan is a sanitizer-guided greybox fuzzer based on Angora. Published Work USENIX Security 2020: ParmeSa

VUSec 158 Dec 31, 2022
Implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT : Cross-Attention Multi-Scale Vision Transformer for Image Classification This is an unofficial PyTorch implementation of CrossViT: Cross-Att

Rishikesh (ऋषिकेश) 103 Nov 25, 2022
Distinguishing Commercial from Editorial Content in News

Distinguishing Commercial from Editorial Content in News In this repository you can find the following: An anonymized version of the data used for my

Timo Kats 3 Sep 26, 2022
Analyzes your GitHub Profile and presents you with a report on how likely you are to become the next MLH Fellow!

Fellowship Prediction GitHub Profile Comparative Analysis Tool Built with BentoML Table of Contents: Features Disclaimer Technologies Used Contributin

Damir Temir 51 Dec 29, 2022
Implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021).

[PDF] | [Slides] The official implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021 Long talk) Installation Inst

MilaGraph 117 Dec 09, 2022
PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020).

Scaffold-Federated-Learning PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020). Environment numpy=

KI 30 Dec 29, 2022
An improvement of FasterGICP: Acceptance-rejection Sampling based 3D Lidar Odometry

fasterGICP This package is an improvement of fast_gicp Please cite our paper if possible. W. Jikai, M. Xu, F. Farzin, D. Dai and Z. Chen, "FasterGICP:

79 Dec 31, 2022
Official code for On Path Integration of Grid Cells: Group Representation and Isotropic Scaling (NeurIPS 2021)

On Path Integration of Grid Cells: Group Representation and Isotropic Scaling This repo contains the official implementation for the paper On Path Int

Ruiqi Gao 39 Nov 10, 2022
A motion detection system with RaspberryPi, OpenCV, Python

Human Detection System using Raspberry Pi Functionality Activates a relay on detecting motion. You may need following components to get the expected R

Omal Perera 55 Dec 04, 2022
Code for the paper "Controllable Video Captioning with an Exemplar Sentence"

SMCG Code for the paper "Controllable Video Captioning with an Exemplar Sentence" Introduction We investigate a novel and challenging task, namely con

10 Dec 04, 2022
The official implementation of Equalization Loss v1 & v2 (CVPR 2020, 2021) based on MMDetection.

The Equalization Losses for Long-tailed Object Detection and Instance Segmentation This repo is official implementation CVPR 2021 paper: Equalization

Jingru Tan 129 Dec 16, 2022
Code for IntraQ, PyTorch implementation of our paper under review

IntraQ: Learning Synthetic Images with Intra-Class Heterogeneity for Zero-Shot Network Quantization paper Requirements Python = 3.7.10 Pytorch == 1.7

1 Nov 19, 2021
Implementation of the SUMO (Slim U-Net trained on MODA) model

SUMO - Slim U-Net trained on MODA Implementation of the SUMO (Slim U-Net trained on MODA) model as described in: TODO: add reference to paper once ava

6 Nov 19, 2022
Patch-Diffusion Code (AAAI2022)

Patch-Diffusion This is an official PyTorch implementation of "Patch Diffusion: A General Module for Face Manipulation Detection" in AAAI2022. Require

H 7 Nov 02, 2022