Official implementation of the PICASO: Permutation-Invariant Cascaded Attentional Set Operator

Related tags

Deep LearningPICASO
Overview

PICASO

Official PyTorch implemetation for the paper PICASO:Permutation-Invariant Cascaded Attentive Set Operator.

Requirements

  • Python 3
  • torch >= 1.0
  • numpy
  • matplotlib
  • scipy
  • tqdm

Abstract

Set-input deep networks have recently drawn much interest in computer vision and machine learning. This is in part due to the increasing number of important tasks such as meta-learning, clustering, and anomaly detection that are defined on set inputs. These networks must take an arbitrary number of input samples and produce the output invariant to the input set permutation. Several algorithms have been recently developed to address this urgent need. Our paper analyzes these algorithms using both synthetic and real-world datasets, and shows that they are not effective in dealing with common data variations such as image translation or viewpoint change. To address this limitation, we propose a permutation-invariant cascaded attentional set operator (PICASO). The gist of PICASO is a cascade of multihead attention blocks with dynamic templates. The proposed operator is a stand-alone module that can be adapted and extended to serve different machine learning tasks. We demonstrate the utilities of PICASO in four diverse scenarios: (i) clustering, (ii) image classification under novel viewpoints, (iii) image anomaly detection, and (iv) state prediction. PICASO increases the SmallNORB image classification accuracy with novel viewpoints by about 10% points. For set anomaly detection on CelebA dataset, our model improves the areas under ROC and PR curves dataset by about 22% and 10%, respectively. For the state prediction on CLEVR dataset, it improves the AP by about 40%.

Experiments

This repository implements the amortized clustering, classification, set anomaly detection, and state prediction experiments in the paper.

Amortized Clustering

You can use run.py to implement the experiment. To shift the data domain, you can use mvn_diag.py and add shift value to X.

Classification

We have used preprocessed smallNORB dataset for this experiment.

Set Anomaly Detection

In this experiment, we have used CelebA dataset. The preprocessing code is also provided in Set Anomaly Detection folder.

State Prediction

We used the same process employed in the Slot Attention paper. We recommend using multiple GPUs for this experiment.

Reference

If you found our code useful, please consider citing our work.

