2D Human Pose estimation using transformers. Implementation in Pytorch

Overview

PE-former: Pose Estimation Transformer

Vision transformer architectures perform very well for image classification tasks. Efforts to solve more challenging vision tasks with transformers rely on convolutional backbones for feature extraction.

POTR is a pure transformer architecture (no CNN backbone) for 2D body pose estimation. It uses an encoder-decoder architecture with a vision transformer as an encoder and a transformer decoder (derived from DETR).

You can use the code in this repository to train and evaluate different POTR configurations on the COCO dataset.

Model

POTR is based on building blocks derived from recent SOTA models. As shown in the figure there are two major components: A Visual Transformer encoder, and a Transformer decoder.

model

The input image is initially converted into tokens following the ViT paradigm. A position embedding is used to help retain the patch-location information. The tokens and the position embedding are used as input to transformer encoder. The transformed tokens are used as the memory input of the transformer decoder. The inputs of the decoder are M learned queries. For each query the network will produce a joint prediction. The output tokens from the transformer decoder are passed through two heads (FFNs).

  • The first is a classification head used to predict the joint type (i.e class) of each query.
  • The second is a regression head that predicts the normalized coordinates (in the range [0,1]) of the joint in the input image.

Predictions that do not correspond to joints are mapped to a "no object" class.

Acknowledgements

The code in this repository is based on the following:

Thank you!

Preparing

Create a python venv and install all the dependencies:

python -m venv pyenv
source pyenv/bin/activate
pip install -r requirements.txt

Training

Here are some CLI examples using the lit_main.py script.

Training POTR with a deit_small encoder, patch size of 16x16 pixels and input resolution 192x256:

python lit_main.py --vit_arch deit_deit_small --patch_size 16 --batch_size 42 --input_size 192 256 --hidden_dim 384 --vit_dim 384 --gpus 1 --num_workers 24

POTR with Xcit_small_p16 encoder:

 python lit_main.py --vit_arch xcit_small_12_p16 --batch_size 42 --input_size 288 384 --hidden_dim 384 --vit_dim 384 --gpus 1 --num_workers 24   --vit_weights https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth

POTR with the ViT as Backbone (VAB) configuration:

 python lit_main.py --vit_as_backbone --vit_arch resnet50 --batch_size 42 --input_size 192 256 --hidden_dim 384 --vit_dim 384 --gpus 1 --position_embedding learned_nocls --num_workers 16 --num_queries 100 --dim_feedforward 1536 --accumulate_grad_batches 1

Baseline that uses a resnet50 (pretrained with dino) as an encoder:

 python lit_main.py --vit_arch resnet50 --patch_size 16 --batch_size 42 --input_size 192 256 --hidden_dim 384 --vit_dim 384 --gpus 1 --num_workers 24 --vit_weights https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth --position_embedding learned_nocls

Check the lit_main.py cli arguments for a complete list.

python lit_main.py --help

Evaluation

Evaluate a trained model using the evaluate.py script.

For example to evaluate POTR with an xcit_small_12_p8 encoder:

python evaluate.py --vit_arch xcit_small_12_p8 --patch_size 8 --batch_size 42 --input_size 192 256 --hidden_dim 384 --vit_dim 384  --position_embedding enc_xcit --num_workers 16 --num_queries 100 --dim_feedforward 1536 --init_weights paper_experiments/xcit_small12_p8_dino_192_256_paper/checkpoints/checkpoint-epoch\=065-AP\=0.736.ckpt --use_det_bbox

Evaluate POTR with a deit_small encoder:

 python evaluate.py --vit_arch deit_deit_small --patch_size 16 --batch_size 42 --input_size 192 256 --hidden_dim 384 --vit_dim 384 --num_workers 24 --init_weights lightning_logs/version_0/checkpoints/checkpoint-epoch\=074-AP\=0.622.ckpt  --use_det_bbox

Set the argument of --init_weights to your model's checkpoint.

Model Zoo

name input params AP AR url
POTR-Deit-dino-p8 192x256 36.4M 70.6 78.1 model
POTR-Xcit-p16 288x384 40.6M 70.2 77.4 model
POTR-Xcit-dino-p16 288x384 40.6M 70.7 77.9 model
POTR-Xcit-dino-p8 192x256 40.5M 71.6 78.7 model
POTR-Xcit-dino-p8 288x384 40.5M 72.6 79.4 model

Check the experiments folder for configuration files and evaluation results.

All trained models and tensorboard training logs can be downloaded from this drive folder.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Owner
Panteleris Paschalis
Panteleris Paschalis
Related resources for our EMNLP 2021 paper

Plan-then-Generate: Controlled Data-to-Text Generation via Planning Authors: Yixuan Su, David Vandyke, Sihui Wang, Yimai Fang, and Nigel Collier Code

