Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions

Overview

torch-imle

Concise and self-contained PyTorch library implementing the I-MLE gradient estimator proposed in our NeurIPS 2021 paper Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions.

This repository contains a library for transforming any combinatorial black-box solver in a differentiable layer. All code for reproducing the experiments in the NeurIPS paper is available in the official NEC Laboratories Europe repository.

Overview

Implicit MLE (I-MLE) makes it possible to include discrete combinatorial optimization algorithms, such as Dijkstra's algorithm or integer linear program (ILP) solvers, in standard deep learning architectures. The core idea of I-MLE is that it defines an implicit maximum likelihood objective whose gradients are used to update upstream parameters of the model. Every instance of I-MLE requires two ingredients:

  1. A method to approximately sample from a complex and intractable distribution. For this we use Perturb-and-MAP (aka the Gumbel-max trick) and propose a novel family of noise perturbations tailored to the problem at hand.
  2. A method to compute a surrogate empirical distribution: Vanilla MLE reduces the KL divergence between the current distribution and the empirical distribution. Since in our setting, we do not have access to an empirical distribution, we have to design surrogate empirical distributions. Here we propose two families of surrogate distributions which are widely applicable and work well in practice.

Example

For example, let's consider a map from a simple game where the task is to find the shortest path from the top-left to the bottom-right corner. Black areas have the highest and white areas the lowest cost. In the centre, you can see what happens when we use the proposed sum-of-gamma noise distribution to sample paths. On the right, you can see the resulting marginal probabilities for every tile (the probability of each tile being part of a sampled path).

Gradients and Learning

Let us assume that the optimal shortest path is the one of the left. Starting from random weights, the model can learn to produce the weights that will result in the optimal shortest path via Gradient Descent, by minimising the Hamming loss between the produced path and the gold path. Here we show the paths being produced during training (middle), and the corresponding map weights (right).

Input noise temperature set to 0.0, and target noise temperature set to 0.0:

Input noise temperature set to 1.0, and target noise temperature set to 1.0:

Input noise temperature set to 2.0, and target noise temperature set to 2.0:

Input noise temperature set to 5.0, and target noise temperature set to 5.0:

Input noise temperature set to 5.0, and target noise temperature set to 0.0:

All animations were generated by this script.

Code

Using this library is extremely easy -- see this example as a reference. Assuming we have a method that implements a black-box combinatorial solver such as Dijkstra's algorithm:

import numpy as np

import torch
from torch import Tensor

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

We can obtain the corresponding distribution and gradients in this way:

from imle.wrapper import imle
from imle.target import TargetDistribution
from imle.noise import SumOfGammaNoiseDistribution

target_distribution = TargetDistribution(alpha=0.0, beta=10.0)
noise_distribution = SumOfGammaNoiseDistribution(k=k, nb_iterations=100)

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

imle_solver = imle(torch_solver,
                   target_distribution=target_distribution,
                    noise_distribution=noise_distribution,
                    nb_samples=10,
                    input_noise_temperature=input_noise_temperature,
                    target_noise_temperature=target_noise_temperature)

Or, alternatively, using a simple function annotation:

@imle(target_distribution=target_distribution,
      noise_distribution=noise_distribution,
      nb_samples=10,
      input_noise_temperature=input_noise_temperature,
      target_noise_temperature=target_noise_temperature)
def imle_solver(weights_batch: Tensor) -> Tensor:
    return torch_solver(weights_batch)

Papers using I-MLE

Reference

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}
Owner
UCL Natural Language Processing
UCL Natural Language Processing
SEOVER: Sentence-level Emotion Orientation Vector based Conversation Emotion Recognition Model

SEOVER-Master This code is the implementation of paper: SEOVER: Sentence-level Emotion Orientation Vector based Conversation Emotion Recognition Model

4 Feb 24, 2022
State-of-the-art data augmentation search algorithms in PyTorch

MuarAugment Description MuarAugment is a package providing the easiest way to a state-of-the-art data augmentation pipeline. How to use You can instal

43 Dec 12, 2022
Tracking Progress in Question Answering over Knowledge Graphs

Tracking Progress in Question Answering over Knowledge Graphs Table of contents Question Answering Systems with Descriptions The QA Systems Table cont

