A PyTorch implementation of a Factorization Machine module in cython.

Overview

fmpytorch

A library for factorization machines in pytorch. A factorization machine is like a linear model, except multiplicative interaction terms between the variables are modeled as well.

The input to a factorization machine layer is a vector, and the output is a scalar. Batching is fully supported.

This is a work in progress. Feedback and bugfixes welcome! Hopefully you find the code useful.

Usage

The factorization machine layers in fmpytorch can be used just like any other built-in module. Here's a simple feed-forward model using a factorization machine that takes in a 50-D input, and models interactions using k=5 factors.

import torch
from fmpytorch.second_order.fm import FactorizationMachine

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(100, 50)
        self.dropout = torch.nn.Dropout(.5)
	# This makes a fm layer mapping from 50-D to 1-D.
	# The number of factors is 5.
        self.fm = FactorizationMachine(50, 5)

    def forward(self, x):
        x = self.linear(x)
        x = self.dropout(x)
        x = self.fm(x)
        return x

See examples/toy.py or examples/regression.py for fuller examples.

Installation

This package requires pytorch, numpy, and cython.

To install, you can run:

cd fmpytorch
sudo python setup.py install

Factorization Machine brief intro

A linear model, given a vector x models its output y as

where w are the learnable weights of the model.

However, the interactions between the input variables x_i are purely additive. In some cases, it might be useful to model the interactions between your variables, e.g., x_i * x_j. You could add terms into your model like

However, this introduces a large number of w2 variables. Specifically, there are O(n^2) parameters introduced in this formulation, one for each interaction pair. A factorization machine approximates w2 using low dimensional factors, i.e.,

where each v_i is a low-dimensional vector. This is the forward pass of a second order factorization machine. This low-rank re-formulation has reduced the number of additional parameters for the factorization machine to O(k*n). Magically, the forward (and backward) pass can be reformulated so that it can be computed in O(k*n), rather than the naive O(k*n^2) formulation above.

Currently supported features

Currently, only a second order factorization machine is supported. The forward and backward passes are implemented in cython. Compared to the autodiff solution, the cython passes run several orders of magnitude faster. I've only tested it with python 2 at the moment.

TODOs

  1. Support for sparse tensors.
  2. More interesting useage examples
  3. More testing, e.g., with python 3, etc.
  4. Make sure all of the code plays nice with torch-specific stuff, e.g., GPUs
  5. Arbitrary order factorization machine support
  6. Better organization/code cleaning

Thanks to

Vlad Niculae (@vene) for his sage wisdom.

The original factorization machine citation, which this layer is based off of, is

@inproceedings{rendle2010factorization,
	       title={Factorization machines},
    	       author={Rendle, Steffen},
      	       booktitle={ICDM},
               pages={995--1000},
	       year={2010},
	       organization={IEEE}
}
Owner
Jack Hessel
Research Scientist @ AI2: PhD in CS previously from Cornell
Jack Hessel
Automatic differentiation with weighted finite-state transducers.

GTN: Automatic Differentiation with WFSTs Quickstart | Installation | Documentation What is GTN? GTN is a framework for automatic differentiation with

100 Dec 29, 2022
EfficientNetV2-with-TPU - Cifar-10 case study

EfficientNetV2-with-TPU EfficientNet EfficientNetV2 adalah jenis jaringan saraf convolutional yang memiliki kecepatan pelatihan lebih cepat dan efisie

Sultan syach 1 Dec 28, 2021
Implementation of C-RNN-GAN.

Implementation of C-RNN-GAN. Publication: Title: C-RNN-GAN: Continuous recurrent neural networks with adversarial training Information: http://mogren.

Olof Mogren 427 Dec 25, 2022
Pytorch Implementation of paper "Noisy Natural Gradient as Variational Inference"

Noisy Natural Gradient as Variational Inference PyTorch implementation of Noisy Natural Gradient as Variational Inference. Requirements Python 3 Pytor