Yixuan Su 61 Jan 03, 2023
Spatial Attentive Single-Image Deraining with a High Quality Real Rain Dataset (CVPR'19)

Spatial Attentive Single-Image Deraining with a High Quality Real Rain Dataset (CVPR'19) Tianyu Wang*, Xin Yang*, Ke Xu, Shaozhe Chen, Qiang Zhang, Ry

Steve Wong 177 Dec 01, 2022
Direct design of biquad filter cascades with deep learning by sampling random polynomials.

IIRNet Direct design of biquad filter cascades with deep learning by sampling random polynomials. Usage git clone https://github.com/csteinmetz1/IIRNe

Christian J. Steinmetz 55 Nov 02, 2022
Code for HLA-Face: Joint High-Low Adaptation for Low Light Face Detection (CVPR21)

HLA-Face: Joint High-Low Adaptation for Low Light Face Detection The official PyTorch implementation for HLA-Face: Joint High-Low Adaptation for Low L

Wenjing Wang 77 Dec 08, 2022
Python implementation of "Single Image Haze Removal Using Dark Channel Prior"

##Dependencies pillow(~2.6.0) Numpy(~1.9.0) If the scripts throw AttributeError: __float__, make sure your pillow has jpeg support e.g. try: $ sudo ap

Joyee Cheung 73 Dec 20, 2022
Genshin-assets - 👧 Public documentation & static assets for Genshin Impact data.

genshin-assets This repo provides easy access to the Genshin Impact assets, primarily for use on static sites. Sources Genshin Optimizer - An Artifact

Zerite Development 5 Nov 22, 2022
Python and C++ implementation of "MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation". Accepted at LXCV @ CVPR 2021.

MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation This is a PyTorch and LibTorch implementation of MarkerPose: a

Jhacson Meza 47 Nov 18, 2022
Code for PhySG: Inverse Rendering with Spherical Gaussians for Physics-based Relighting and Material Editing

PhySG: Inverse Rendering with Spherical Gaussians for Physics-based Relighting and Material Editing CVPR 2021. Project page: https://kai-46.github.io/

Kai Zhang 141 Dec 14, 2022
Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective

Unofficial pytorch implementation of the paper "Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective"

16 Nov 21, 2022
Official PyTorch implementation for FastDPM, a fast sampling algorithm for diffusion probabilistic models

Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. S

Zhifeng Kong 68 Dec 26, 2022
Python package for visualizing the loss landscape of parameterized quantum algorithms.

orqviz A Python package for easily visualizing the loss landscape of Variational Quantum Algorithms by Zapata Computing Inc. orqviz provides a collect

Zapata Computing, Inc. 75 Dec 30, 2022
[EMNLP 2021] Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training

RoSTER The source code used for Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training, p

Yu Meng 60 Dec 30, 2022
RL-GAN: Transfer Learning for Related Reinforcement Learning Tasks via Image-to-Image Translation

RL-GAN: Transfer Learning for Related Reinforcement Learning Tasks via Image-to-Image Translation RL-GAN is an official implementation of the paper: T

42 Nov 10, 2022
Train Scene Graph Generation for Visual Genome and GQA in PyTorch >= 1.2 with improved zero and few-shot generalization.

Scene Graph Generation Object Detections Ground truth Scene Graph Generated Scene Graph In this visualization, woman sitting on rock is a zero-shot tr

Boris Knyazev 93 Dec 28, 2022
A python3 tool to take a 360 degree survey of the RF spectrum (hamlib + rotctld + RTL-SDR/HackRF)

RF Light House (rflh) A python script to use a rotor and a SDR device (RTL-SDR or HackRF One) to measure the RF level around and get a data set and be

Pavel Milanes (CO7WT) 11 Dec 13, 2022
The UI as a mobile display for OP25

OP25 Mobile Control Head A 'remote' control head that interfaces with an OP25 instance. We take advantage of some data end-points left exposed for the

Sarah Rose Giddings 13 Dec 28, 2022
Pytorch implementation of the paper "Topic Modeling Revisited: A Document Graph-based Neural Network Perspective"

Graph Neural Topic Model (GNTM) This is the pytorch implementation of the paper "Topic Modeling Revisited: A Document Graph-based Neural Network Persp

Dazhong Shen 8 Sep 14, 2022
This repository is a series of notebooks that show solutions for the projects at Dataquest.io.

Dataquest Project Solutions This repository is a series of notebooks that show solutions for the projects at Dataquest.io. Of course, there are always

Dataquest 1.1k Dec 30, 2022
Pytorch implementation of COIN, a framework for compression with implicit neural representations 🌸

COIN 🌟 This repo contains a Pytorch implementation of COIN: COmpression with Implicit Neural representations, including code to reproduce all experim

Emilien Dupont 104 Dec 14, 2022
Sparse Physics-based and Interpretable Neural Networks

Sparse Physics-based and Interpretable Neural Networks for PDEs This repository contains the code and manuscript for research done on Sparse Physics-b

28 Jan 03, 2023