Loopy belief propagation for factor graphs on discrete variables, in JAX!

Overview

continuous-integration PyPI version pre-commit.ci status codecov Documentation Status

PGMax

PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.

  • General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
  • LBP in JAX: PGMax generates pure JAX functions implementing LBP for a given factor graph. The generated pure JAX functions run on modern accelerators (GPU/TPU), work with JAX transformations (e.g. vmap for processing batches of models/samples, grad for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.

See our blog post and companion paper for more details.

Installation | Getting started

Installation

Install from PyPI

pip install pgmax

Install latest version from GitHub

pip install git+https://github.com/vicariousinc/PGMax.git

Developer

git clone https://github.com/vicariousinc/PGMax.git
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | python3 -
cd PGMax
poetry shell
poetry install
pre-commit install

Install on GPU

By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.

Getting Started

Here are a few self-contained Colab notebooks to help you get started on using PGMax:

Citing PGMax

Please consider citing our companion paper if you use PGMax in your work:

@article{zhou2022pgmax,
  author = {Zhou, Guangyao and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
  title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
  journal = {arXiv preprint arXiv:2202.04110},
  year={2022}
}

First two authors contributed equally.

Comments
  • Incomplete documentation for add_factor

    Incomplete documentation for add_factor

    The documentation for add_factor currently says:

    log_potentials – Optional array of shape (num_val_configs,) or (num_factors, num_val_configs). If specified, it contains the log of the potential value for every possible configuration. If none, it is assumed the log potential is uniform 0 and such an array is automatically initialized.

    However, in order to use it one would need to know the order in which the factors should be specified. Could this be added to the documentation?

    bug documentation 
    opened by nathanielvirgo 8
  • Add support for Python 3.9 and 3.10

    Add support for Python 3.9 and 3.10

    I don't know if this is a bug or a problem with my installation, but if I try to run a file containing only the line

    from pgmax.fg import graph
    

    I get the error

    % /opt/local/bin/python3 /Users/nathaniel/Dropbox/Code/PGMax/test01.py
    Traceback (most recent call last):
      File "/Users/nathaniel/Dropbox/Code/PGMax/test01.py", line 1, in <module>
        from pgmax.fg import graph
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/fg/graph.py", line 11, in <module>
        import pgmax.bp.infer as infer
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/bp/infer.py", line 6, in <module>
        import pgmax.bp.bp_utils as bp_utils
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/bp/bp_utils.py", line 11, in <module>
        @jax.partial(jax.jit, static_argnames="max_segment_length")
    AttributeError: module 'jax' has no attribute 'partial'
    

    This is with Python3.9 installed using Macports on MacOS 10.15.7. PGMax, jax and other prerequisities were installed with pip-3.9 install --user PGMax. I'm happy to give any other information about my installation if you can tell me how to obtain it.

    enhancement 
    opened by nathanielvirgo 6
  • Test sanity check example using new interface and inference modules, and put together the first unit test

    Test sanity check example using new interface and inference modules, and put together the first unit test

    The unit test should run fast. One option is to cache new results. Another option is to just make the model really small.

    In the process, we should also:

    1. Deprecate the current contrib module and create a new examples directory to hold everything.
    2. Start figuring out what our user facing interface should look like.
    opened by StannisZhou 6
  • FactorGraph supports any type of factors + runs specialized inference for ORFactors

    FactorGraph supports any type of factors + runs specialized inference for ORFactors

    In this PR we

    1. Redefine the factor graph abstraction by introducing factor types: factors in a graph are clustered in factor groups, which are grouped according to their factor types. See fg/graph.py
    2. Specify two types of factors: EnumerationFactor (this class already existed) and ORFactor (this new class inherits from the new LogicalFactor). Each Factor class must have its own methods to compile and concatenate Wirings for inference. See factors/enumeration.py and factors/logical.py.
    3. Make running inference in a graph agnostic to the current type of factors supported. New factors types can then be added without modifying graph.py.
    4. Implement a specialized inference for ORFactors (see pass_OR_fac_to_var_messages in factors/logical.py) and compare it with the existing one for EnumerationFactors in the unit test tests/factors/test_or.py
    opened by antoine-dedieu 5
  • RCN example

    RCN example

    This PR contains an example implementation of RCN using the pgmax package. We load a pre-trained RCN model on a very small subset of mnist (20 examples) and test on a small subset of mnist (20 examples). The reported accuracy = 0.80.

    opened by shrinuKushagra 5
  • Variables refactor

    Variables refactor

    We update the way of representing variables. In particular:

    • We get rid of variables names, as welll as of the Variables and CompositeVariableGroup classes. A variable is now represented by a tuple (variable hash, variable num_states) In particular, a FactorGraph can then directly be instantiated asfg = graph.FactorGraph(variables=[hidden_variables, visible_variables]) Similarly, Factors are defined by directly passing the variables involved, as [hidden_variables[ii], visible_variables[jj]]
    • We rewrite NDVariableArray so that the user can access variables by relying on the use of numpy arrays. We also optimize some follow-up computations.
    opened by antoine-dedieu 4
  • Numba speedup for wiring + log potentials

    Numba speedup for wiring + log potentials

    This PR is the continuation of https://github.com/vicariousinc/PGMax/pull/129 and part of our efforts to speed up the adding of FactorGroups and the wiring compilation.

    As https://github.com/vicariousinc/PGMax/pull/129 has moved most of the wiring computation to the FactorGroup level, we can now use numba for fast computation of these wirings

    As a result:

    • adding factors for the RBM exp takes 3s, building run_bp takes 1s
    • adding factors for the convor exp takes 2s, building run_bp takes 1s
    opened by antoine-dedieu 4
  • RCN implementation on a small train and test set

    RCN implementation on a small train and test set

    This PR contains the first implementation of the RCN example using the PGMax package. The file to run is examples/rcn/inference_pgmax_small.py This code contains implementation, visualization on a small set. Trained with 20 examples and tested on 20 examples.

    The inference has been separated from model creation code. Saved models are added to /storage/users/skushagra/pgmax_rcn_artifacts/ .

    Implementation on the full dataset will be implemented in a later PR.

    opened by shrinuKushagra 4
  • Make `FactorGraph` mutable to support interactive model building

    Make `FactorGraph` mutable to support interactive model building

    Should implement interface for:

    1. Add factors
    2. Set evidence for variables
    3. Initialize messages by setting messages for factors
    4. Initialize messages by spreading beliefs from variables
    enhancement 
    opened by StannisZhou 4
  • Add customized class for pairwise factors; Default to have uniform potentials

    Add customized class for pairwise factors; Default to have uniform potentials

    Currently, users have to manually create an array of all possible configs and a uniform potential, but it would be nice to do this behind-the-scenes in some easy way. Maybe we can make it so that if either of these is None during init, then we assume all possible configs or uniform potential respectively and automatically create these

    enhancement 
    opened by NishanthJKumar 4
  • Make BP closer to jax optimizer

    Make BP closer to jax optimizer

    Resolves https://github.com/vicariousinc/PGMax/issues/124

    We make graph.BP closer to JAX optimizers https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html

    opened by antoine-dedieu 3
  • Provide high-level syntax for creating factors

    Provide high-level syntax for creating factors

    One of the speed bottleneck in creating a FactorGraph is the time to create the variables_for_factors list, which is currently slow as we loop through the individual variables.

    However, in the case where all the variable groups are NDVarArr we can speed up this step a lot proposing a generic get_factors interface where the user would define the general rule for the factors and the corresponding list would be generated with numba.

    One options is to have a first argument which consists of variable groups for which we loop over dimensions, and a second argument which consists of variable groups for do not loop over, For, example get_factors({x:(i, j), y:(k, l)}, {z:(i+k, j+l)}) would mean

    factors = []
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            for k in range(y.shape[0]):
                for l in range(y.shape[1]): 
                    factors.append((x[i, j], y[k, l], z[i+k, j+l]))
    
    opened by antoine-dedieu 0
  • Modify `vars_to_starts` representation

    Modify `vars_to_starts` representation

    Creating the vars_to_starts as a dictionnary mapping variable to int is expensive in the case where we have a lot of variables.

    Instead it could map a variable group to an array (in the case of a NDVariableArray) or a list (for a VariableDict)

    opened by antoine-dedieu 0
  • Improve documentation for variables/variable groups

    Improve documentation for variables/variable groups

    Currently it's not clear what names PGMax assigns to different variables (e.g. https://github.com/vicariousinc/PGMax/issues/115). Add documentation to make this clearer.

    documentation 
    opened by StannisZhou 0
Releases(v0.4.1)
  • v0.4.1(May 19, 2022)

    Highlights

    • Fixing two minor issues when running BP with variable groups defined with different number of states by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/144

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.4.0...v0.4.1

    Source code(tar.gz)
    Source code(zip)
  • v0.4.0(May 9, 2022)

    Breaking changes

    ⚠️ This release changes the high-level API as well as import paths ⚠️

    This release makes several major breaking changes to improve usability and efficiency of the package.

    1. Interacting with variables through VarGroup objects

    We no longer refer to variables by names, but instead directly interact with VarGroup objects. This change has several implications.

    • We no longer have a Variable class. Instead, we access individual variables by indexing into VarGroup objects.

    • FactorGraph can no longer be initialized with a dictionary of variable groups (as we no longer have names for variables). Instead, we initialize a FactorGraph by

    from pgmax import fgraph
    fg = fgraph.FactorGraph(variable_groups=variable_groups)
    

    where variable_groups is either a VarGroup or a list of VarGroups.

    • We can directly construct Factor/FactorGroup using individual variables, and have a unified add_factors interface for adding Factors and FactorGroups to the FactorGraph.

    For example, we can create a PairwiseFactorGroup via:

    from pgmax import fgroup
    pairwise_factors = fgroup.PairwiseFactorGroup(
        variables_for_factors=variables_for_factors,
        log_potential_matrix=log_potential_matrix,
    )
    

    where variables_for_factors is a list of list of individual variables. And we can add factors to a FactorGraph fg by

    fg.add_factors(factors=factors)
    

    where factors can be individual Factor, individual FactorGroup, or a list of Factors and FactorGroups.

    • We access LBP results by indexing with VarGroup. For example, after running BP, we can get the MAP decoding for the VarGroup visible_variables via
    beliefs = bp.get_beliefs(bp_arrays)
    map_states_visible = infer.decode_map_states(beliefs)[visible_variables]
    

    2. Efficient construction of FactorGroup

    We have implemented efficient construction of FactorGroup. Going forward, we always recommend constructing FactorGroup instead of individual Factor.

    3. Improved LBP interface

    We first create the functions used to run BP with temperature T via

    from pgmax import infer
    bp = infer.BP(fg.bp_state, temperature=T)
    

    where bp contains functions that initialize or updates the arrays involved in LBP.

    We can initialize bp_arrays by

    bp_arrays = bp.init()
    

    apply log potentials, messages and evidence updates by

    bp_arrays = bp.update(
        bp_arrays=bp_arrays,
    	log_potentials_updates=log_potentials_updates,
    	ftov_msgs_updates=ftov_msgs_updates,
    	evidence_updates=evidence_updates,
    )
    

    and run bp for a certain number of iterations by

    bp_arrays = bp.run_bp(bp_arrays, num_iters=num_iters, damping=damping)
    

    Note that we can arbitrarily interleave bp.update with bp.run_bp, which allows flexible control over how we run LBP.

    4. Improved high-level module organization

    Now we have 5 main high-level modules, fgraph for factor graphs, factor for factors, vgroup for variable groups, fgroup for factor groups, and infer for LBP.

    Details of what has changed:

    • Speed up the process of adding Factors and compiling wiring for a FactorGraph by moving all the computations to the FactorGroup level, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/129
    • Speed up the process of computing log potentials + wiring for FactorGroup with numba, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/133
    • Make the BP class behavior closer to JAX optimizers by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/135
    • Get rid of the Variables and CompositeVariableGroup classes + of the variable names + adopt a simpler representation for variables + rely on numpy arrays to makeNDVarArray efficient, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/136
    • Overall module reorganization, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/140

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.3.0...v0.4.0

    Source code(tar.gz)
    Source code(zip)
  • v0.3.0(Mar 25, 2022)

    Highlights

    • Refactors to support adding different factor types with specialized inference procedures by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/122
    • Specialized logical AND/OR factors by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/122 https://github.com/vicariousinc/PGMax/pull/126
    • New example on 2D binary blind deconvolution by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/127

    New Contributors

    • @antoine-dedieu made his first contribution in https://github.com/vicariousinc/PGMax/pull/122

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.3...v0.3.0

    Source code(tar.gz)
    Source code(zip)
  • v0.2.3(Feb 19, 2022)

    What's Changed

    • Links to blog post and companion paper; Documentation updates by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/111
    • Get rid of redundant array shape for log_potentials by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/116
    • Support python 3.9/3.10; Improve documentation for add_factor; Bump up version for new release by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/120

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.2...v0.2.3

    Source code(tar.gz)
    Source code(zip)
  • v0.2.2(Jan 22, 2022)

    What's Changed

    • Update README by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/103
    • RCN example by @shrinuKushagra in https://github.com/vicariousinc/PGMax/pull/96
    • Add support for sum-product with temperature by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/104
    • Include Grid Markov Random Field example by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/107
    • Changes for blog post by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/109

    New Contributors

    • @shrinuKushagra made their first contribution in https://github.com/vicariousinc/PGMax/pull/96

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.1...v0.2.2

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Dec 1, 2021)

    What's Changed

    • Bump versions for publishing by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/63
    • Example notebook with PMAP sampling of RBMs trained on MNIST digits by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/80
    • Use functools.partial instead of jax.partial by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/83
    • First pass for speeding up graph and evidence operations by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/84
    • Moving to a functional interface by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/88
    • Update README in preparation for making repo public by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/89
    • Pre commit ci test by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/90
    • [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/vicariousinc/PGMax/pull/91
    • Update dependency requirements to be less aggresive by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/92
    • adds codecov badge to README! by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/94
    • Fix GPU memory leak that came up in RCN example by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/97
    • [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/vicariousinc/PGMax/pull/95
    • Docs update by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/98
    • fixes bug where wrong conf.py path was specified by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/99
    • Rtd warning fix by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/100
    • includes minor changes to README and documentation by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/101
    • Bump up version; Fixes for docs by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/102

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Sep 6, 2021)

    Features

    • Efficient and scalable max-product belief propagation using a fully flat representation
    • A general factor graph interface that supports easy specification of PGMs with pairwise factors and higher-order factors based on explicit enumeration
    • Mechanisms for evidence and message manipulation
    • 3 example notebooks showcasing the functionalities of PGMax
    Source code(tar.gz)
    Source code(zip)
Stacked Hourglass Network with a Multi-level Attention Mechanism: Where to Look for Intervertebral Disc Labeling

⚠️ ‎‎‎ A more recent and actively-maintained version of this code is available in ivadomed Stacked Hourglass Network with a Multi-level Attention Mech

Reza Azad 14 Oct 24, 2022
Official implementation of "CrossPoint: Self-Supervised Cross-Modal Contrastive Learning for 3D Point Cloud Understanding" (CVPR, 2022)

CrossPoint: Self-Supervised Cross-Modal Contrastive Learning for 3D Point Cloud Understanding (CVPR'22) Paper Link | Project Page Abstract : Manual an

Mohamed Afham 152 Dec 23, 2022
Hysterese plugin with two temperature offset areas

craftbeerpi4 plugin OffsetHysterese Temperatur-Steuerungs-Plugin mit zwei tempereaturbereich abhängigen Offsets. Installation sudo pip3 install https:

HappyHibo 1 Dec 21, 2021
Implicit Model Specialization through DAG-based Decentralized Federated Learning

Federated Learning DAG Experiments This repository contains software artifacts to reproduce the experiments presented in the Middleware '21 paper "Imp

Operating Systems and Middleware Group 5 Oct 16, 2022
TOOD: Task-aligned One-stage Object Detection, ICCV2021 Oral

One-stage object detection is commonly implemented by optimizing two sub-tasks: object classification and localization, using heads with two parallel branches, which might lead to a certain level of

264 Jan 09, 2023
Implementation of Kalman Filter in Python

Kalman Filter in Python This is a basic example of how Kalman filter works in Python. I do plan on refactoring and expanding this repo in the future.

Enoch Kan 35 Sep 11, 2022
Creative Applications of Deep Learning w/ Tensorflow

Creative Applications of Deep Learning w/ Tensorflow This repository contains lecture transcripts and homework assignments as Jupyter Notebooks for th

Parag K Mital 1.5k Dec 30, 2022
A pytorch implementation of faster RCNN detection framework (Use detectron2, it's a masterpiece)

Notice(2019.11.2) This repo was built back two years ago when there were no pytorch detection implementation that can achieve reasonable performance.

Ruotian(RT) Luo 1.8k Jan 01, 2023
Baseline inference Algorithm for the STOIC2021 challenge.

STOIC2021 Baseline Algorithm This codebase contains an example submission for the STOIC2021 COVID-19 AI Challenge. As a baseline algorithm, it impleme

Luuk Boulogne 10 Aug 08, 2022
Guiding evolutionary strategies by (inaccurate) differentiable robot simulators @ NeurIPS, 4th Robot Learning Workshop

Guiding Evolutionary Strategies by Differentiable Robot Simulators In recent years, Evolutionary Strategies were actively explored in robotic tasks fo

Vladislav Kurenkov 4 Dec 14, 2021
Allows including an action inside another action (by preprocessing the Yaml file). This is how composite actions should have worked.

actions-includes Allows including an action inside another action (by preprocessing the Yaml file). Instead of using uses or run in your action step,

Tim Ansell 70 Nov 04, 2022
Implementation of Hierarchical Transformer Memory (HTM) for Pytorch

Hierarchical Transformer Memory (HTM) - Pytorch Implementation of Hierarchical Transformer Memory (HTM) for Pytorch. This Deepmind paper proposes a si

Phil Wang 63 Dec 29, 2022
Official Implementation for the paper DeepFace-EMD: Re-ranking Using Patch-wise Earth Mover’s Distance Improves Out-Of-Distribution Face Identification

DeepFace-EMD: Re-ranking Using Patch-wise Earth Mover’s Distance Improves Out-Of-Distribution Face Identification Official Implementation for the pape

Anh M. Nguyen 36 Dec 28, 2022
A community run, 5-day PyTorch Deep Learning Bootcamp

Deep Learning Winter School, November 2107. Tel Aviv Deep Learning Bootcamp : http://deep-ml.com. About Tel-Aviv Deep Learning Bootcamp is an intensiv

Shlomo Kashani. 1.3k Sep 04, 2021
Kaggleship: Kaggle Notebooks

Kaggleship: Kaggle Notebooks This repository contains my Kaggle notebooks. They are generally about data science, machine learning, and deep learning.

Erfan Sobhaei 1 Jan 25, 2022
CrossMLP - The repository offers the official implementation of our BMVC 2021 paper (oral) in PyTorch.

CrossMLP Cascaded Cross MLP-Mixer GANs for Cross-View Image Translation Bin Ren1, Hao Tang2, Nicu Sebe1. 1University of Trento, Italy, 2ETH, Switzerla

Bingoren 16 Jul 27, 2022
This repository contains code used to audit the stability of personality predictions made by two algorithmic hiring systems

Stability Audit This repository contains code used to audit the stability of personality predictions made by two algorithmic hiring systems, Humantic

Data, Responsibly 4 Oct 27, 2022
LVI-SAM: Tightly-coupled Lidar-Visual-Inertial Odometry via Smoothing and Mapping

LVI-SAM This repository contains code for a lidar-visual-inertial odometry and mapping system, which combines the advantages of LIO-SAM and Vins-Mono

Tixiao Shan 1.1k Dec 27, 2022
Using image super resolution models with vapoursynth and speeding them up with TensorRT

vs-RealEsrganAnime-tensorrt-docker Using image super resolution models with vapoursynth and speeding them up with TensorRT. Also a docker image since

4 Aug 23, 2022
PASTRIE: A Corpus of Prepositions Annotated with Supersense Tags in Reddit International English

PASTRIE Official release of the corpus described in the paper: Michael Kranzlein, Emma Manning, Siyao Peng, Shira Wein, Aryaman Arora, and Nathan Schn

NERT @ Georgetown 4 Dec 02, 2021