Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions

Overview

torch-imle

Concise and self-contained PyTorch library implementing the I-MLE gradient estimator proposed in our NeurIPS 2021 paper Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions.

This repository contains a library for transforming any combinatorial black-box solver in a differentiable layer. All code for reproducing the experiments in the NeurIPS paper is available in the official NEC Laboratories Europe repository.

Overview

Implicit MLE (I-MLE) makes it possible to include discrete combinatorial optimization algorithms, such as Dijkstra's algorithm or integer linear program (ILP) solvers, in standard deep learning architectures. The core idea of I-MLE is that it defines an implicit maximum likelihood objective whose gradients are used to update upstream parameters of the model. Every instance of I-MLE requires two ingredients:

  1. A method to approximately sample from a complex and intractable distribution. For this we use Perturb-and-MAP (aka the Gumbel-max trick) and propose a novel family of noise perturbations tailored to the problem at hand.
  2. A method to compute a surrogate empirical distribution: Vanilla MLE reduces the KL divergence between the current distribution and the empirical distribution. Since in our setting, we do not have access to an empirical distribution, we have to design surrogate empirical distributions. Here we propose two families of surrogate distributions which are widely applicable and work well in practice.

Example

For example, let's consider a map from a simple game where the task is to find the shortest path from the top-left to the bottom-right corner. Black areas have the highest and white areas the lowest cost. In the centre, you can see what happens when we use the proposed sum-of-gamma noise distribution to sample paths. On the right, you can see the resulting marginal probabilities for every tile (the probability of each tile being part of a sampled path).

Gradients and Learning

Let us assume that the optimal shortest path is the one of the left. Starting from random weights, the model can learn to produce the weights that will result in the optimal shortest path via Gradient Descent, by minimising the Hamming loss between the produced path and the gold path. Here we show the paths being produced during training (middle), and the corresponding map weights (right).

Input noise temperature set to 0.0, and target noise temperature set to 0.0:

Input noise temperature set to 1.0, and target noise temperature set to 1.0:

Input noise temperature set to 2.0, and target noise temperature set to 2.0:

Input noise temperature set to 5.0, and target noise temperature set to 5.0:

Input noise temperature set to 5.0, and target noise temperature set to 0.0:

All animations were generated by this script.

Code

Using this library is extremely easy -- see this example as a reference. Assuming we have a method that implements a black-box combinatorial solver such as Dijkstra's algorithm:

import numpy as np

import torch
from torch import Tensor

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

We can obtain the corresponding distribution and gradients in this way:

from imle.wrapper import imle
from imle.target import TargetDistribution
from imle.noise import SumOfGammaNoiseDistribution

target_distribution = TargetDistribution(alpha=0.0, beta=10.0)
noise_distribution = SumOfGammaNoiseDistribution(k=k, nb_iterations=100)

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

imle_solver = imle(torch_solver,
                   target_distribution=target_distribution,
                    noise_distribution=noise_distribution,
                    nb_samples=10,
                    input_noise_temperature=input_noise_temperature,
                    target_noise_temperature=target_noise_temperature)

Or, alternatively, using a simple function annotation:

@imle(target_distribution=target_distribution,
      noise_distribution=noise_distribution,
      nb_samples=10,
      input_noise_temperature=input_noise_temperature,
      target_noise_temperature=target_noise_temperature)
def imle_solver(weights_batch: Tensor) -> Tensor:
    return torch_solver(weights_batch)

Papers using I-MLE

Reference

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}
Owner
UCL Natural Language Processing
UCL Natural Language Processing
Structured Data Gradient Pruning (SDGP)

Structured Data Gradient Pruning (SDGP) Weight pruning is a technique to make Deep Neural Network (DNN) inference more computationally efficient by re

Bradley McDanel 10 Nov 11, 2022
Dieser Scanner findet Websites, die nicht direkt in Suchmaschinen auftauchen, aber trotzdem erreichbar sind.

Deep Web Scanner Dieses Script findet Websites, die per IPv4-Adresse erreichbar sind und speichert deren Metadaten. Die Ausgabe im Terminal wird nach

Alex K. 30 Nov 18, 2022
Hydra Lightning Template for Structured Configs

Hydra Lightning Template for Structured Configs Template for creating projects with pytorch-lightning and hydra. How to use this template? Create your