@misc{zare2021picaso,
      title={PICASO: Permutation-Invariant Cascaded Attentional Set Operator}, 
      author={Samira Zare and Hien Van Nguyen},
      year={2021},
      eprint={2107.08305},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
Owner
Samira Zare
Samira Zare
Contains source code for the winning solution of the xView3 challenge

Winning Solution for xView3 Challenge This repository contains source code and pretrained models for my (Eugene Khvedchenya) solution to xView 3 Chall

Eugene Khvedchenya 51 Dec 30, 2022
Title: Graduate-Admissions-Predictor

The purpose of this project is create a predictive model capable of identifying the probability of a person securing an admit based on their personal profile parameters. Simplified visualisations hav

Akarsh Singh 1 Jan 26, 2022
Performance Analysis of Multi-user NOMA Wireless-Powered mMTC Networks: A Stochastic Geometry Approach

Performance Analysis of Multi-user NOMA Wireless-Powered mMTC Networks: A Stochastic Geometry Approach Thanh Luan Nguyen, Tri Nhu Do, Georges Kaddoum

Thanh Luan Nguyen 2 Oct 10, 2022
A lightweight tool to get an AI Infrastructure Stack up in minutes not days.

K3ai will take care of setup K8s for You, deploy the AI tool of your choice and even run your code on it.

k3ai 105 Dec 04, 2022
an implementation of Video Frame Interpolation via Adaptive Separable Convolution using PyTorch

This work has now been superseded by: https://github.com/sniklaus/revisiting-sepconv sepconv-slomo This is a reference implementation of Video Frame I

Simon Niklaus 985 Jan 08, 2023
Tensorflow Implementation of the paper "Spectral Normalization for Generative Adversarial Networks" (ICML 2017 workshop)

tf-SNDCGAN Tensorflow implementation of the paper "Spectral Normalization for Generative Adversarial Networks" (https://www.researchgate.net/publicati

Nhat M. Nguyen 248 Nov 25, 2022
TorchFlare is a simple, beginner-friendly, and easy-to-use PyTorch Framework train your models effortlessly.

TorchFlare TorchFlare is a simple, beginner-friendly and an easy-to-use PyTorch Framework train your models without much effort. It provides an almost

Atharva Phatak 85 Dec 26, 2022
Uses OpenCV and Python Code to detect a face on the screen

Simple-Face-Detection This code uses OpenCV and Python Code to detect a face on the screen. This serves as an example program. Important prerequisites

Denis Woolley (CreepyD) 1 Feb 12, 2022
using STGCN to achieve egg classification task

EEG Classification   The task requires us to classify electroencephalography(EEG) into six categories, including human body, human face, animal body,

4 Jun 13, 2022
Code, pre-trained models and saliency results for the paper "Boosting RGB-D Saliency Detection by Leveraging Unlabeled RGB Images".

Boosting RGB-D Saliency Detection by Leveraging Unlabeled RGB This repository is the official implementation of the paper. Our results comming soon in

Xiaoqiang Wang 8 May 22, 2022
Real-Time Semantic Segmentation in Mobile device

Real-Time Semantic Segmentation in Mobile device This project is an example project of semantic segmentation for mobile real-time app. The architectur

708 Jan 01, 2023
Official repository of IMPROVING DEEP IMAGE MATTING VIA LOCAL SMOOTHNESS ASSUMPTION.

IMPROVING DEEP IMAGE MATTING VIA LOCAL SMOOTHNESS ASSUMPTION This is the official repository of IMPROVING DEEP IMAGE MATTING VIA LOCAL SMOOTHNESS ASSU

电线杆 14 Dec 15, 2022
Get 2D point positions (e.g., facial landmarks) projected on 3D mesh

points2d_projection_mesh Input 2D points (e.g. facial landmarks) on an image Camera parameters (extrinsic and intrinsic) of the image Aligned 3D mesh

5 Dec 08, 2022
Implementation of the Triangle Multiplicative module, used in Alphafold2 as an efficient way to mix rows or columns of a 2d feature map, as a standalone package for Pytorch

Triangle Multiplicative Module - Pytorch Implementation of the Triangle Multiplicative module, used in Alphafold2 as an efficient way to mix rows or c

Phil Wang 22 Oct 28, 2022
Deep Markov Factor Analysis (NeurIPS2021)

Deep Markov Factor Analysis (DMFA) Codes and experiments for deep Markov factor analysis (DMFA) model accepted for publication at NeurIPS2021: A. Farn

Sarah Ostadabbas 2 Dec 16, 2022
Dataset and Code for ICCV 2021 paper "Real-world Video Super-resolution: A Benchmark Dataset and A Decomposition based Learning Scheme"

Dataset and Code for RealVSR Real-world Video Super-resolution: A Benchmark Dataset and A Decomposition based Learning Scheme Xi Yang, Wangmeng Xiang,

Xi Yang 92 Jan 04, 2023
Official implementation for “Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior”

HEP Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior Implementation Python3 PyTorch=1.0 NVIDIA GPU+CUDA Training process The

FengZhang 34 Dec 04, 2022
PyTorch implementation of the supervised learning experiments from the paper Model-Agnostic Meta-Learning (MAML)

pytorch-maml This is a PyTorch implementation of the supervised learning experiments from the paper Model-Agnostic Meta-Learning (MAML): https://arxiv

Kate Rakelly 516 Jan 05, 2023
Spline is a tool that is capable of running locally as well as part of well known pipelines like Jenkins (Jenkinsfile), Travis CI (.travis.yml) or similar ones.

Welcome to spline - the pipeline tool Important note: Since change in my job I didn't had the chance to continue on this project. My main new project

Thomas Lehmann 29 Aug 22, 2022
GarmentNets: Category-Level Pose Estimation for Garments via Canonical Space Shape Completion

GarmentNets This repository contains the source code for the paper GarmentNets: Category-Level Pose Estimation for Garments via Canonical Space Shape

Columbia Artificial Intelligence and Robotics Lab 43 Nov 21, 2022