Tensorflow implementation and notebooks for Implicit Maximum Likelihood Estimation

Related tags

Deep Learningtf-imle
Overview

tf-imle

Tensorflow 2 and PyTorch implementation and Jupyter notebooks for Implicit Maximum Likelihood Estimation (I-MLE) proposed in the NeurIPS 2021 paper Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions.

I-MLE is also available as a PyTorch library: https://github.com/uclnlp/torch-imle

Introduction

Implicit MLE (I-MLE) makes it possible to include discrete combinatorial optimization algorithms, such as Dijkstra's algorithm or integer linear programming (ILP) solvers, as well as complex discrete probability distributions in standard deep learning architectures. The figure below illustrates the setting I-MLE was developed for. is a standard neural network, mapping some input to the input parameters of a discrete combinatorial optimization algorithm or a discrete probability distribution, depicted as the black box. In the forward pass, the discrete component is executed and its discrete output fed into a downstream neural network . Now, with I-MLE it is possible to estimate gradients of with respect to a loss function, which are used during backpropagation to update the parameters of the upstream neural network.

Illustration of the problem addressed by I-MLE

The core idea of I-MLE is that it defines an implicit maximum likelihood objective whose gradients are used to update upstream parameters of the model. Every instance of I-MLE requires two ingredients:

  1. A method to approximately sample from a complex and possibly intractable distribution. For this we use Perturb-and-MAP (aka the Gumbel-max trick) and propose a novel family of noise perturbations tailored to the problem at hand.
  2. A method to compute a surrogate empirical distribution: Vanilla MLE reduces the KL divergence between the current distribution and the empirical distribution. Since in our setting, we do not have access to such an empirical distribution, we have to design surrogate empirical distributions which we term target distributions. Here we propose two families of target distributions which are widely applicable and work well in practice.

Requirements:

TensorFlow 2 implementation:

  • tensorflow==2.3.0 or tensorflow-gpu==2.3.0
  • numpy==1.18.5
  • matplotlib==3.1.1
  • scikit-learn==0.24.1
  • tensorflow-probability==0.7.0

PyTorch implementation:

Example: I-MLE as a Layer

The following is an instance of I-MLE implemented as a layer. This is a class where the optimization problem is computing the k-subset configuration, the target distribution is based on perturbation-based implicit differentiation, and the perturb-and-MAP noise perturbations are drawn from the sum-of-gamma distribution.

class IMLESubsetkLayer(tf.keras.layers.Layer):
    
    def __init__(self, k, _tau=10.0, _lambda=10.0):
        super(IMLESubsetkLayer, self).__init__()
        # average number of 1s in a solution to the optimization problem
        self.k = k
        # the temperature at which we want to sample
        self._tau = _tau
        # the perturbation strength (here we use a target distribution based on perturbation-based implicit differentiation
        self._lambda = _lambda  
        # the samples we store for the backward pass
        self.samples = None 
        
    @tf.function
    def sample_sum_of_gamma(self, shape):
        
        s = tf.map_fn(fn=lambda t: tf.random.gamma(shape, 1.0/self.k, self.k/t), 
                  elems=tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))   
        # now add the samples
        s = tf.reduce_sum(s, 0)
        # the log(m) term
        s = s - tf.math.log(10.0)
        # divide by k --> each s[c] has k samples whose sum is distributed as Gumbel(0, 1)
        s = self._tau * (s / self.k)

        return s
    
    @tf.function
    def sample_discrete_forward(self, logits): 
        self.samples = self.sample_sum_of_gamma(tf.shape(logits))
        gamma_perturbed_logits = logits + self.samples
        # gamma_perturbed_logits is the input to the combinatorial opt algorithm
        # the next two lines can be replaced by a custom black-box algorithm call
        threshold = tf.expand_dims(tf.nn.top_k(gamma_perturbed_logits, self.k, sorted=True)[0][:,-1], -1)
        y = tf.cast(tf.greater_equal(gamma_perturbed_logits, threshold), tf.float32)
        
        return y
    
    @tf.function
    def sample_discrete_backward(self, logits):     
        gamma_perturbed_logits = logits + self.samples
        # gamma_perturbed_logits is the input to the combinatorial opt algorithm
        # the next two lines can be replaced by a custom black-box algorithm call
        threshold = tf.expand_dims(tf.nn.top_k(gamma_perturbed_logits, self.k, sorted=True)[0][:,-1], -1)
        y = tf.cast(tf.greater_equal(gamma_perturbed_logits, threshold), tf.float32)
        return y
    
    @tf.custom_gradient
    def subset_k(self, logits, k):

        # sample discretely with perturb and map
        z_train = self.sample_discrete_forward(logits)
        # compute the top-k discrete values
        threshold = tf.expand_dims(tf.nn.top_k(logits, self.k, sorted=True)[0][:,-1], -1)
        z_test = tf.cast(tf.greater_equal(logits, threshold), tf.float32)
        # at training time we sample, at test time we take the argmax
        z_output = K.in_train_phase(z_train, z_test)
        
        def custom_grad(dy):

            # we perturb (implicit diff) and then resuse sample for perturb and MAP
            map_dy = self.sample_discrete_backward(logits - (self._lambda*dy))
            # we now compute the gradients as the difference (I-MLE gradients)
            grad = tf.math.subtract(z_train, map_dy)
            # return the gradient            
            return grad, k

        return z_output, custom_grad