Model-driven Machine Learning 4 Jul 19, 2022
Code for our SIGCOMM'21 paper "Network Planning with Deep Reinforcement Learning".

0. Introduction This repository contains the source code for our SIGCOMM'21 paper "Network Planning with Deep Reinforcement Learning". Notes The netwo

NetX Group 68 Nov 24, 2022
A PyTorch implementation for our paper "Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation".

Dual-Contrastive-Learning A PyTorch implementation for our paper "Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation". Y

hoshi-hiyouga 85 Dec 26, 2022
Regularized Frank-Wolfe for Dense CRFs: Generalizing Mean Field and Beyond

CRF - Conditional Random Fields A library for dense conditional random fields (CRFs). This is the official accompanying code for the paper Regularized

Đ.Khuê Lê-Huu 21 Nov 26, 2022
Discord bot for notifying on github events

Git-Observer Discord bot for notifying on github events ⚠️ This bot is meant to write messages to only one channel (implementing this for multiple pro

ilu_vatar_ 0 Apr 19, 2022
SHIFT15M: multiobjective large-scale fashion dataset with distributional shifts

[arXiv] The main motivation of the SHIFT15M project is to provide a dataset that contains natural dataset shifts collected from a web service IQON, wh

ZOZO, Inc. 138 Nov 24, 2022
SeqAttack: a framework for adversarial attacks on token classification models

A framework for adversarial attacks against token classification models

Walter 23 Nov 25, 2022
Sandbox for training deep learning networks

Deep learning networks This repo is used to research convolutional networks primarily for computer vision tasks. For this purpose, the repo contains (

Oleg Sémery 2.7k Jan 01, 2023
Official implementation of "Generating 3D Molecules for Target Protein Binding"

Generating 3D Molecules for Target Protein Binding This is the official implementation of the GraphBP method proposed in the following paper. Meng Liu

DIVE Lab, Texas A&M University 74 Dec 07, 2022
pytorch implementation for Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network arXiv:1609.04802

PyTorch SRResNet Implementation of Paper: "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"(https://arxiv.org/abs

Jiu XU 436 Jan 09, 2023
Code for "Graph-Evolving Meta-Learning for Low-Resource Medical Dialogue Generation". [AAAI 2021]

Graph Evolving Meta-Learning for Low-resource Medical Dialogue Generation Code to be further cleaned... This repo contains the code of the following p

Shuai Lin 29 Nov 01, 2022
SNE-RoadSeg in PyTorch, ECCV 2020

SNE-RoadSeg Introduction This is the official PyTorch implementation of SNE-RoadSeg: Incorporating Surface Normal Information into Semantic Segmentati

242 Dec 20, 2022
Vision-Language Pre-training for Image Captioning and Question Answering

VLP This repo hosts the source code for our AAAI2020 work Vision-Language Pre-training (VLP). We have released the pre-trained model on Conceptual Cap

Luowei Zhou 373 Jan 03, 2023
Clean Machine Learning, a Coding Kata

Kata: Clean Machine Learning From Dirty Code First, open the Kata in Google Colab (or else download it) You can clone this project and launch jupyter-

Neuraxio 13 Nov 03, 2022
Implementation of popular bandit algorithms in batch environments.

batch-bandits Implementation of popular bandit algorithms in batch environments. Source code to our paper "The Impact of Batch Learning in Stochastic

Danil Provodin 2 Sep 11, 2022
Code for Universal Semi-Supervised Semantic Segmentation models paper accepted in ICCV 2019

USSS_ICCV19 Code for Universal Semi Supervised Semantic Segmentation accepted to ICCV 2019. Full Paper available at https://arxiv.org/abs/1811.10323.

Tarun K 68 Nov 24, 2022
MCMC samplers for Bayesian estimation in Python, including Metropolis-Hastings, NUTS, and Slice

Sampyl May 29, 2018: version 0.3 Sampyl is a package for sampling from probability distributions using MCMC methods. Similar to PyMC3 using theano to

Mat Leonard 304 Dec 25, 2022
Official PyTorch Implementation of "AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecasting".

AgentFormer This repo contains the official implementation of our paper: AgentFormer: Agent-Aware Transformers for Socio-Temporal Multi-Agent Forecast

Ye Yuan 161 Dec 23, 2022