Includes PyTorch -> Keras model porting code for ConvNeXt family of models with fine-tuning and inference notebooks.

Overview

ConvNeXt-TF

This repository provides TensorFlow / Keras implementations of different ConvNeXt [1] variants. It also provides the TensorFlow / Keras models that have been populated with the original ConvNeXt pre-trained weights available from [2]. These models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model objects and one can call all the utility functions on them (example: .summary()).

As of today, all the TensorFlow / Keras variants of the models listed here are available in this repository except for the isotropic ones. This list includes the ImageNet-1k as well as ImageNet-21k models.

Refer to the "Using the models" section to get started. Additionally, here's a related blog post that jots down my experience.

Conversion

TensorFlow / Keras implementations are available in models/convnext_tf.py. Conversion utilities are in convert.py.

Models

The converted models are available on TF-Hub.

There should be a total of 15 different models each having two variants: classifier and feature extractor. You can load any model and get started like so:

import tensorflow as tf

model_gcs_path = "gs://tfhub-modules/sayakpaul/convnext_tiny_1k_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)
print(model.summary(expand_nested=True))

The model names are interpreted as follows:

  • convnext_large_21k_1k_384: This means that the model was first pre-trained on the ImageNet-21k dataset and was then fine-tuned on the ImageNet-1k dataset. Resolution used during pre-training and fine-tuning: 384x384. large denotes the topology of the underlying model.
  • convnext_large_1k_224: Means that the model was pre-trained on the ImageNet-1k dataset with a resolution of 224x224.

Results

Results are on ImageNet-1k validation set (top-1 accuracy).

name original [email protected] keras [email protected]
convnext_tiny_1k_224 82.1 81.312
convnext_small_1k_224 83.1 82.392
convnext_base_1k_224 83.8 83.28
convnext_base_1k_384 85.1 84.876
convnext_large_1k_224 84.3 83.844
convnext_large_1k_384 85.5 85.376
convnext_base_21k_1k_224 85.8 85.364
convnext_base_21k_1k_384 86.8 86.79
convnext_large_21k_1k_224 86.6 86.36
convnext_large_21k_1k_384 87.5 87.504
convnext_xlarge_21k_1k_224 87.0 86.732
convnext_xlarge_21k_1k_384 87.8 87.68

Differences in the results are primarily because of the differences in the library implementations especially how image resizing is implemented in PyTorch and TensorFlow. Results can be verified with the code in i1k_eval. Logs are available at this URL.

Using the models

Pre-trained models:

Randomly initialized models:

from models.convnext_tf import get_convnext_model

convnext_tiny = get_convnext_model()
print(convnext_tiny.summary(expand_nested=True))

To view different model configurations, refer here.

Upcoming (contributions welcome)

  • Align layer initializers (useful if someone wanted to train the models from scratch)
  • Allow the models to accept arbitrary shapes (useful for downstream tasks)
  • Convert the isotropic models as well
  • Fine-tuning notebook (thanks to awsaf49)
  • Off-the-shelf-classification notebook
  • Publish models on TF-Hub

References

[1] ConvNeXt paper: https://arxiv.org/abs/2201.03545

[2] Official ConvNeXt code: https://github.com/facebookresearch/ConvNeXt

Acknowledgements

Owner
Sayak Paul
ML Engineer at @carted | One PR at a time
Sayak Paul
Google AI Open Images - Object Detection Track: Open Solution

Google AI Open Images - Object Detection Track: Open Solution This is an open solution to the Google AI Open Images - Object Detection Track ๐Ÿ˜ƒ More c

minerva.ml 46 Jun 22, 2022
A unified framework to jointly model images, text, and human attention traces.

connect-caption-and-trace This repository contains the reference code for our paper Connecting What to Say With Where to Look by Modeling Human Attent

Meta Research 73 Oct 24, 2022
ROMP: Monocular, One-stage, Regression of Multiple 3D People, ICCV21

Monocular, One-stage, Regression of Multiple 3D People ROMP, accepted by ICCV 2021, is a concise one-stage network for multi-person 3D mesh recovery f

Yu Sun 937 Jan 04, 2023
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
Cortex-compatible model server for Python and TensorFlow

Nucleus model server Nucleus is a model server for TensorFlow and generic Python models. It is compatible with Cortex clusters, Kubernetes clusters, a

Cortex Labs 14 Nov 27, 2022
RCDNet: A Model-driven Deep Neural Network for Single Image Rain Removal (CVPR2020)

