Implementation of Supervised Contrastive Learning with AMP, EMA, SWA, and many other tricks

Overview

SupCon-Framework

The repo is an implementation of Supervised Contrastive Learning. It's based on another implementation, but with several differencies:

  • Fixed bugs (incorrect ResNet implementations, which leads to a very small max batch size),
  • Offers a lot of additional functionality (first of all, rich validation).

To be more precise, in this implementations you will find:

  • Augmentations with albumentations
  • Hyperparameters are moved to .yml configs
  • t-SNE visualizations
  • 2-step validation (for features before and after the projection head) using metrics like AMI, NMI, mAP, precision_at_1, etc with PyTorch Metric Learning.
  • Exponential Moving Average for a more stable training, and Stochastic Moving Average for a better generalization and just overall performance.
  • Automatic Mixed Precision (torch version) training in order to be able to train with a bigger batch size (roughly by a factor of 2).
  • LabelSmoothing loss, and LRFinder for the second stage of the training (FC).
  • TensorBoard logs, checkpoints
  • Support of timm models, and pytorch-optimizer

Install

  1. Clone the repo:
git clone https://github.com/ivanpanshin/SupCon-Framework && cd SupCon-Framework/
  1. Create a clean virtual environment
python3 -m venv venv
source venv/bin/activate
  1. Install dependencies
python -m pip install --upgrade pip
pip install -r requirements.txt

Training

In order to execute Cifar10 training run:

python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage1.yml
python swa.py --config_name configs/train/swa_supcon_resnet18_cifar10_stage1.yml
python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage2.yml
python swa.py --config_name configs/train/swa_supcon_resnet18_cifar10_stage2.yml

In order to run LRFinder on the second stage of the training, run:

python learning_rate_finder.py --config_name configs/train/lr_finder_supcon_resnet18_cifar10_stage2.yml

The process of training Cifar100 is exactly the same, just change config names from cifar10 to cifar100.

After that you can check the results of the training either in logs or runs directory. For example, in order to check tensorboard logs for the first stage of Cifar10 training, run:

tensorboard --logdir runs/supcon_first_stage_cifar10

Visualizations

This repo is supplied with t-SNE visualizations so that you can check embeddings you get after the training. Check t-SNE.ipynb for details.

Those are t-SNE visualizations for Cifar10 for validation and train with SupCon (top), and validation and train with CE (bottom).

Those are t-SNE visualizations for Cifar100 for validation and train with SupCon (top), and validation and train with CE (bottom).

Results

Model Stage Dataset Accuracy
ResNet18 Frist CIFAR10 95.9
ResNet18 Second CIFAR10 94.9
ResNet18 Frist CIFAR100 79.0
ResNet18 Second CIFAR100 77.9

Note that even though the accuracy on the second stage is lower, it's not always the case. In my experience, the difference between stages is usually around 1 percent, including the difference that favors the second stage.

Training time for the whole pipeline (without any early stopping) on CIFAR10 or CIFAR100 is around 4 hours (single 2080Ti with AMP). However, with reasonable early stopping that value goes down to around 2.5-3 hours.

Custom datasets

It's fairly easy to adapt this pipeline to custom datasets. First, you need to check tools/datasets.py for that. Second, add a new class for your dataset. The only guideline here is to follow the same augmentation logic, that is

        if self.second_stage:
            image = self.transform(image=image)['image']
        else:
            image = self.transform(image)

Third, add your dataset to DATASETS dict still inside tools/datasets.py, and you're good to go.

FAQ

  • Q: What hyperparameters I should try to change?

    A: First of all, learning rate. Second of all, try to change the augmentation policy. SupCon is build around "cropping + color jittering" scheme, so you can try changing the cropping size or the intensity of jittering. Check tools.utils.build_transforms for that.

  • Q: What backbone and batch size should I use?

    A: This is quite simple. Take the biggest backbone you can, and after that take the highest batch size your GPU can offer. The reason for that: SupCon is more prone (than regular classification training with CE/LabelSmoothing/etc) to improving with stronger backbones. Moverover, it has a property of explicit hard positive and negative mining. It means that the higher the batch size - the more difficult and helpful samples you supply to your model.

  • Q: Do I need the second stage of the training?

    A: Not necessarily. You can do classification based only on embeddings. In order to do that compute embeddings for the train set, and at inference time do the following: take a sample, compute its embedding, take the closest one from the training, take its class. To make this fast and efficient, you something like faiss for similarity search. Note that this is actually how validation is done in this repo. Moveover, during training you will see a metric precision_at_1. This is actually just accuracy based solely on embeddings.

  • Q: Should I use AMP?

    A: If your GPU has tensor cores (like 2080Ti) - yes. If it doesn't (like 1080Ti) - check the speed with AMP and without. If the speed dropped slightly (or even increased by a bit) - use it, since SupCon works better with bigger batch sizes.

  • Q: How should I use EMA?

    A: You only need to choose the ema_decay_per_epoch parameter in the config. The heuristic is fairly simple. If your dataset is big, then something as small as 0.3 will do just fine. And as your dataset gets smaller, you can increase ema_decay_per_epoch. Thanks to bonlime for this idea. I advice you to check his great pytorch tools repo, it's a hidden gem.

  • Q: Is it better than training with Cross Entropy/Label Smoothing/etc?

    A: Unfortunately, in my experience, it's much easier to get better results with something like CE. It's more stable, faster to train, and simply produces better or the same results. For instance, in case on CIFAR10/100 it's trivial to train ResNet18 up tp 96/81 percent respectively. Of cource, I've seen cased where SupCon performs better, but it takes quite a bit of work to make it outperform CE.

  • Q: How long should I train with SupCon?

    A: The answer is tricky. On one hand, authors of the original paper claim that the longer you train with SupCon, the better it gets. However, I did not observe such a behavior in my tests. So the only recommendation I can give is the following: start with 100 epochs for easy datasets (like CIFAR10/100), and 1000 for more industrial ones. Then - monitor the training process. If the validaton metric (such as precision_at_1) doesn't impove for several dozens of epochs - you can stop the training. You might incorporate early stopping for this reason into the pipeline.