Reference

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}
Owner
NEC Laboratories Europe
Research software developed at NEC Laboratories Europe
NEC Laboratories Europe
TCTrack: Temporal Contexts for Aerial Tracking (CVPR2022)

TCTrack: Temporal Contexts for Aerial Tracking (CVPR2022) Ziang Cao and Ziyuan Huang and Liang Pan and Shiwei Zhang and Ziwei Liu and Changhong Fu In

Intelligent Vision for Robotics in Complex Environment 100 Dec 19, 2022
Back to Event Basics: SSL of Image Reconstruction for Event Cameras

Back to Event Basics: SSL of Image Reconstruction for Event Cameras Minimal code for Back to Event Basics: Self-Supervised Learning of Image Reconstru

TU Delft 42 Dec 26, 2022
Code implementation from my Medium blog post: [Transformers from Scratch in PyTorch]

transformer-from-scratch Code for my Medium blog post: Transformers from Scratch in PyTorch Note: This Transformer code does not include masked attent

Frank Odom 27 Dec 21, 2022
Official implementation of "Variable-Rate Deep Image Compression through Spatially-Adaptive Feature Transform", ICCV 2021

Variable-Rate Deep Image Compression through Spatially-Adaptive Feature Transform This repository is the implementation of "Variable-Rate Deep Image C

Myungseo Song 47 Dec 13, 2022
[2021 MultiMedia] CONQUER: Contextual Query-aware Ranking for Video Corpus Moment Retrieval

CONQUER: Contexutal Query-aware Ranking for Video Corpus Moment Retreival PyTorch implementation of CONQUER: Contexutal Query-aware Ranking for Video

Hou zhijian 23 Dec 26, 2022
Matthew Colbrook 1 Apr 08, 2022
RoMA: Robust Model Adaptation for Offline Model-based Optimization

RoMA: Robust Model Adaptation for Offline Model-based Optimization Implementation of RoMA: Robust Model Adaptation for Offline Model-based Optimizatio

9 Oct 31, 2022
GAN-STEM-Conv2MultiSlice - Exploring Generative Adversarial Networks for Image-to-Image Translation in STEM Simulation

GAN-STEM-Conv2MultiSlice GAN method to help covert lower resolution STEM images generated by convolution methods to higher resolution STEM images gene

UW-Madison Computational Materials Group 2 Feb 10, 2021
Code for DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents

DeepXML Code for DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents Architectures and algorithms DeepXML supports

Extreme Classification 49 Nov 06, 2022
Python package to generate image embeddings with CLIP without PyTorch/TensorFlow

imgbeddings A Python package to generate embedding vectors from images, using OpenAI's robust CLIP model via Hugging Face transformers. These image em

Max Woolf 81 Jan 04, 2023
Code for Fold2Seq paper from ICML 2021

[ICML2021] Fold2Seq: A Joint Sequence(1D)-Fold(3D) Embedding-based Generative Model for Protein Design Environment file: environment.yml Data and Feat

International Business Machines 43 Dec 04, 2022
The reference baseline of final exam for XMU machine learning course

Mini-NICO Baseline The baseline is a reference method for the final exam of machine learning course. Requirements Installation we use /python3.7 /torc

JoaquinChou 3 Dec 29, 2021
Analyses of the individual electric field magnitudes with Roast.

Aloi Davide - PhD Student (UoB) Analysis of electric field magnitudes (wp2a dataset only at the moment) and correlation analysis with Dynamic Causal M

Davide Aloi 7 Dec 15, 2022
Unofficial JAX implementations of Deep Learning models

JAX Models Table of Contents About The Project Getting Started Prerequisites Installation Usage Contributing License Contact About The Project The JAX

107 Jan 05, 2023
Simply enable or disable your Nvidia dGPU

EnvyControl (WIP) Simply enable or disable your Nvidia dGPU Usage First clone this repo and install envycontrol with sudo pip install . CLI Turn off y

Victor Bayas 292 Jan 03, 2023
Concept drift monitoring for HA model servers.

{Fast, Correct, Simple} - pick three Easily compare training and production ML data & model distributions Goals Boxkite is an instrumentation library

98 Dec 15, 2022
DeepCAD: A Deep Generative Network for Computer-Aided Design Models

DeepCAD This repository provides source code for our paper: DeepCAD: A Deep Generative Network for Computer-Aided Design Models Rundi Wu, Chang Xiao,

Rundi Wu 85 Dec 31, 2022
CARL provides highly configurable contextual extensions to several well-known RL environments.

CARL (context adaptive RL) provides highly configurable contextual extensions to several well-known RL environments.

AutoML-Freiburg-Hannover 51 Dec 28, 2022
Implementation of "Efficient Regional Memory Network for Video Object Segmentation" (Xie et al., CVPR 2021).

RMNet This repository contains the source code for the paper Efficient Regional Memory Network for Video Object Segmentation. Cite this work @inprocee

Haozhe Xie 76 Dec 14, 2022
这是一个利用facenet和retinaface实现人脸识别的库,可以进行在线的人脸识别。

Facenet+Retinaface:人脸识别模型在Keras当中的实现 目录 注意事项 Attention 所需环境 Environment 文件下载 Download 预测步骤 How2predict 参考资料 Reference 注意事项 该库中包含了两个网络,分别是retinaface和fa

Bubbliiiing 31 Nov 15, 2022