Tensorflow implementation of Swin Transformer model.

Overview

Swin Transformer (Tensorflow)

Tensorflow reimplementation of Swin Transformer model.

Based on Official Pytorch implementation. image

Requirements

  • tensorflow >= 2.4.1

Pretrained Swin Transformer Checkpoints

ImageNet-1K and ImageNet-22K Pretrained Checkpoints

name pretrain resolution [email protected] #params model
swin_tiny_224 ImageNet-1K 224x224 81.2 28M github
swin_small_224 ImageNet-1K 224x224 83.2 50M github
swin_base_224 ImageNet-22K 224x224 85.2 88M github
swin_base_384 ImageNet-22K 384x384 86.4 88M github
swin_large_224 ImageNet-22K 224x224 86.3 197M github
swin_large_384 ImageNet-22K 384x384 87.3 197M github

Examples

Initializing the model:

from swintransformer import SwinTransformer

model = SwinTransformer('swin_tiny_224', num_classes=1000, include_top=True, pretrained=False)

You can use a pretrained model like this:

import tensorflow as tf
from swintransformer import SwinTransformer

model = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
  SwinTransformer('swin_tiny_224', include_top=False, pretrained=True),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

If you use a pretrained model with TPU on kaggle, specify use_tpu option:

import tensorflow as tf
from swintransformer import SwinTransformer

model = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
  SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

Example: TPU training on Kaggle

Citation

