Loopy belief propagation for factor graphs on discrete variables, in JAX!

Overview

continuous-integration PyPI version pre-commit.ci status codecov Documentation Status

PGMax

PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.

  • General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
  • LBP in JAX: PGMax generates pure JAX functions implementing LBP for a given factor graph. The generated pure JAX functions run on modern accelerators (GPU/TPU), work with JAX transformations (e.g. vmap for processing batches of models/samples, grad for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.

See our blog post and companion paper for more details.

Installation | Getting started

Installation

Install from PyPI

pip install pgmax

Install latest version from GitHub

pip install git+https://github.com/vicariousinc/PGMax.git

Developer

git clone https://github.com/vicariousinc/PGMax.git
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/install-poetry.py | python3 -
cd PGMax
poetry shell
poetry install
pre-commit install

Install on GPU

By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.

Getting Started

Here are a few self-contained Colab notebooks to help you get started on using PGMax:

Citing PGMax

Please consider citing our companion paper if you use PGMax in your work:

@article{zhou2022pgmax,
  author = {Zhou, Guangyao and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
  title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
  journal = {arXiv preprint arXiv:2202.04110},
  year={2022}
}

First two authors contributed equally.

Comments
  • Incomplete documentation for add_factor

    Incomplete documentation for add_factor

    The documentation for add_factor currently says:

    log_potentials – Optional array of shape (num_val_configs,) or (num_factors, num_val_configs). If specified, it contains the log of the potential value for every possible configuration. If none, it is assumed the log potential is uniform 0 and such an array is automatically initialized.

    However, in order to use it one would need to know the order in which the factors should be specified. Could this be added to the documentation?

    bug documentation 
    opened by nathanielvirgo 8
  • Add support for Python 3.9 and 3.10

    Add support for Python 3.9 and 3.10

    I don't know if this is a bug or a problem with my installation, but if I try to run a file containing only the line

    from pgmax.fg import graph
    

    I get the error

    % /opt/local/bin/python3 /Users/nathaniel/Dropbox/Code/PGMax/test01.py
    Traceback (most recent call last):
      File "/Users/nathaniel/Dropbox/Code/PGMax/test01.py", line 1, in <module>
        from pgmax.fg import graph
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/fg/graph.py", line 11, in <module>
        import pgmax.bp.infer as infer
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/bp/infer.py", line 6, in <module>
        import pgmax.bp.bp_utils as bp_utils
      File "/Users/nathaniel/Library/Python/3.9/lib/python/site-packages/pgmax/bp/bp_utils.py", line 11, in <module>
        @jax.partial(jax.jit, static_argnames="max_segment_length")
    AttributeError: module 'jax' has no attribute 'partial'
    

    This is with Python3.9 installed using Macports on MacOS 10.15.7. PGMax, jax and other prerequisities were installed with pip-3.9 install --user PGMax. I'm happy to give any other information about my installation if you can tell me how to obtain it.

    enhancement 
    opened by nathanielvirgo 6
  • Test sanity check example using new interface and inference modules, and put together the first unit test

    Test sanity check example using new interface and inference modules, and put together the first unit test

    The unit test should run fast. One option is to cache new results. Another option is to just make the model really small.

    In the process, we should also:

    1. Deprecate the current contrib module and create a new examples directory to hold everything.
    2. Start figuring out what our user facing interface should look like.
    opened by StannisZhou 6
  • FactorGraph supports any type of factors + runs specialized inference for ORFactors

    FactorGraph supports any type of factors + runs specialized inference for ORFactors

    In this PR we

    1. Redefine the factor graph abstraction by introducing factor types: factors in a graph are clustered in factor groups, which are grouped according to their factor types. See fg/graph.py
    2. Specify two types of factors: EnumerationFactor (this class already existed) and ORFactor (this new class inherits from the new LogicalFactor). Each Factor class must have its own methods to compile and concatenate Wirings for inference. See factors/enumeration.py and factors/logical.py.
    3. Make running inference in a graph agnostic to the current type of factors supported. New factors types can then be added without modifying graph.py.
    4. Implement a specialized inference for ORFactors (see pass_OR_fac_to_var_messages in factors/logical.py) and compare it with the existing one for EnumerationFactors in the unit test tests/factors/test_or.py
    opened by antoine-dedieu 5
  • RCN example

    RCN example

    This PR contains an example implementation of RCN using the pgmax package. We load a pre-trained RCN model on a very small subset of mnist (20 examples) and test on a small subset of mnist (20 examples). The reported accuracy = 0.80.

    opened by shrinuKushagra 5
  • Variables refactor

    Variables refactor

    We update the way of representing variables. In particular:

    • We get rid of variables names, as welll as of the Variables and CompositeVariableGroup classes. A variable is now represented by a tuple (variable hash, variable num_states) In particular, a FactorGraph can then directly be instantiated asfg = graph.FactorGraph(variables=[hidden_variables, visible_variables]) Similarly, Factors are defined by directly passing the variables involved, as [hidden_variables[ii], visible_variables[jj]]
    • We rewrite NDVariableArray so that the user can access variables by relying on the use of numpy arrays. We also optimize some follow-up computations.
    opened by antoine-dedieu 4
  • Numba speedup for wiring + log potentials

    Numba speedup for wiring + log potentials

    This PR is the continuation of https://github.com/vicariousinc/PGMax/pull/129 and part of our efforts to speed up the adding of FactorGroups and the wiring compilation.

    As https://github.com/vicariousinc/PGMax/pull/129 has moved most of the wiring computation to the FactorGroup level, we can now use numba for fast computation of these wirings

    As a result:

    • adding factors for the RBM exp takes 3s, building run_bp takes 1s
    • adding factors for the convor exp takes 2s, building run_bp takes 1s
    opened by antoine-dedieu 4
  • RCN implementation on a small train and test set

    RCN implementation on a small train and test set

    This PR contains the first implementation of the RCN example using the PGMax package. The file to run is examples/rcn/inference_pgmax_small.py This code contains implementation, visualization on a small set. Trained with 20 examples and tested on 20 examples.

    The inference has been separated from model creation code. Saved models are added to /storage/users/skushagra/pgmax_rcn_artifacts/ .

    Implementation on the full dataset will be implemented in a later PR.

    opened by shrinuKushagra 4
  • Make `FactorGraph` mutable to support interactive model building

    Make `FactorGraph` mutable to support interactive model building

    Should implement interface for:

    1. Add factors
    2. Set evidence for variables
    3. Initialize messages by setting messages for factors
    4. Initialize messages by spreading beliefs from variables
    enhancement 
    opened by StannisZhou 4
  • Add customized class for pairwise factors; Default to have uniform potentials

    Add customized class for pairwise factors; Default to have uniform potentials

    Currently, users have to manually create an array of all possible configs and a uniform potential, but it would be nice to do this behind-the-scenes in some easy way. Maybe we can make it so that if either of these is None during init, then we assume all possible configs or uniform potential respectively and automatically create these

    enhancement 
    opened by NishanthJKumar 4
  • Make BP closer to jax optimizer

    Make BP closer to jax optimizer

    Resolves https://github.com/vicariousinc/PGMax/issues/124

    We make graph.BP closer to JAX optimizers https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html

    opened by antoine-dedieu 3
  • Provide high-level syntax for creating factors

    Provide high-level syntax for creating factors

    One of the speed bottleneck in creating a FactorGraph is the time to create the variables_for_factors list, which is currently slow as we loop through the individual variables.

    However, in the case where all the variable groups are NDVarArr we can speed up this step a lot proposing a generic get_factors interface where the user would define the general rule for the factors and the corresponding list would be generated with numba.

    One options is to have a first argument which consists of variable groups for which we loop over dimensions, and a second argument which consists of variable groups for do not loop over, For, example get_factors({x:(i, j), y:(k, l)}, {z:(i+k, j+l)}) would mean

    factors = []
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            for k in range(y.shape[0]):
                for l in range(y.shape[1]): 
                    factors.append((x[i, j], y[k, l], z[i+k, j+l]))
    
    opened by antoine-dedieu 0
  • Modify `vars_to_starts` representation

    Modify `vars_to_starts` representation

    Creating the vars_to_starts as a dictionnary mapping variable to int is expensive in the case where we have a lot of variables.

    Instead it could map a variable group to an array (in the case of a NDVariableArray) or a list (for a VariableDict)

    opened by antoine-dedieu 0
  • Improve documentation for variables/variable groups

    Improve documentation for variables/variable groups

    Currently it's not clear what names PGMax assigns to different variables (e.g. https://github.com/vicariousinc/PGMax/issues/115). Add documentation to make this clearer.

    documentation 
    opened by StannisZhou 0
Releases(v0.4.1)
  • v0.4.1(May 19, 2022)

    Highlights

    • Fixing two minor issues when running BP with variable groups defined with different number of states by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/144

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.4.0...v0.4.1

    Source code(tar.gz)
    Source code(zip)
  • v0.4.0(May 9, 2022)

    Breaking changes

    ⚠️ This release changes the high-level API as well as import paths ⚠️

    This release makes several major breaking changes to improve usability and efficiency of the package.

    1. Interacting with variables through VarGroup objects

    We no longer refer to variables by names, but instead directly interact with VarGroup objects. This change has several implications.

    • We no longer have a Variable class. Instead, we access individual variables by indexing into VarGroup objects.

    • FactorGraph can no longer be initialized with a dictionary of variable groups (as we no longer have names for variables). Instead, we initialize a FactorGraph by

    from pgmax import fgraph
    fg = fgraph.FactorGraph(variable_groups=variable_groups)
    

    where variable_groups is either a VarGroup or a list of VarGroups.

    • We can directly construct Factor/FactorGroup using individual variables, and have a unified add_factors interface for adding Factors and FactorGroups to the FactorGraph.

    For example, we can create a PairwiseFactorGroup via:

    from pgmax import fgroup
    pairwise_factors = fgroup.PairwiseFactorGroup(
        variables_for_factors=variables_for_factors,
        log_potential_matrix=log_potential_matrix,
    )
    

    where variables_for_factors is a list of list of individual variables. And we can add factors to a FactorGraph fg by

    fg.add_factors(factors=factors)
    

    where factors can be individual Factor, individual FactorGroup, or a list of Factors and FactorGroups.

    • We access LBP results by indexing with VarGroup. For example, after running BP, we can get the MAP decoding for the VarGroup visible_variables via
    beliefs = bp.get_beliefs(bp_arrays)
    map_states_visible = infer.decode_map_states(beliefs)[visible_variables]
    

    2. Efficient construction of FactorGroup

    We have implemented efficient construction of FactorGroup. Going forward, we always recommend constructing FactorGroup instead of individual Factor.

    3. Improved LBP interface

    We first create the functions used to run BP with temperature T via

    from pgmax import infer
    bp = infer.BP(fg.bp_state, temperature=T)
    

    where bp contains functions that initialize or updates the arrays involved in LBP.

    We can initialize bp_arrays by

    bp_arrays = bp.init()
    

    apply log potentials, messages and evidence updates by

    bp_arrays = bp.update(
        bp_arrays=bp_arrays,
    	log_potentials_updates=log_potentials_updates,
    	ftov_msgs_updates=ftov_msgs_updates,
    	evidence_updates=evidence_updates,
    )
    

    and run bp for a certain number of iterations by

    bp_arrays = bp.run_bp(bp_arrays, num_iters=num_iters, damping=damping)
    

    Note that we can arbitrarily interleave bp.update with bp.run_bp, which allows flexible control over how we run LBP.

    4. Improved high-level module organization

    Now we have 5 main high-level modules, fgraph for factor graphs, factor for factors, vgroup for variable groups, fgroup for factor groups, and infer for LBP.

    Details of what has changed:

    • Speed up the process of adding Factors and compiling wiring for a FactorGraph by moving all the computations to the FactorGroup level, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/129
    • Speed up the process of computing log potentials + wiring for FactorGroup with numba, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/133
    • Make the BP class behavior closer to JAX optimizers by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/135
    • Get rid of the Variables and CompositeVariableGroup classes + of the variable names + adopt a simpler representation for variables + rely on numpy arrays to makeNDVarArray efficient, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/136
    • Overall module reorganization, by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/140

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.3.0...v0.4.0

    Source code(tar.gz)
    Source code(zip)
  • v0.3.0(Mar 25, 2022)

    Highlights

    • Refactors to support adding different factor types with specialized inference procedures by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/122
    • Specialized logical AND/OR factors by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/122 https://github.com/vicariousinc/PGMax/pull/126
    • New example on 2D binary blind deconvolution by @antoine-dedieu in https://github.com/vicariousinc/PGMax/pull/127

    New Contributors

    • @antoine-dedieu made his first contribution in https://github.com/vicariousinc/PGMax/pull/122

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.3...v0.3.0

    Source code(tar.gz)
    Source code(zip)
  • v0.2.3(Feb 19, 2022)

    What's Changed

    • Links to blog post and companion paper; Documentation updates by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/111
    • Get rid of redundant array shape for log_potentials by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/116
    • Support python 3.9/3.10; Improve documentation for add_factor; Bump up version for new release by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/120

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.2...v0.2.3

    Source code(tar.gz)
    Source code(zip)
  • v0.2.2(Jan 22, 2022)

    What's Changed

    • Update README by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/103
    • RCN example by @shrinuKushagra in https://github.com/vicariousinc/PGMax/pull/96
    • Add support for sum-product with temperature by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/104
    • Include Grid Markov Random Field example by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/107
    • Changes for blog post by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/109

    New Contributors

    • @shrinuKushagra made their first contribution in https://github.com/vicariousinc/PGMax/pull/96

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.1...v0.2.2

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Dec 1, 2021)

    What's Changed

    • Bump versions for publishing by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/63
    • Example notebook with PMAP sampling of RBMs trained on MNIST digits by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/80
    • Use functools.partial instead of jax.partial by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/83
    • First pass for speeding up graph and evidence operations by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/84
    • Moving to a functional interface by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/88
    • Update README in preparation for making repo public by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/89
    • Pre commit ci test by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/90
    • [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/vicariousinc/PGMax/pull/91
    • Update dependency requirements to be less aggresive by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/92
    • adds codecov badge to README! by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/94
    • Fix GPU memory leak that came up in RCN example by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/97
    • [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/vicariousinc/PGMax/pull/95
    • Docs update by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/98
    • fixes bug where wrong conf.py path was specified by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/99
    • Rtd warning fix by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/100
    • includes minor changes to README and documentation by @NishanthJKumar in https://github.com/vicariousinc/PGMax/pull/101
    • Bump up version; Fixes for docs by @StannisZhou in https://github.com/vicariousinc/PGMax/pull/102

    Full Changelog: https://github.com/vicariousinc/PGMax/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Sep 6, 2021)

    Features

    • Efficient and scalable max-product belief propagation using a fully flat representation
    • A general factor graph interface that supports easy specification of PGMs with pairwise factors and higher-order factors based on explicit enumeration
    • Mechanisms for evidence and message manipulation
    • 3 example notebooks showcasing the functionalities of PGMax
    Source code(tar.gz)
    Source code(zip)
This tutorial aims to learn the basics of deep learning by hands, and master the basics through combination of lectures and exercises

2021-Deep-learning This tutorial aims to learn the basics of deep learning by hands, and master the basics through combination of paper and exercises.

108 Feb 24, 2022
AdaSpeech 2: Adaptive Text to Speech with Untranscribed Data

AdaSpeech 2: Adaptive Text to Speech with Untranscribed Data [WIP] Unofficial Pytorch implementation of AdaSpeech 2. Requirements : All code written i

Rishikesh (ऋषिकेश) 63 Dec 28, 2022
PyTorch implementation of "PatchGame: Learning to Signal Mid-level Patches in Referential Games" to appear in NeurIPS 2021

PatchGame: Learning to Signal Mid-level Patches in Referential Games This repository is the official implementation of the paper - "PatchGame: Learnin

Kamal Gupta 22 Mar 16, 2022
Here I will explain the flow to deploy your custom deep learning models on Ultra96V2.

Xilinx_Vitis_AI This repo will help you to Deploy your Deep Learning Model on Ultra96v2 Board. Prerequisites Vitis Core Development Kit 2019.2 This co

Amin Mamandipoor 1 Feb 08, 2022
An implementation on "Curved-Voxel Clustering for Accurate Segmentation of 3D LiDAR Point Clouds with Real-Time Performance"

Lidar-Segementation An implementation on "Curved-Voxel Clustering for Accurate Segmentation of 3D LiDAR Point Clouds with Real-Time Performance" from

Wangxu1996 135 Jan 06, 2023
Official code repository for Continual Learning In Environments With Polynomial Mixing Times

Official code for Continual Learning In Environments With Polynomial Mixing Times Continual Learning in Environments with Polynomial Mixing Times This

Sharath Raparthy 1 Dec 19, 2021
Official Codes for Graph Modularity:Towards Understanding the Cross-Layer Transition of Feature Representations in Deep Neural Networks.

Dynamic-Graphs-Construction Official Codes for Graph Modularity:Towards Understanding the Cross-Layer Transition of Feature Representations in Deep Ne

11 Dec 14, 2022
Malware Bypass Research using Reinforcement Learning

Malware Bypass Research using Reinforcement Learning

Bobby Filar 76 Dec 26, 2022
Code for "LoFTR: Detector-Free Local Feature Matching with Transformers", CVPR 2021

LoFTR: Detector-Free Local Feature Matching with Transformers Project Page | Paper LoFTR: Detector-Free Local Feature Matching with Transformers Jiami

ZJU3DV 1.4k Jan 04, 2023
A symbolic-model-guided fuzzer for TLS

tlspuffin TLS Protocol Under FuzzINg A symbolic-model-guided fuzzer for TLS Master Thesis | Thesis Presentation | Documentation Disclaimer: The term "

69 Dec 20, 2022
Breast cancer is been classified into benign tumour and malignant tumour.

Breast cancer is been classified into benign tumour and malignant tumour. Logistic regression is applied in this model.

1 Feb 04, 2022
Computer Vision application in the web

Computer Vision application in the web Preview Usage Clone this repo git clone https://github.com/amineHY/WebApp-Computer-Vision-streamlit.git cd Web

Amine Hadj-Youcef. PhD 35 Dec 06, 2022
Latte: Cross-framework Python Package for Evaluation of Latent-based Generative Models

Cross-framework Python Package for Evaluation of Latent-based Generative Models Latte Latte (for LATent Tensor Evaluation) is a cross-framework Python

Karn Watcharasupat 30 Sep 08, 2022
TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

Microsoft 1.3k Dec 30, 2022
Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORAL)

Scribble-Supervised LiDAR Semantic Segmentation Dataset and code release for the paper Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORA

102 Dec 25, 2022
Learning Logic Rules for Document-Level Relation Extraction

LogiRE Learning Logic Rules for Document-Level Relation Extraction We propose to introduce logic rules to tackle the challenges of doc-level RE. Equip

41 Dec 26, 2022
Convert Python 3 code to CUDA code.

Py2CUDA Convert python code to CUDA. Usage To convert a python file say named py_file.py to CUDA, run python generate_cuda.py --file py_file.py --arch

Yuval Rosen 3 Jul 14, 2021
Framework for training options with different attention mechanism and using them to solve downstream tasks.

Using Attention in HRL Framework for training options with different attention mechanism and using them to solve downstream tasks. Requirements GPU re

5 Nov 03, 2022
This project implements "virtual speed" from heart rate monito

ANT+ Virtual Stride Based Speed and Distance Monitor Overview This project imple

2 May 20, 2022
1st place solution in CCF BDCI 2021 ULSEG challenge

1st place solution in CCF BDCI 2021 ULSEG challenge This is the source code of the 1st place solution for ultrasound image angioma segmentation task (

Chenxu Peng 30 Nov 22, 2022