This is the official source code for SLATE. We provide the code for the model, the training code, and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.

Related tags

Deep Learningslate
Overview

SLATE

This is the official source code for SLATE. We provide the code for the model, the training code and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.

Arxiv: https://arxiv.org/pdf/2110.11405.pdf
Project Page: https://sites.google.com/view/slate-autoencoder

Dataset

The current release provides a boilerplate code to train the model on the 3D Shapes dataset. The dataset class is provided in shapes_3d.py. You can edit or replace this class if you need to run the code on a different dataset. The 3D Shapes dataset can be downloaded from the official URL https://console.cloud.google.com/storage/browser/3d-shapes. This should produce a dataset file 3dshapes.h5. During training, the path to this dataset file needs to be provided using the argument --data_path.

Training

To train the model, simply execute:

python train.py

Check train.py to see the full list of training arguments.

Outputs

The training code produces Tensorboard logs. To see these logs, run Tensorboard on the logging directory that was provided in the training argument --log_path. These logs contain the training loss curves and visualizations of reconstructions and object attention maps.

Hyperparameters of Interest

  • Learning Rate can be tuned using the training argument --lr_main and different choices can affect the characteristics of the object attention maps.
  • Number of Slots can be tuned using the training argument --num_slots. Number of slots should be set higher than the number of objects you expect to see in the images.
  • Number of Slot Attention Iterations can be tuned using the training argument --num_iterations. In general, keep the number of iterations as small as possible because too many iterations can prevent slots from learning to diversify and attach to different objects.

Code Files

This repository provides the following files.

  • train.py contains the main code for running the training.
  • slate.py provides the model class for SLATE.
  • shapes_3d.py contains the dataset class for 3D Shapes dataset.
  • dvae.py provides the encoder and the decoder for Discrete VAE.
  • slot_attn.py provides the model class for Slot Attention encoder.
  • transformer.py provides the model classes for Transformer.
  • utils.py provides helper classes and functions for the implementation.
Owner
Gautam Singh
PhD student at Rutgers CS
Gautam Singh
Autonomous Movement from Simultaneous Localization and Mapping

Autonomous Movement from Simultaneous Localization and Mapping About us Built by a group of Clarkson University students with the help from Professor

14 Nov 07, 2022
An educational tool to introduce AI planning concepts using mobile manipulator robots.

JEDAI Explains Decision-Making AI Virtual Machine Image The recommended way of using JEDAI is to use pre-configured Virtual Machine image that is avai

Autonomous Agents and Intelligent Robots 13 Nov 15, 2022
NeRViS: Neural Re-rendering for Full-frame Video Stabilization

Neural Re-rendering for Full-frame Video Stabilization

Yu-Lun Liu 9 Jun 17, 2022
Fast image augmentation library and an easy-to-use wrapper around other libraries

Albumentations Albumentations is a Python library for image augmentation. Image augmentation is used in deep learning and computer vision tasks to inc

11.4k Jan 09, 2023
Vikrant Deshpande 1 Nov 17, 2022
Pseudo lidar - (CVPR 2019) Pseudo-LiDAR from Visual Depth Estimation: Bridging the Gap in 3D Object Detection for Autonomous Driving

Pseudo-LiDAR from Visual Depth Estimation: Bridging the Gap in 3D Object Detection for Autonomous Driving This paper has been accpeted by Conference o

Yan Wang 881 Dec 27, 2022
AIR^2 for Interaction Prediction

This is the repository for AIR^2 for Interaction Prediction. Explanation of the solution: Video: link License AIR is released under the Apache 2.0 lic

21 Sep 27, 2022
Random Erasing Data Augmentation. Experiments on CIFAR10, CIFAR100 and Fashion-MNIST

Random Erasing Data Augmentation =============================================================== black white random This code has the source code for

Zhun Zhong 654 Dec 26, 2022
JugLab 33 Dec 30, 2022
Classify the disease status of a plant given an image of a passion fruit

Passion Fruit Disease Detection I tried to create an accurate machine learning models capable of localizing and identifying multiple Passion Fruits in

3 Nov 09, 2021
Neural Scene Flow Prior (NeurIPS 2021 spotlight)

Neural Scene Flow Prior Xueqian Li, Jhony Kaesemodel Pontes, Simon Lucey Will appear on Thirty-fifth Conference on Neural Information Processing Syste

Lilac Lee 85 Jan 03, 2023
Facebook Research 605 Jan 02, 2023
Joint Detection and Identification Feature Learning for Person Search

Person Search Project This repository hosts the code for our paper Joint Detection and Identification Feature Learning for Person Search. The code is

712 Dec 17, 2022
Barbershop: GAN-based Image Compositing using Segmentation Masks (SIGGRAPH Asia 2021)

Barbershop: GAN-based Image Compositing using Segmentation Masks Barbershop: GAN-based Image Compositing using Segmentation Masks Peihao Zhu, Rameen A

Peihao Zhu 928 Dec 30, 2022
PyTorch implementation of the Flow Gaussian Mixture Model (FlowGMM) model from our paper

Flow Gaussian Mixture Model (FlowGMM) This repository contains a PyTorch implementation of the Flow Gaussian Mixture Model (FlowGMM) model from our pa

Pavel Izmailov 124 Nov 06, 2022
Programming with Neural Surrogates of Programs

Programming with Neural Surrogates of Programs

0 Dec 12, 2021
RealFormer-Pytorch Implementation of RealFormer using pytorch

RealFormer-Pytorch Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt C

Simo Ryu 90 Dec 08, 2022
COVID-VIT: Classification of Covid-19 from CT chest images based on vision transformer models

COVID-ViT COVID-VIT: Classification of Covid-19 from CT chest images based on vision transformer models This code is to response to te MIA-COV19 compe

17 Dec 30, 2022
Sentinel-1 vessel detection model used in the xView3 challenge

sar_vessel_detect Code for the AI2 Skylight team's submission in the xView3 competition (https://iuu.xview.us) for vessel detection in Sentinel-1 SAR

AI2 6 Sep 10, 2022
Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)

Neuron Merging: Compensating for Pruned Neurons Pytorch implementation of Neuron Merging: Compensating for Pruned Neurons, accepted at 34th Conference

Woojeong Kim 33 Dec 30, 2022