Implementation of TabTransformer, attention network for tabular data, in Pytorch

Overview

Tab Transformer

Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's breadth of GBDT's performance.

Install

$ pip install tab-transformer-pytorch

Usage

import torch
from tab_transformer_pytorch import TabTransformer

cont_mean_std = torch.randn(10, 2)

model = TabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    dim = 32,                           # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    depth = 6,                          # depth, paper recommended 6
    heads = 8,                          # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
    mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm
)

x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)

Unsupervised Training

To undergo the type of unsupervised training described in the paper, you can first convert your categories tokens to the appropriate unique ids, and then use Electra on model.transformer.

Citations

@misc{huang2020tabtransformer,
    title={TabTransformer: Tabular Data Modeling Using Contextual Embeddings}, 
    author={Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
    year={2020},
    eprint={2012.06678},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
Comments
  • Minor Bug: actuation function being applied to output layer in class MLP

    Minor Bug: actuation function being applied to output layer in class MLP

    The code for class MLP is mistakingly applying the actuation function to the last (i.e. output) layer. The error is in the evaluation of the is_last flag. The current code is:

    class MLP(nn.Module):
        def __init__(self, dims, act = None):
            super().__init__()
            dims_pairs = list(zip(dims[:-1], dims[1:]))
            layers = []
            for ind, (dim_in, dim_out) in enumerate(dims_pairs):
                is_last = ind >= (len(dims) - 1)
    

    The last line should be changed to is_last = ind >= (len(dims) - 2):

    class MLP(nn.Module):
        def __init__(self, dims, act = None):
            super().__init__()
            dims_pairs = list(zip(dims[:-1], dims[1:]))
            layers = []
            for ind, (dim_in, dim_out) in enumerate(dims_pairs):
                is_last = ind >= (len(dims) - 2)
    

    If you like, I can do a pull request.

    opened by rminhas 1
  • Update tab_transformer_pytorch.py

    Update tab_transformer_pytorch.py

    Add activation function out of the loop for the whole model, not after each of the linear layers. 'if is_last' condition was creating linear output all the time no matter what the activation function was.

    opened by EveryoneDirn 0
  • Unindent continuous_mean_std buffer

    Unindent continuous_mean_std buffer

    Problem: continuous_mean_std is not an attribute of TabTransformer if not defined in the argument explicitly. Example reproducing AttributeError:

    model = TabTransformer(
        categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
        num_continuous = 10,                # number of continuous values
        dim = 32,                           # dimension, paper set at 32
        dim_out = 1,                        # binary prediction, but could be anything
        depth = 6,                          # depth, paper recommended 6
        heads = 8,                          # heads, paper recommends 8
        attn_dropout = 0.1,                 # post-attention dropout
        ff_dropout = 0.1,                   # feed forward dropout
        mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
        mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    # continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm)
    x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
    x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually
    pred = model(x_categ, x_cont) # gives AttributeError
    
    

    Solution: Simply un-indenting the buffer registration of continuous_mean_std.

    opened by spliew 0
  • low gpu usage,

    low gpu usage,

    Hi.

    I'm having a problem with running your code with my dataset. It's pretty slow. GPU runs at 50% usage in average and each epoch takes almost 900 seconds to run.

    My dataset has 590540 rows, 24 categorical features, and 192 continuous features. Categories are encoded using Label encoder. Total dataset size is around 600Mb. My gpu is an integrated NVIDIA RTX 3060 with 6Gb of RAM. Optimizer is Adam.

    These are the software versions:

    Windows 10

    Python: 3.7.11 Pytorch: 1.7.0+cu110 Numpy: 1.21.2

    Let me know if you need more info from my side.

    Thanks.

    Xin.

    opened by xinqiao123 0
  • Intended usage of num_special_tokens?

    Intended usage of num_special_tokens?

    From what I understand, these are supposed to be reserved for oov values. Is the intended usage to set oov values in the input to some negative number and overwrite the offset? That is what it seems like it would take to achieve the desired outcome, but also seems somewhat confusing and clunky to do. Or perhaps I am misunderstanding its purpose? Thanks!

    opened by LLYX 2
  • No Category Shared Embedding?

    No Category Shared Embedding?

    I noticed that this implementation does not seem to have the feature of a shared embedding between each value belonging to the same category (unless I missed it) that the paper mentions (c_phi_i). If it's indeed missing, do you have plans to add that?

    Thanks for this implementation!

    opened by LLYX 3
  • index -1 is out of bounds for dimension 1 with size 17

    index -1 is out of bounds for dimension 1 with size 17

    I encountered this problem during the training process. What is the possible reason for this problem, and how can I solve this problem? Thanks!

      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 583, in forward
        return self.tabnet(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 468, in forward
        steps_output, M_loss = self.encoder(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 160, in forward
        M = self.att_transformers[step](prior, att)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 637, in forward
        x = self.selector(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 109, in forward
        return sparsemax(input, self.dim)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 52, in forward
        tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 94, in _threshold_and_support
        tau = input_cumsum.gather(dim, support_size - 1)
    RuntimeError: index -1 is out of bounds for dimension 1 with size 17
    Experiment has terminated.
    
    opened by hengzhe-zhang 2
  • Is there any training example about tabtransformer?

    Is there any training example about tabtransformer?

    Hi, I want to use it in a tabular dataset to finish a supervised learning,But I dont really know how to train this model with dataset(it seems that there is no such content in the readme file ). Could you please help me? thank you.

    opened by pancodex 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Nested cross-validation is necessary to avoid biased model performance in embedded feature selection in high-dimensional data with tiny sample sizes

Pruner for nested cross-validation - Sphinx-Doc Nested cross-validation is necessary to avoid biased model performance in embedded feature selection i

1 Dec 15, 2021
An AFL implementation with UnTracer (our coverage-guided tracer)

UnTracer-AFL This repository contains an implementation of our prototype coverage-guided tracing framework UnTracer in the popular coverage-guided fuz

113 Dec 17, 2022
This implements the learning and inference/proposal algorithm described in "Learning to Propose Objects, Krähenbühl and Koltun"

Learning to propose objects This implements the learning and inference/proposal algorithm described in "Learning to Propose Objects, Krähenbühl and Ko

Philipp Krähenbühl 90 Sep 10, 2021
Pretraining Representations For Data-Efficient Reinforcement Learning

Pretraining Representations For Data-Efficient Reinforcement Learning Max Schwarzer, Nitarshan Rajkumar, Michael Noukhovitch, Ankesh Anand, Laurent Ch

Mila 40 Dec 11, 2022
Pytorch implementation of our paper under review — Lottery Jackpots Exist in Pre-trained Models

Lottery Jackpots Exist in Pre-trained Models (Paper Link) Requirements Python = 3.7.4 Pytorch = 1.6.1 Torchvision = 0.4.1 Reproduce the Experiment

Yuxin Zhang 27 Jun 28, 2022
A curated list of the top 10 computer vision papers in 2021 with video demos, articles, code and paper reference.

The Top 10 Computer Vision Papers of 2021 The top 10 computer vision papers in 2021 with video demos, articles, code, and paper reference. While the w

Louis-François Bouchard 118 Dec 21, 2022
StarGAN - Official PyTorch Implementation (CVPR 2018)

StarGAN - Official PyTorch Implementation ***** New: StarGAN v2 is available at https://github.com/clovaai/stargan-v2 ***** This repository provides t

Yunjey Choi 5.1k Jan 04, 2023
Real time sign language recognition

The proposed work aims at converting american sign language gestures into English that can be understood by everyone in real time.

Mohit Kaushik 6 Jun 13, 2022
Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy" (ICLR 2022 Spotlight)

About Code release for Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy (ICLR 2022 Spotlight)

THUML @ Tsinghua University 221 Dec 31, 2022
Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic video-to-video translation.

vid2vid Project | YouTube(short) | YouTube(full) | arXiv | Paper(full) Pytorch implementation for high-resolution (e.g., 2048x1024) photorealistic vid

NVIDIA Corporation 8.1k Jan 01, 2023
A plug-and-play library for neural networks written in Python

A plug-and-play library for neural networks written in Python!

Dimos Michailidis 2 Jul 16, 2022
Code to reproduce the experiments in the paper "Transformer Based Multi-Source Domain Adaptation" (EMNLP 2020)

Transformer Based Multi-Source Domain Adaptation Dustin Wright and Isabelle Augenstein To appear in EMNLP 2020. Read the preprint: https://arxiv.org/a

CopeNLU 36 Dec 05, 2022
Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch

NÜWA - Pytorch (wip) Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. This repository will be popul

Phil Wang 463 Dec 28, 2022
Deep Learning Interviews book: Hundreds of fully solved job interview questions from a wide range of key topics in AI.

This book was written for you: an aspiring data scientist with a quantitative background, facing down the gauntlet of the interview process in an increasingly competitive field. For most of you, the

4.1k Dec 28, 2022
[WACV21] Code for our paper: Samuel, Atzmon and Chechik, "From Generalized zero-shot learning to long-tail with class descriptors"

DRAGON: From Generalized zero-shot learning to long-tail with class descriptors Paper Project Website Video Overview DRAGON learns to correct the bias

Dvir Samuel 25 Dec 06, 2022
face_recognization (FaceNet) + TFHE (HNP) + hand_face_detection (Mediapipe)

SuperControlSystem Face_Recognization (FaceNet) 面部识别 (FaceNet) Fully Homomorphic Encryption over the Torus (HNP) 环面全同态加密 (TFHE) Hand_Face_Detection (M

liziyu0104 2 Dec 30, 2021
Code for ICLR 2021 Paper, "Anytime Sampling for Autoregressive Models via Ordered Autoencoding"

Anytime Autoregressive Model Anytime Sampling for Autoregressive Models via Ordered Autoencoding , ICLR 21 Yilun Xu, Yang Song, Sahaj Gara, Linyuan Go

Yilun Xu 22 Sep 08, 2022
FNet Implementation with TensorFlow & PyTorch

FNet Implementation with TensorFlow & PyTorch. TensorFlow & PyTorch implementation of the paper "FNet: Mixing Tokens with Fourier Transforms". Overvie

Abdelghani Belgaid 1 Feb 12, 2022
Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.

Tensor2Tensor Tensor2Tensor, or T2T for short, is a library of deep learning models and datasets designed to make deep learning more accessible and ac

12.9k Jan 09, 2023
This project is a re-implementation of MASTER: Multi-Aspect Non-local Network for Scene Text Recognition by MMOCR

This project is a re-implementation of MASTER: Multi-Aspect Non-local Network for Scene Text Recognition by MMOCR,which is an open-source toolbox based on PyTorch. The overall architecture will be sh

Jianquan Ye 82 Nov 17, 2022