Owner
Ivan Panshin
Machine Learning Engineer: CV, NLP, tabular data. Kaggle (top 0.003% worldwide) and Open Source
Ivan Panshin
CheckList-Api - Created with django rest framework and JWT(Json Web Tokens for Authentication)

CheckList Api created with django rest framework and JWT(Json Web Tokens for Aut

shantanu nimkar 1 Jan 24, 2022
Django server for Travel Mate (Project: nomad)

Travel Mate Server (Project: Nomad) Django 2.0 server for Travel Mate Contribute For new feature request in the app, open a new feature request on the

Travel Mate 41 May 29, 2022
This is a Token tool that gives you many options to harm the account.

Trabis-Token-Tool This is a Token tool that gives you many options to harm the account. Utilities With this tools you can do things as : ·Delete all t

Steven 2 Feb 13, 2022
This Python based program checks your CC Stripe Auth 1$ Based Checker

CC-Checker This Python based program checks your CC Stripe Auth 1$ Based Checker About Author Coded by xBlackx Reach Me On Telegram @xBlackx_Coder jOI

xBlackxCoder 11 Nov 20, 2022
Integrated set of Django applications addressing authentication, registration, account management as well as 3rd party (social) account authentication.

Welcome to django-allauth! Integrated set of Django applications addressing authentication, registration, account management as well as 3rd party (soc

Raymond Penners 7.7k Jan 03, 2023
蓝鲸用户管理是蓝鲸智云提供的企业组织架构和用户管理解决方案,为企业统一登录提供认证源服务。

蓝鲸用户管理 简体中文 | English 蓝鲸用户管理是蓝鲸智云提供的企业组织架构和用户管理解决方案,为企业统一登录提供认证源服务。 总览 架构设计 代码目录 功能 支持多层级的组织架构管理 支持通过多种方式同步数据:OpenLDAP、Microsoft Active Directory(MAD)

腾讯蓝鲸 35 Dec 14, 2022
A generic, spec-compliant, thorough implementation of the OAuth request-signing logic

OAuthLib - Python Framework for OAuth1 & OAuth2 *A generic, spec-compliant, thorough implementation of the OAuth request-signing logic for Python 3.5+

OAuthlib 2.5k Jan 01, 2023
Django x Elasticsearch Templates

Django x Elasticsearch Requirements Python 3.7 Django = 3 Elasticsearch 7.15 Setup Elasticsearch Install via brew Install brew tap elastic/tap brew

Aji Pratama 0 May 22, 2022
User-related REST API based on the awesome Django REST Framework

Django REST Registration User registration REST API, based on Django REST Framework. Documentation Full documentation for the project is available at

Andrzej Pragacz 399 Jan 03, 2023
Foundation Auth Proxy is an abstraction on Foundations' authentication layer and is used to authenticate requests to Atlas's REST API.

foundations-auth-proxy Setup By default the server runs on http://0.0.0.0:5558. This can be changed via the arguments. Arguments: '-H' or '--host': ho

Dessa - Open Source 2 Jul 03, 2020
Django Authetication with Twitch.

Django Twitch Auth Dependencies Install requests if not installed pip install requests Installation Install using pip pip install django_twitch_auth A

Leandro Lopes Bueno 1 Jan 02, 2022
Easy and secure implementation of Azure AD for your FastAPI APIs 🔒 Single- and multi-tenant support.

Easy and secure implementation of Azure AD for your FastAPI APIs 🔒 Single- and multi-tenant support.

Intility 220 Jan 05, 2023
This script will pull and analyze syscalls in given application(s) allowing for easier security research purposes

SyscallExtractorAnalyzer This script will pull and analyze syscalls in given application(s) allowing for easier security research purposes Goals Teach

Truvis Thornton 18 Jul 09, 2022
Strong, Simple, and Precise security for Flask APIs (using jwt)

flask-praetorian Strong, Simple, and Precise security for Flask APIs API security should be strong, simple, and precise like a Roman Legionary. This p

Tucker Beck 321 Dec 18, 2022
Python's simple login system concept - Advanced level

Simple login system with Python - For beginners Creating a simple login system using python for beginners this repository aims to provide a simple ove

Low_Scarlet 1 Dec 13, 2021
JWT authentication for Pyramid

JWT authentication for Pyramid This package implements an authentication policy for Pyramid that using JSON Web Tokens. This standard (RFC 7519) is of

Wichert Akkerman 73 Dec 03, 2021
Ready-to-use and customizable users management for FastAPI

FastAPI Users Ready-to-use and customizable users management for FastAPI Documentation: https://frankie567.github.io/fastapi-users/ Source Code: https

François Voron 2.4k Jan 04, 2023
A JOSE implementation in Python

python-jose A JOSE implementation in Python Docs are available on ReadTheDocs. The JavaScript Object Signing and Encryption (JOSE) technologies - JSON

Michael Davis 1.2k Dec 28, 2022
Flask App With Login

Flask App With Login by FranciscoCharles Este projeto basico é o resultado do estudos de algumas funcionalidades do micro framework Flask do Python. O

Charles 3 Nov 14, 2021
OAuth2 goodies for the Djangonauts!

Django OAuth Toolkit OAuth2 goodies for the Djangonauts! If you are facing one or more of the following: Your Django app exposes a web API you want to

Jazzband 2.7k Dec 31, 2022