Tony JiHyun Kim 119 Dec 02, 2022
[CVPR 2022] PoseTriplet: Co-evolving 3D Human Pose Estimation, Imitation, and Hallucination under Self-supervision (Oral)

PoseTriplet: Co-evolving 3D Human Pose Estimation, Imitation, and Hallucination under Self-supervision Kehong Gong*, Bingbing Li*, Jianfeng Zhang*, Ta

256 Dec 28, 2022
Lightweight library to build and train neural networks in Theano

Lasagne Lasagne is a lightweight library to build and train neural networks in Theano. Its main features are: Supports feed-forward networks such as C

Lasagne 3.8k Dec 29, 2022
DeepFashion2 is a comprehensive fashion dataset.

DeepFashion2 Dataset DeepFashion2 is a comprehensive fashion dataset. It contains 491K diverse images of 13 popular clothing categories from both comm

switchnorm 1.8k Jan 07, 2023
Creating a Linear Program Solver by Implementing the Simplex Method in Python with NumPy

Creating a Linear Program Solver by Implementing the Simplex Method in Python with NumPy Simplex Algorithm is a popular algorithm for linear programmi

Reda BELHAJ 2 Oct 12, 2022
Code base of object detection

rmdet code base of object detection. 环境安装: 1. 安装conda python环境 - `conda create -n xxx python=3.7/3.8` - `conda activate xxx` 2. 运行脚本,自动安装pytorch1

3 Mar 08, 2022
A simple baseline for 3d human pose estimation in PyTorch.

3d_pose_baseline_pytorch A PyTorch implementation of a simple baseline for 3d human pose estimation. You can check the original Tensorflow implementat

weigq 312 Jan 06, 2023
PyTorch code for training MM-DistillNet for multimodal knowledge distillation

There is More than Meets the Eye: Self-Supervised Multi-Object Detection and Tracking with Sound by Distilling Multimodal Knowledge MM-DistillNet is a

51 Dec 20, 2022
Learning with Subset Stacking

Learning with Subset Stacking (LESS) LESS is a new supervised learning algorithm that is based on training many local estimators on subsets of a given

S. Ilker Birbil 19 Oct 04, 2022
Image Classification - A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

0 Jan 23, 2022
Code for PhySG: Inverse Rendering with Spherical Gaussians for Physics-based Relighting and Material Editing

PhySG: Inverse Rendering with Spherical Gaussians for Physics-based Relighting and Material Editing CVPR 2021. Project page: https://kai-46.github.io/

Kai Zhang 141 Dec 14, 2022
A-ESRGAN aims to provide better super-resolution images by using multi-scale attention U-net discriminators.

A-ESRGAN: Training Real-World Blind Super-Resolution with Attention-based U-net Discriminators The authors are hidden for the purpose of double blind

77 Dec 16, 2022
Machine learning evaluation metrics, implemented in Python, R, Haskell, and MATLAB / Octave

Note: the current releases of this toolbox are a beta release, to test working with Haskell's, Python's, and R's code repositories. Metrics provides i

Ben Hamner 1.6k Dec 26, 2022
A Large-Scale Dataset for Spinal Vertebrae Segmentation in Computed Tomography

A Large-Scale Dataset for Spinal Vertebrae Segmentation in Computed Tomography

ICT.MIRACLE lab 75 Dec 26, 2022
Official PyTorch implementation of Segmenter: Transformer for Semantic Segmentation

Segmenter: Transformer for Semantic Segmentation Segmenter: Transformer for Semantic Segmentation by Robin Strudel*, Ricardo Garcia*, Ivan Laptev and

594 Jan 06, 2023
Graph Analysis From Scratch

Graph Analysis From Scratch Goal In this notebook we wanted to implement some functionalities to analyze a weighted graph only by using algorithms imp

Arturo Ghinassi 0 Sep 17, 2022
Turning pixels into virtual points for multimodal 3D object detection.

Multimodal Virtual Point 3D Detection Turning pixels into virtual points for multimodal 3D object detection. Multimodal Virtual Point 3D Detection, Ti

Tianwei Yin 204 Jan 08, 2023