Loopy belief propagation for factor graphs on discrete variables, in JAX!

Overview

continuous-integration PyPI version pre-commit.ci status codecov Documentation Status

PGMax

PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.

  • General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
  • LBP in JAX: PGMax generates pure JAX functions implementing LBP for a given factor graph. The generated pure JAX functions run on modern accelerators (GPU/TPU), work with JAX transformations (e.g. vmap for processing batches of models/samples, grad for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.

See our blog post and companion paper for more details.

Installation | Getting started

Installation

Install from PyPI

pip install pgmax

Install latest version from GitHub

pip install git+https://github.com/vicariousinc/PGMax.git

Developer

git clone https://github.com/vicariousinc/PGMax.git
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | python3 -
cd PGMax
poetry shell
poetry install
pre-commit install

Install on GPU

By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.

Getting Started

Here are a few self-contained Colab notebooks to help you get started on using PGMax:

Citing PGMax

Please consider citing our companion paper if you use PGMax in your work:

@article{zhou2022pgmax,
  author = {Zhou, Guangyao and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
  title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
  journal = {arXiv preprint arXiv:2202.04110},
  year={2022}
}

First two authors contributed equally.

Comments
  • Incomplete documentation for add_factor

    Incomplete documentation for add_factor

    The documentation for add_factor currently says:

    log_potentials – Optional array of shape (num_val_configs,) or (num_factors, num_val_configs). If specified, it contains the log of the potential value for every possible configuration. If none, it is assumed the log potential is uniform 0 and such an array is automatically initialized.

    However, in order to use it one would need to know the order in which the factors should be specified. Could this be added to the documentation?

    bug documentation 
    opened by nathanielvirgo 8
  • Add support for Python 3.9 and 3.10

    Add support for Python 3.9 and 3.10

    I don't know if this is a bug or a problem with my installation, but if I try to run a file containing only the line

    from pgmax.fg import graph
    

    I get the error

    % /opt/local/bin/python3 /Users/nathaniel/Dropbox/Code/PGMax/test01.py
    Traceback (most recent call last):
      File "/Users/nathaniel/Dropbox/Code/PGMax/test01.py", line 1, in <module>
        from pgmax.fg import graph
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/fg/graph.py", line 11, in <module>
        import pgmax.bp.infer as infer
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/bp/infer.py", line 6, in <module>
        import pgmax.bp.bp_utils as bp_utils
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/bp/bp_utils.py", line 11, in <module>
        @jax.partial(jax.jit, static_argnames="max_segment_length")
    AttributeError: module 'jax' has no attribute 'partial'
    

    This is with Python3.9 installed using Macports on MacOS 10.15.7. PGMax, jax and other prerequisities were installed with pip-3.9 install --user PGMax. I'm happy to give any other information about my installation if you can tell me how to obtain it.

    enhancement 
    opened by nathanielvirgo 6
  • Test sanity check example using new interface and inference modules, and put together the first unit test

    Test sanity check example using new interface and inference modules, and put together the first unit test

    The unit test should run fast. One option is to cache new results. Another option is to just make the model really small.

    In the process, we should also:

    1. Deprecate the current contrib module and create a new examples directory to hold everything.
    2. Start figuring out what our user facing interface should look like.
    opened by StannisZhou 6
  • FactorGraph supports any type of factors + runs specialized inference for ORFactors

    FactorGraph supports any type of factors + runs specialized inference for ORFactors

    In this PR we

    1. Redefine the factor graph abstraction by introducing factor types: factors in a graph are clustered in factor groups, which are grouped according to their factor types. See fg/graph.py
    2. Specify two types of factors: EnumerationFactor (this class already existed) and ORFactor (this new class inherits from the new LogicalFactor). Each Factor class must have its own methods to compile and concatenate Wirings for inference. See factors/enumeration.py and factors/logical.py.
    3. Make running inference in a graph agnostic to the current type of factors supported. New factors types can then be added without modifying graph.py.
    4. Implement a specialized inference for ORFactors (see pass_OR_fac_to_var_messages in factors/logical.py) and compare it with the existing one for EnumerationFactors in the unit test tests/factors/test_or.py
    opened by antoine-dedieu 5
  • RCN example

    RCN example

    This PR contains an example implementation of RCN using the pgmax package. We load a pre-trained RCN model on a very small subset of mnist (20 examples) and test on a small subset of mnist (20 examples). The reported accuracy = 0.80.

    opened by shrinuKushagra 5
  • Variables refactor

    Variables refactor

    We update the way of representing variables. In particular:

    • We get rid of variables names, as welll as of the Variables and CompositeVariableGroup classes. A variable is now represented by a tuple (variable hash, variable num_states) In particular, a FactorGraph can then directly be instantiated asfg = graph.FactorGraph(variables=[hidden_variables, visible_variables]) Similarly, Factors are defined by directly passing the variables involved, as [hidden_variables[ii], visible_variables[jj]]
    • We rewrite NDVariableArray so that the user can access variables by relying on the use of numpy arrays. We also optimize some follow-up computations.
    opened by antoine-dedieu 4
  • Numba speedup for wiring + log potentials

    Numba speedup for wiring + log potentials

    This PR is the continuation of https://github.com/vicariousinc/PGMax/pull/129 and part of our efforts to speed up the adding of FactorGroups and the wiring compilation.

    As https://github.com/vicariousinc/PGMax/pull/129 has moved most of the wiring computation to the FactorGroup level, we can now use numba for fast computation of these wirings

    As a result:

    • adding factors for the RBM exp takes 3s, building run_bp takes 1s
    • adding factors for the convor exp takes 2s, building run_bp takes 1s
    opened by antoine-dedieu 4
  • RCN implementation on a small train and test set

    RCN implementation on a small train and test set

    This PR contains the first implementation of the RCN example using the PGMax package. The file to run is examples/rcn/inference_pgmax_small.py This code contains implementation, visualization on a small set. Trained with 20 examples and tested on 20 examples.

    The inference has been separated from model creation code. Saved models are added to /storage/users/skushagra/pgmax_rcn_artifacts/ .

    Implementation on the full dataset will be implemented in a later PR.

    opened by shrinuKushagra 4
  • Make `FactorGraph` mutable to support interactive model building

    Make `FactorGraph` mutable to support interactive model building

    Should implement interface for:

    1. Add factors
    2. Set evidence for variables
    3. Initialize messages by setting messages for factors
    4. Initialize messages by spreading beliefs from variables
    enhancement 
    opened by StannisZhou 4
  • Add customized class for pairwise factors; Default to have uniform potentials

    Add customized class for pairwise factors; Default to have uniform potentials

    Currently, users have to manually create an array of all possible configs and a uniform potential, but it would be nice to do this behind-the-scenes in some easy way. Maybe we can make it so that if either of these is None during init, then we assume all possible configs or uniform potential respectively and automatically create these

    enhancement 
    opened by NishanthJKumar 4
  • Make BP closer to jax optimizer

    Make BP closer to jax optimizer

    Resolves https://github.com/vicariousinc/PGMax/issues/124

    We make graph.BP closer to JAX optimizers https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html

    opened by antoine-dedieu 3
  • Provide high-level syntax for creating factors

    Provide high-level syntax for creating factors

    One of the speed bottleneck in creating a FactorGraph is the time to create the variables_for_factors list, which is currently slow as we loop through the individual variables.

    However, in the case where all the variable groups are NDVarArr we can speed up this step a lot proposing a generic get_factors interface where the user would define the general rule for the factors and the corresponding list would be generated with numba.

    One options is to have a first argument which consists of variable groups for which we loop over dimensions, and a second argument which consists of variable groups for do not loop over, For, example get_factors({x:(i, j), y:(k, l)}, {z:(i+k, j+l)}) would mean

    factors = []
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            for k in range(y.shape[0]):
                for l in range(y.shape[1]): 
                    factors.append((x[i, j], y[k, l], z[i+k, j+l]))
    
    opened by antoine-dedieu 0
  • Modify `vars_to_starts` representation

    Modify `vars_to_starts` representation

    Creating the vars_to_starts as a dictionnary mapping variable to int is expensive in the case where we have a lot of variables.

    Instead it could map a variable group to an array (in the case of a NDVariableArray) or a list (for a VariableDict)

    opened by antoine-dedieu 0
  • Improve documentation for variables/variable groups

    Improve documentation for variables/variable groups

    Currently it's not clear what names PGMax assigns to different variables (e.g. https://github.com/vicariousinc/PGMax/issues/115). Add documentation to make this clearer.

    documentation 
    opened by StannisZhou 0
Releases(v0.4.1)
  • v0.4.1(May 19, 2022)

    Highlights

    • Fixing two minor issues when running BP with variable groups defined with different number of states by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/144

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.4.0...v0.4.1

    Source code(tar.gz)
    Source code(zip)
  • v0.4.0(May 9, 2022)

    Breaking changes

    ⚠️ This release changes the high-level API as well as import paths ⚠️

    This release makes several major breaking changes to improve usability and efficiency of the package.

    1. Interacting with variables through VarGroup objects

    We no longer refer to variables by names, but instead directly interact with VarGroup objects. This change has several implications.

    • We no longer have a Variable class. Instead, we access individual variables by indexing into VarGroup objects.

    • FactorGraph can no longer be initialized with a dictionary of variable groups (as we no longer have names for variables). Instead, we initialize a FactorGraph by

    from pgmax import fgraph
    fg = fgraph.FactorGraph(variable_groups=variable_groups)
    

    where variable_groups is either a VarGroup or a list of VarGroups.

    • We can directly construct Factor/FactorGroup using individual variables, and have a unified add_factors interface for adding Factors and FactorGroups to the FactorGraph.

    For example, we can create a PairwiseFactorGroup via:

    from pgmax import fgroup
    pairwise_factors = fgroup.PairwiseFactorGroup(
        variables_for_factors=variables_for_factors,
        log_potential_matrix=log_potential_matrix,
    )
    

    where variables_for_factors is a list of list of individual variables. And we can add factors to a FactorGraph fg by

    fg.add_factors(factors=factors)
    

    where factors can be individual Factor, individual FactorGroup, or a list of Factors and FactorGroups.

    • We access LBP results by indexing with VarGroup. For example, after running BP, we can get the MAP decoding for the VarGroup visible_variables via
    beliefs = bp.get_beliefs(bp_arrays)
    map_states_visible = infer.decode_map_states(beliefs)[visible_variables]
    

    2. Efficient construction of FactorGroup

    We have implemented efficient construction of FactorGroup. Going forward, we always recommend constructing FactorGroup instead of individual Factor.

    3. Improved LBP interface

    We first create the functions used to run BP with temperature T via

    from pgmax import infer
    bp = infer.BP(fg.bp_state, temperature=T)
    

    where bp contains functions that initialize or updates the arrays involved in LBP.

    We can initialize bp_arrays by

    bp_arrays = bp.init()
    

    apply log potentials, messages and evidence updates by

    bp_arrays = bp.update(
        bp_arrays=bp_arrays,
    	log_potentials_updates=log_potentials_updates,
    	ftov_msgs_updates=ftov_msgs_updates,
    	evidence_updates=evidence_updates,
    )
    

    and run bp for a certain number of iterations by

    bp_arrays = bp.run_bp(bp_arrays, num_iters=num_iters, damping=damping)
    

    Note that we can arbitrarily interleave bp.update with bp.run_bp, which allows flexible control over how we run LBP.

    4. Improved high-level module organization

    Now we have 5 main high-level modules, fgraph for factor graphs, factor for factors, vgroup for variable groups, fgroup for factor groups, and infer for LBP.

    Details of what has changed:

    • Speed up the process of adding Factors and compiling wiring for a FactorGraph by moving all the computations to the FactorGroup level, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/129
    • Speed up the process of computing log potentials + wiring for FactorGroup with numba, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/133
    • Make the BP class behavior closer to JAX optimizers by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/135
    • Get rid of the Variables and CompositeVariableGroup classes + of the variable names + adopt a simpler representation for variables + rely on numpy arrays to makeNDVarArray efficient, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/136
    • Overall module reorganization, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/140

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.3.0...v0.4.0

    Source code(tar.gz)
    Source code(zip)
  • v0.3.0(Mar 25, 2022)

    Highlights

    • Refactors to support adding different factor types with specialized inference procedures by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/122
    • Specialized logical AND/OR factors by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/122 https://github.com/vicariousinc/PGMax/pull/126
    • New example on 2D binary blind deconvolution by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/127

    New Contributors

    • @antoine-dedieu made his first contribution in https://github.com/vicariousinc/PGMax/pull/122

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.3...v0.3.0

    Source code(tar.gz)
    Source code(zip)
  • v0.2.3(Feb 19, 2022)

    What's Changed

    • Links to blog post and companion paper; Documentation updates by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/111
    • Get rid of redundant array shape for log_potentials by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/116
    • Support python 3.9/3.10; Improve documentation for add_factor; Bump up version for new release by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/120

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.2...v0.2.3

    Source code(tar.gz)
    Source code(zip)
  • v0.2.2(Jan 22, 2022)

    What's Changed

    • Update README by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/103
    • RCN example by @shrinuKushagra in https://github.com/vicariousinc/PGMax/pull/96
    • Add support for sum-product with temperature by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/104
    • Include Grid Markov Random Field example by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/107
    • Changes for blog post by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/109

    New Contributors

    • @shrinuKushagra made their first contribution in https://github.com/vicariousinc/PGMax/pull/96

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.1...v0.2.2

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Dec 1, 2021)

    What's Changed

    • Bump versions for publishing by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/63
    • Example notebook with PMAP sampling of RBMs trained on MNIST digits by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/80
    • Use functools.partial instead of jax.partial by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/83
    • First pass for speeding up graph and evidence operations by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/84
    • Moving to a functional interface by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/88
    • Update README in preparation for making repo public by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/89
    • Pre commit ci test by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/90
    • [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/vicariousinc/PGMax/pull/91
    • Update dependency requirements to be less aggresive by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/92
    • adds codecov badge to README! by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/94
    • Fix GPU memory leak that came up in RCN example by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/97
    • [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/vicariousinc/PGMax/pull/95
    • Docs update by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/98
    • fixes bug where wrong conf.py path was specified by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/99
    • Rtd warning fix by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/100
    • includes minor changes to README and documentation by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/101
    • Bump up version; Fixes for docs by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/102

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Sep 6, 2021)

    Features

    • Efficient and scalable max-product belief propagation using a fully flat representation
    • A general factor graph interface that supports easy specification of PGMs with pairwise factors and higher-order factors based on explicit enumeration
    • Mechanisms for evidence and message manipulation
    • 3 example notebooks showcasing the functionalities of PGMax
    Source code(tar.gz)
    Source code(zip)
Adversarial Learning for Modeling Human Motion

Adversarial Learning for Modeling Human Motion This repository contains the open source code which reproduces the results for the paper: Adversarial l

wangqi 6 Jun 15, 2021
L-Verse: Bidirectional Generation Between Image and Text

Far beyond learning long-range interactions of natural language, transformers are becoming the de-facto standard for many vision tasks with their power and scalabilty

Kim, Taehoon 102 Dec 21, 2022
Machine Unlearning with SISA

Machine Unlearning with SISA Lucas Bourtoule, Varun Chandrasekaran, Christopher Choquette-Choo, Hengrui Jia, Adelin Travers, Baiwu Zhang, David Lie, N

CleverHans Lab 70 Jan 01, 2023
Official Pytorch Code for the paper TransWeather

TransWeather Official Code for the paper TransWeather, Arxiv Tech Report 2021 Paper | Website About this repo: This repo hosts the implentation code,

Jeya Maria Jose 81 Dec 30, 2022
Reimplement of SimSwap training code

SimSwap-train Reimplement of SimSwap training code Instructions 1.Environment Preparation (1)Refer to the README document of SIMSWAP to configure the

seeprettyface.com 111 Dec 31, 2022
Collection of sports betting AI tools.

sports-betting sports-betting is a collection of tools that makes it easy to create machine learning models for sports betting and evaluate their perf

George Douzas 109 Dec 31, 2022
SustainBench: Benchmarks for Monitoring the Sustainable Development Goals with Machine Learning

Datasets | Website | Raw Data | OpenReview SustainBench: Benchmarks for Monitoring the Sustainable Development Goals with Machine Learning Christopher

67 Dec 17, 2022
Migration of Edge-based Distributed Federated Learning

FedFly: Towards Migration in Edge-based Distributed Federated Learning About the research Due to mobility, a device participating in Federated Learnin

qub-blesson 11 Nov 13, 2022
Liver segmentation using MONAI and pytorch

Machine Learning use case in the field of Healthcare. In this project MONAI and pytorch frameworks are used for 3D Liver segmentation.

Abhishek Gajbhiye 2 May 30, 2022
Train Scene Graph Generation for Visual Genome and GQA in PyTorch >= 1.2 with improved zero and few-shot generalization.

Scene Graph Generation Object Detections Ground truth Scene Graph Generated Scene Graph In this visualization, woman sitting on rock is a zero-shot tr

Boris Knyazev 93 Dec 28, 2022
DeepMReye: magnetic resonance-based eye tracking using deep neural networks

DeepMReye: magnetic resonance-based eye tracking using deep neural networks

73 Dec 21, 2022
Avatarify Python - Avatars for Zoom, Skype and other video-conferencing apps.

Avatarify Python - Avatars for Zoom, Skype and other video-conferencing apps.

Ali Aliev 15.3k Jan 05, 2023
Research using Cirq!

ReCirq Research using Cirq! This project contains modules for running quantum computing applications and experiments through Cirq and Quantum Engine.

quantumlib 230 Dec 29, 2022
利用yolov5和TensorRT从0到1实现目标检测的模型训练到模型部署全过程

写在前面 利用TensorRT加速推理速度是以时间换取精度的做法,意味着在推理速度上升的同时将会有精度的下降,不过不用太担心,精度下降微乎其微。此外,要有NVIDIA显卡,经测试,CUDA10.2可以支持20系列显卡及以下,30系列显卡需要CUDA11.x的支持,并且目前有bug。 默认你已经完成了

Helium 6 Jul 28, 2022
Combining Latent Space and Structured Kernels for Bayesian Optimization over Combinatorial Spaces

This repository contains source code for the paper Combining Latent Space and Structured Kernels for Bayesian Optimization over Combinatorial Spaces a

9 Nov 21, 2022
Single object tracking and segmentation.

Single/Multiple Object Tracking and Segmentation Codes and comparison of recent single/multiple object tracking and segmentation. News 💥 AutoMatch is

ZP ZHANG 385 Jan 02, 2023
FairMOT for Multi-Class MOT using YOLOX as Detector

FairMOT-X Project Overview FairMOT-X is a multi-class multi object tracker, which has been tailored for training on the BDD100K MOT Dataset. It makes

Jonathan Tan 33 Dec 28, 2022
《Single Image Reflection Removal Beyond Linearity》(CVPR 2019)

Single-Image-Reflection-Removal-Beyond-Linearity Paper Single Image Reflection Removal Beyond Linearity. Qiang Wen, Yinjie Tan, Jing Qin, Wenxi Liu, G

Qiang Wen 51 Jun 24, 2022
The codebase for our paper "Generative Occupancy Fields for 3D Surface-Aware Image Synthesis" (NeurIPS 2021)

Generative Occupancy Fields for 3D Surface-Aware Image Synthesis (NeurIPS 2021) Project Page | Paper Xudong Xu, Xingang Pan, Dahua Lin and Bo Dai GOF

xuxudong 97 Nov 10, 2022
Code for "Reconstructing 3D Human Pose by Watching Humans in the Mirror", CVPR 2021 oral

Reconstructing 3D Human Pose by Watching Humans in the Mirror Qi Fang*, Qing Shuai*, Junting Dong, Hujun Bao, Xiaowei Zhou CVPR 2021 Oral The videos a

ZJU3DV 178 Dec 13, 2022