@article{liu2021Swin,
  title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
  author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
  journal={arXiv preprint arXiv:2103.14030},
  year={2021}
}
Comments
  • no module name 'swintransformer' error

    no module name 'swintransformer' error

    I wounder where the from swintransformer import SwinTransformer come from? I tried to pip install it, it also said that there is no such module. How can I overcome this problem?

    opened by HunarAA 2
  • Pretrained Swin-Transformer for multiple output

    Pretrained Swin-Transformer for multiple output

    Hi rishigami,

    Thank you for the implementation in Tensorflow. I am trying to use the Swin Transformer for a classification problem with multiple outputs. In your guide on how to use a pertained model you put it in a Sequential mode, but in this way I am not able to stack multiple dense layer for the multiple classification, could you help me understand how can I adapt your TF code to my problem, using it in a Functional API way maybe?

    opened by imanuelroz 2
  • NotImplementedError during model save

    NotImplementedError during model save

    I have defined a model as follows:

    def buildModel(LR = LR):
        backbone = SwinTransformer('swin_large_224', num_classes=None, include_top=False, pretrained=True, use_tpu=False)
        
        inp = L.Input(shape=(224,224,3))
        emb = backbone(inp)
        out = L.Dense(1,activation="relu")(emb)
        
        model = tf.keras.Model(inputs=inp,outputs=out)
        optimizer = tf.keras.optimizers.Adam(lr = LR)
        model.compile(loss="mse",optimizer=optimizer,metrics=[tf.keras.metrics.RootMeanSquaredError()])
        return model
    

    Now when I save this model using model.save("./model.hdf5") I get the following error:

    NotImplementedError                       Traceback (most recent call last)
    /tmp/ipykernel_43/131311624.py in <module>
    ----> 1 model.save("model.hdf5")
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
       2000     # pylint: enable=line-too-long
       2001     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
    -> 2002                     signatures, options, save_traces)
       2003 
       2004   def save_weights(self,
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
        152           'or using `save_weights`.')
        153     hdf5_format.save_model_to_hdf5(
    --> 154         model, filepath, overwrite, include_optimizer)
        155   else:
        156     saved_model_save.save(model, filepath, overwrite, include_optimizer,
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
        113 
        114   try:
    --> 115     model_metadata = saving_utils.model_metadata(model, include_optimizer)
        116     for k, v in model_metadata.items():
        117       if isinstance(v, (dict, list, tuple)):
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
        156   except NotImplementedError as e:
        157     if require_config:
    --> 158       raise e
        159 
        160   metadata = dict(
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
        153   model_config = {'class_name': model.__class__.__name__}
        154   try:
    --> 155     model_config['config'] = model.get_config()
        156   except NotImplementedError as e:
        157     if require_config:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_config(self)
        648 
        649   def get_config(self):
    --> 650     return copy.deepcopy(get_network_config(self))
        651 
        652   @classmethod
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_network_config(network, serialize_layer_fn)
       1347         filtered_inbound_nodes.append(node_data)
       1348 
    -> 1349     layer_config = serialize_layer_fn(layer)
       1350     layer_config['name'] = layer.name
       1351     layer_config['inbound_nodes'] = filtered_inbound_nodes
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
        248         return serialize_keras_class_and_config(
        249             name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
    --> 250       raise e
        251     serialization_config = {}
        252     for key, item in config.items():
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
        243     name = get_registered_name(instance.__class__)
        244     try:
    --> 245       config = instance.get_config()
        246     except NotImplementedError as e:
        247       if _SKIP_FAILED_SERIALIZATION:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in get_config(self)
       2252 
       2253   def get_config(self):
    -> 2254     raise NotImplementedError
       2255 
       2256   @classmethod
    
    NotImplementedError: 
    
    opened by Bibhash123 1
  • Invalid argument

    Invalid argument

    this is my basic model

    
    with tpu_strategy.scope():
        model = tf.keras.Sequential([
                            tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(data, mode="torch"), 
                                                                input_shape=[224,224, 3]),
                            SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
                            tf.keras.layers.Dense(1, activation='sigmoid')
                                            ])
    
    model.compile(loss = tf.keras.losses.BinaryCrossentropy(),
                              optimizer = tf.keras.optimizers.Adam(learning_rate=cfg['LEARNING_RATE']),
                              metrics   = RMSE)
    
    

    I am getting this error,

    (3) Invalid argument: {{function_node __inference_train_function_705020}} Reshape's input dynamic dimension is decomposed into multiple output dynamic dimensions, but the constraint is ambiguous and XLA can't infer the output dimension %reshape.12202 = f32[256,144,576]{2,1,0} reshape(f32[36864,576]{1,0} %transpose.12194), metadata={op_type="Reshape" op_name="sequential_40/swin_large_384/sequential_39/basic_layer_28/sequential_35/swin_transformer_block_169/window_attention_169/layers0/blocks1/attn/qkv/Tensordot"}. [[{{node TPUReplicate/_compile/_17658394825749957328/_4}}]] [[tpu_compile_succeeded_assert/_11424487196827204192/_5/_209]]

    opened by AliKayhanAtay 1
  • relative_position_bias_table initialization

    relative_position_bias_table initialization

    Hi, In the official code, relative_position_bias_table is initialized in a truncated normal distribution. Is that part missing in this repo?

    Official code: https://github.com/microsoft/Swin-Transformer/blob/6bbd83ca617db8480b2fb9b335c476ffaf5afb1a/models/swin_transformer.py#L110

    This implem https://github.com/rishigami/Swin-Transformer-TF/blob/8986ca7b0e1f984437db2d8f17e0ecd87fadcd4f/swintransformer/model.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L70

    opened by gathierry 1
  • Image size other than default ones doesn't work

    Image size other than default ones doesn't work

    • Notebook: https://colab.research.google.com/drive/1nqYkQCUzShkVdqGxW4TyMrtAb0n5MBZR#scrollTo=G9ZVlphmqD7d Issue:
    • In swin_tiny_224 I've tried multiple of 224, 512x512, multiple of window_size. But nothing seems to work other than the 224x224.
    • Same goes for swin_large_384, only default size 384x384 works.

    I'm wondering if this is expected behavior or not. Is there any way to make it work for non-square image?

    opened by awsaf49 1
  • Added 3D support for SwinTransformerModel, ie for medical imaging tasks

    Added 3D support for SwinTransformerModel, ie for medical imaging tasks

    Tested and working, ie:

    IMAGE_SIZE = [112, 112, 112]
    NUM_CLASSES = 10
    
    model_3d = tf.keras.Sequential([
      swin_transformer_nd.SwinTransformerModel(img_size=IMAGE_SIZE, patch_size=(4, 4, 4), depths=[2, 2, 6]),
      tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
    ])
    model_3d.compile(tf.keras.optimizers.Adam(), "categorical_crossentropy")
    
    for i in range(100):
        x = np.zeros([1, *IMAGE_SIZE, 1])
        y = tf.zeros([1, NUM_CLASSES])
        
        model_3d.fit(x, y)
        print("Trained on a batch")
    
    opened by MohamadZeina 0
  • Could you provide weights convert script?

    Could you provide weights convert script?

    I tried code and weights you provided, and find the performance is bad. Could you pleaase to provide weights convert script for me to figure out this issue?

    Many thanks

    opened by edwardyehuang 0
  • tf load model is erro

    tf load model is erro

    import tensorflow as tf from swintransformer import SwinTransformer model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]), SwinTransformer('swin_tiny_224', include_top=False, pretrained=True), tf.keras.layers.Dense(NUM_CLASSES, activation='softmax') ])

    tf can't load pre trained model。this step is errro

    opened by jangjiun 0
  • Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel)

    Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel)

    Has anyone tried to use the pretrained model with TimeDistributed layer ?

    model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), 
    input_shape=[224,224, 3]), SwinTransformer('swin_base_224', include_top=False, pretrained=True)])
    
    model_f = models.Sequential()
    	model.add(TimeDistributed(model, input_shape= (8,224,224,3)) 
    
    

    I get the following error:

    NotImplementedError: Exception encountered when calling layer "time_distributed" (type TimeDistributed).
    
    Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel).
    
    Call arguments received by layer "time_distributed" (type TimeDistributed):
      • inputs=tf.Tensor(shape=(None, 8, 224, 224, 3), dtype=float32)
      • training=False
    
    
    opened by atelili 0
Releases(v0.1-tf-swin-weights)
Customer-Transaction-Analysis - This analysis is based on a synthesised transaction dataset containing 3 months worth of transactions for 100 hypothetical customers.

Customer-Transaction-Analysis - This analysis is based on a synthesised transaction dataset containing 3 months worth of transactions for 100 hypothetical customers. It contains purchases, recurring