RCDNet: A Model-driven Deep Neural Network for Single Image Rain Removal (CVPR2020) Hong Wang, Qi Xie, Qian Zhao, and Deyu Meng [PDF] [Supplementary M

Hong Wang 6 Sep 27, 2022
RCD: Relation Map Driven Cognitive Diagnosis for Intelligent Education Systems

RCD: Relation Map Driven Cognitive Diagnosis for Intelligent Education Systems This is our implementation for the paper: Weibo Gao, Qi Liu*, Zhenya Hu

BigData Lab @USTC ไธญ็ง‘ๅคงๅคงๆ•ฐๆฎๅฎž้ชŒๅฎค 10 Oct 16, 2022
PyTorch implementation of the Quasi-Recurrent Neural Network - up to 16 times faster than NVIDIA's cuDNN LSTM

Quasi-Recurrent Neural Network (QRNN) for PyTorch Updated to support multi-GPU environments via DataParallel - see the the multigpu_dataparallel.py ex

Salesforce 1.3k Dec 28, 2022
A new GCN model for Point Cloud Analyse

Pytorch Implementation of PointNet and PointNet++ This repo is implementation for VA-GCN in pytorch. Classification (ModelNet10/40) Data Preparation D

12 Feb 02, 2022
Classification of EEG data using Deep Learning

Graduation-Project Classification of EEG data using Deep Learning Epilepsy is the most common neurological disease in the world. Epilepsy occurs as a

Osman Alpaydฤฑn 5 Jun 24, 2022
[MICCAI'20] AlignShift: Bridging the Gap of Imaging Thickness in 3D Anisotropic Volumes

AlignShift NEW: Code for our new MICCAI'21 paper "Asymmetric 3D Context Fusion for Universal Lesion Detection" will also be pushed to this repository

Medical 3D Vision 42 Jan 06, 2023
wgan, wgan2(improved, gp), infogan, and dcgan implementation in lasagne, keras, pytorch

Generative Adversarial Notebooks Collection of my Generative Adversarial Network implementations Most codes are for python3, most notebooks works on C

tjwei 1.5k Dec 16, 2022
Fully Convolutional DenseNet (A.K.A 100 layer tiramisu) for semantic segmentation of images implemented in TensorFlow.

FC-DenseNet-Tensorflow This is a re-implementation of the 100 layer tiramisu, technically a fully convolutional DenseNet, in TensorFlow (Tiramisu). Th

Hasnain Raza 121 Oct 12, 2022
Invertible conditional GANs for image editing

Invertible Conditional GANs This is the implementation of the IcGAN model proposed in our paper: Invertible Conditional GANs for image editing. Novemb

Guim 278 Dec 12, 2022
๐Ÿ… The Most Comprehensive List of Kaggle Solutions and Ideas ๐Ÿ…

๐Ÿ… Collection of Kaggle Solutions and Ideas ๐Ÿ…

Farid Rashidi 2.3k Jan 08, 2023
GPOEO is a micro-intrusive GPU online energy optimization framework for iterative applications

GPOEO GPOEO is a micro-intrusive GPU online energy optimization framework for iterative applications. We also implement ODPP [1] as a comparison. [1]

็‘ž้›ช่ฝป้ฃ 8 Sep 10, 2022
AI grand challenge 2020 Repo (Speech Recognition Track)

KorBERT๋ฅผ ํ™œ์šฉํ•œ ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ ๊ธฐ๋ฐ˜ ์œ„ํ˜‘ ์ƒํ™ฉ์ธ์ง€(2020 ์ธ๊ณต์ง€๋Šฅ ๊ทธ๋žœ๋“œ ์ฑŒ๋ฆฐ์ง€) ๋ณธ ํ”„๋กœ์ ํŠธ๋Š” ETRI์—์„œ ์ œ๊ณต๋œ ํ•œ๊ตญ์–ด korBERT ๋ชจ๋ธ์„ ํ™œ์šฉํ•˜์—ฌ ํญ๋ ฅ ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ํ…์ŠคํŠธ๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋‹ค์–‘ํ•œ ๋ถ„๋ฅ˜ ๋ชจ๋ธ๋“ค์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ๋ณธ ๊ฐœ๋ฐœ์ž๋“ค์ด ์ฐธ์—ฌํ•œ 2020 ์ธ๊ณต์ง€

Young-Seok Choi 23 Jan 25, 2022
Open CV - Convert a picture to look like a cartoon sketch in python

Use the video https://www.youtube.com/watch?v=k7cVPGpnels for initial learning.

Sammith S Bharadwaj 3 Jan 29, 2022
Second Order Optimization and Curvature Estimation with K-FAC in JAX.

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX KFAC-JAX

DeepMind 90 Dec 22, 2022