Knowledge Graph Question Answering 47 Jan 02, 2023
Python module providing a framework to trace individual edges in an image using Gaussian process regression.

Edge Tracing using Gaussian Process Regression Repository storing python module which implements a framework to trace individual edges in an image usi

Jamie Burke 7 Dec 27, 2022
A Pytorch implementation of SMU: SMOOTH ACTIVATION FUNCTION FOR DEEP NETWORKS USING SMOOTHING MAXIMUM TECHNIQUE

SMU_pytorch A Pytorch Implementation of SMU: SMOOTH ACTIVATION FUNCTION FOR DEEP NETWORKS USING SMOOTHING MAXIMUM TECHNIQUE arXiv https://arxiv.org/ab

Fuhang 36 Dec 24, 2022
Source code for "Pack Together: Entity and Relation Extraction with Levitated Marker"

PL-Marker Source code for Pack Together: Entity and Relation Extraction with Levitated Marker. Quick links Overview Setup Install Dependencies Data Pr

THUNLP 173 Dec 30, 2022
《Geo Word Clouds》paper implementation

《Geo Word Clouds》paper implementation

Russellwzr 2 Jan 28, 2022
This repository contains the code used in the paper "Prompt-Based Multi-Modal Image Segmentation".

Prompt-Based Multi-Modal Image Segmentation This repository contains the code used in the paper "Prompt-Based Multi-Modal Image Segmentation". The sys

Timo Lüddecke 305 Dec 30, 2022
Luminaire is a python package that provides ML driven solutions for monitoring time series data.

A hands-off Anomaly Detection Library Table of contents What is Luminaire Quick Start Time Series Outlier Detection Workflow Anomaly Detection for Hig

Zillow 670 Jan 02, 2023
[Pedestron] Generalizable Pedestrian Detection: The Elephant In The Room. @ CVPR2021

Pedestron Pedestron is a MMdetection based repository, that focuses on the advancement of research on pedestrian detection. We provide a list of detec

Irtiza Hasan 594 Jan 05, 2023
Generating Videos with Scene Dynamics

Generating Videos with Scene Dynamics This repository contains an implementation of Generating Videos with Scene Dynamics by Carl Vondrick, Hamed Pirs

Carl Vondrick 706 Jan 04, 2023
The source code for Generating Training Data with Language Models: Towards Zero-Shot Language Understanding.

SuperGen The source code for Generating Training Data with Language Models: Towards Zero-Shot Language Understanding. Requirements Before running, you

Yu Meng 38 Dec 12, 2022
DGN pymarl - Implementation of DGN on Pymarl, which could be trained by VDN or QMIX

This is the implementation of DGN on Pymarl, which could be trained by VDN or QM

4 Nov 23, 2022
Official implementation of "Accelerating Reinforcement Learning with Learned Skill Priors", Pertsch et al., CoRL 2020

Accelerating Reinforcement Learning with Learned Skill Priors [Project Website] [Paper] Karl Pertsch1, Youngwoon Lee1, Joseph Lim1 1CLVR Lab, Universi

Cognitive Learning for Vision and Robotics (CLVR) lab @ USC 134 Dec 06, 2022
Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

Zongwei Zhou 1.8k Dec 27, 2022
Spectral normalization (SN) is a widely-used technique for improving the stability and sample quality of Generative Adversarial Networks (GANs)

Why Spectral Normalization Stabilizes GANs: Analysis and Improvements [paper (NeurIPS 2021)] [paper (arXiv)] [code] Authors: Zinan Lin, Vyas Sekar, Gi

Zinan Lin 32 Dec 16, 2022
Auto-Lama combines object detection and image inpainting to automate object removals

Auto-Lama Auto-Lama combines object detection and image inpainting to automate object removals. It is build on top of DE:TR from Facebook Research and

44 Dec 09, 2022
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

23 Nov 11, 2022
3D detection and tracking viewer (visualization) for kitti & waymo dataset

3D detection and tracking viewer (visualization) for kitti & waymo dataset

222 Jan 08, 2023
AI drive app that can help user become beautiful.

爱美丽 Beauty 简体中文 Features Beauty is an AI drive app that can help user become beautiful. it contain those functions: face score cheek face beauty repor

Starved Midnight 1 Jan 30, 2022