Ayodeji Yekeen 1 Jan 01, 2022
Implementation of Squeezenet in pytorch, pretrained models on Cifar 10 data to come

Pytorch Squeeznet Pytorch implementation of Squeezenet model as described in https://arxiv.org/abs/1602.07360 on cifar-10 Data. The definition of Sque

gaurav pathak 86 Oct 28, 2022
Official Pytorch implementation of Meta Internal Learning

Official Pytorch implementation of Meta Internal Learning

10 Aug 24, 2022
A annotation of yolov5-5.0

代码版本:0714 commit #4000 $ git clone https://github.com/ultralytics/yolov5 $ cd yolov5 $ git checkout 720aaa65c8873c0d87df09e3c1c14f3581d4ea61 这个代码只是注释版

Laughing 229 Dec 17, 2022
Vis2Mesh: Efficient Mesh Reconstruction from Unstructured Point Clouds of Large Scenes with Learned Virtual View Visibility ICCV2021

Vis2Mesh This is the offical repository of the paper: Vis2Mesh: Efficient Mesh Reconstruction from Unstructured Point Clouds of Large Scenes with Lear

71 Dec 25, 2022
Adversarial-autoencoders - Tensorflow implementation of Adversarial Autoencoders

Adversarial Autoencoders (AAE) Tensorflow implementation of Adversarial Autoencoders (ICLR 2016) Similar to variational autoencoder (VAE), AAE imposes

Qian Ge 236 Nov 13, 2022
Person Re-identification

Person Re-identification Final project of Computer Vision Table of content Person Re-identification Table of content Students: Proposed method Dataset

Nguyễn Hoàng Quân 4 Jun 17, 2021
DyNet: The Dynamic Neural Network Toolkit

The Dynamic Neural Network Toolkit General Installation C++ Python Getting Started Citing Releases and Contributing General DyNet is a neural network

Chris Dyer's lab @ LTI/CMU 3.3k Jan 06, 2023
Official implementation of "OpenPifPaf: Composite Fields for Semantic Keypoint Detection and Spatio-Temporal Association" in PyTorch.

openpifpaf Continuously tested on Linux, MacOS and Windows: New 2021 paper: OpenPifPaf: Composite Fields for Semantic Keypoint Detection and Spatio-Te

VITA lab at EPFL 50 Dec 29, 2022
Event sourced bank - A wide-and-shallow example using the Python event sourcing library

Event Sourced Bank A "wide but shallow" example of using the Python event sourci

3 Mar 09, 2022
TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

TorchMultimodal (Alpha Release) Introduction TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

Meta Research 663 Jan 06, 2023
Implementation of Auto-Conditioned Recurrent Networks for Extended Complex Human Motion Synthesis

acLSTM_motion This folder contains an implementation of acRNN for the CMU motion database written in Pytorch. See the following links for more backgro

Yi_Zhou 61 Sep 07, 2022
Official implementation for the paper "SAPE: Spatially-Adaptive Progressive Encoding for Neural Optimization".

SAPE Project page Paper Official implementation for the paper "SAPE: Spatially-Adaptive Progressive Encoding for Neural Optimization". Environment Cre

36 Dec 09, 2022
TensorFlow Implementation of Unsupervised Cross-Domain Image Generation

Domain Transfer Network (DTN) TensorFlow implementation of Unsupervised Cross-Domain Image Generation. Requirements Python 2.7 TensorFlow 0.12 Pickle

Yunjey Choi 865 Nov 17, 2022
Distributed Evolutionary Algorithms in Python

DEAP DEAP is a novel evolutionary computation framework for rapid prototyping and testing of ideas. It seeks to make algorithms explicit and data stru

Distributed Evolutionary Algorithms in Python 4.9k Jan 05, 2023
A static analysis library for computing graph representations of Python programs suitable for use with graph neural networks.

python_graphs This package is for computing graph representations of Python programs for machine learning applications. It includes the following modu

Google Research 258 Dec 29, 2022
Object detection using yolo-tiny model and opencv used as backend

Object detection Algorithm used : Yolo algorithm Backend : opencv Library required: opencv = 4.5.4-dev' Quick Overview about structure 1) main.py Load

2 Jul 06, 2022
Contrastive Multi-View Representation Learning on Graphs

Contrastive Multi-View Representation Learning on Graphs This work introduces a self-supervised approach based on contrastive multi-view learning to l

Kaveh 208 Dec 23, 2022
[제 13회 투빅스 컨퍼런스] OK Mugle! - 장르부터 멜로디까지, Content-based Music Recommendation

Ok Mugle! 🎵 장르부터 멜로디까지, Content-based Music Recommendation 'Ok Mugle!'은 제13회 투빅스 컨퍼런스(2022.01.15)에서 진행한 음악 추천 프로젝트입니다. Description 📖 본 프로젝트에서는 Kakao

SeongBeomLEE 5 Oct 09, 2022
Tutorials, assignments, and competitions for MIT Deep Learning related courses.

MIT Deep Learning This repository is a collection of tutorials for MIT Deep Learning courses. More added as courses progress. Tutorial: Deep Learning

Lex Fridman 9.5k Jan 07, 2023