Fast, general, and tested differentiable structured prediction in PyTorch

Overview

Torch-Struct: Structured Prediction Library

Tests Coverage Status

A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.

  • HMM / LinearChain-CRF
  • HSMM / SemiMarkov-CRF
  • Dependency Tree-CRF
  • PCFG Binary Tree-CRF
  • ...

Designed to be used as efficient batched layers in other PyTorch code.

Tutorial paper describing methodology.

Getting Started

!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5) 
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])

png

# Compute marginals
show(dist.marginals[0])

png

# Compute argmax
show(dist.argmax.detach()[0])

png

# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
# Compute samples 
show(dist.sample((1,)).detach()[0, 0])

png

# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])

png

png

# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10) 
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))

png

Library

Full docs: http://nlp.seas.harvard.edu/pytorch-struct/

Current distributions implemented:

  • LinearChainCRF
  • SemiMarkovCRF
  • DependencyCRF
  • NonProjectiveDependencyCRF
  • TreeCRF
  • NeuralPCFG / NeuralHMM

Each distribution includes:

  • Argmax, sampling, entropy, partition, masking, log_probs, k-max

Extensions:

  • Integration with torchtext, pytorch-transformers, dgl
  • Adapters for generative structured models (CFG / HMM / HSMM)
  • Common tree structured parameterizations TreeLSTM / SpanLSTM

Low-level API:

Everything implemented through semiring dynamic programming.

  • Log Marginals
  • Max and MAP computation
  • Sampling through specialized backprop
  • Entropy and first-order semirings.

Examples

Citation

@misc{alex2020torchstruct,
    title={Torch-Struct: Deep Structured Prediction Library},
    author={Alexander M. Rush},
    year={2020},
    eprint={2002.00876},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

This work was partially supported by NSF grant IIS-1901030.

Comments
  • add tests for CKY

    add tests for CKY

    This PR fixes several bugs in k-best parsing with dist.topk() and includes a simple test to test the function.

    I made incremental changes so that existing modules relying on the CKY will not be affected.

    opened by zhaoyanpeng 8
  • 1st order cky implementation

    1st order cky implementation

    Hi,

    I'd like to contribute this implementation of a first-order cky-style crf with anchored rule potentials: $\phi[i,j,k,A,B,C] := \phi(A_{i,j} \rightarrow B_{i,k}, C{k+1,j})$.

    I also added code to the _Struct class that allows calculating marginals even if input tensors don't require a gradient (i.e., after model.eval())

    Please let me know if you'd like to see any changes.

    Thanks, Tom

    opened by teffland 6
  • Mini-batch setting with Semi Markov CRF

    Mini-batch setting with Semi Markov CRF

    I encounter learning instability when using a batch size > 1 with the semi-markovian CRF (loss goes to very large negative number), even when explicitly providing "lengths". I think the bug comes from the masking. The model train well when setting batch size 1.

    opened by urchade 5
  • Release on PyPI?

    Release on PyPI?

    Is there any interest on releasing pytorch-struct (and genbmm) on the official Python Package Index?

    I ran into this because I distribute my constituency parser on PyPI, and I just recently pushed a new version that depends on pytorch-struct: https://pypi.org/project/benepar/0.2.0a0/

    It turns out that packages on PyPI aren't allowed to depend on packages only hosted on github, so users of my parser can't just pip install benepar and have it work right away.

    opened by nikitakit 5
  • up sweep and down sweep

    up sweep and down sweep

    I'm interested in the parallel scan algorithm for the linear-chain CRF.

    I read the related paper in the tutorial and found that there are two steps: up sweep and down sweep in order to obtain all-prefix-sum.

    I think in this case, we use that algorithm to obtain all Z(x) with different lengths in a batch. But seems I couldn't find out the down sweep code in the repo. Can you point me out there?

    opened by allanj 5
  • [Bug] Implementation of Eisner's algorithm does not restrict the root number to 1

    [Bug] Implementation of Eisner's algorithm does not restrict the root number to 1

    Hey, I found that your implementation of Eisner's algorithm admits arbitrary root number, which is a very severe bug since dependency parsing usually has only one root token.

    In your DepTree.dp() method, you make a conversion to let the root token as the first token in the sentence. Imagine that the root x{0} attacks word x_{i}, I_{0,0} + C_{1, i} = I_{0, i} and I_{0, i} + C_{i,j} = C_{0, j} for some j < L where L is the length of sentence. Now complete span C_{0, j} still have opportunity to attach a new word x_{k} for j< k<=L, making multiple root attachment possible.

    Fortunately, I made some changes to your codes to restrict the root number to 1.

    ` def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True): if arc_scores_in.dim() not in (3, 4): raise ValueError("potentials must have dim of 3 (unlabeled) or 4 (labeled)")

        labeled = arc_scores_in.dim() == 4
        semiring = self.semiring
        # arc_scores_in = _convert(arc_scores_in)
        arc_scores_in, batch, N, lengths = self._check_potentials(
            arc_scores_in, lengths
        )
        arc_scores_in.requires_grad_(True)
        arc_scores = semiring.sum(arc_scores_in) if labeled else arc_scores_in
        alpha = [
            [
                [
                    Chart((batch, N, N), arc_scores, semiring, cache=cache)
                    for _ in range(2)
                ]
                for _ in range(2)
            ]
            for _ in range(2)
        ]
    
        semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
        semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
        semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
        semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)
    
    
        for k in range(1, N):
            f = torch.arange(N - k), torch.arange(k, N)
            ACL = alpha[A][C][L][: N - k, :k]
            ACR = alpha[A][C][R][: N - k, :k]
            BCL = alpha[B][C][L][k:, N - k :]
            BCR = alpha[B][C][R][k:, N - k :]
            x = semiring.dot(ACR, BCL)
            arcs_l = semiring.times(x, arc_scores[:, :, f[1], f[0]])
            alpha[A][I][L][: N - k, k] = arcs_l
            alpha[B][I][L][k:N, N - k - 1] = arcs_l
            arcs_r = semiring.times(x, arc_scores[:, :, f[0], f[1]])
            alpha[A][I][R][:N - k, k] = arcs_r
            alpha[B][I][R][k:N, N - k - 1] = arcs_r
            AIR = alpha[A][I][R][: N - k, 1 : k + 1]
            BIL = alpha[B][I][L][k:, N - k - 1 : N - 1]
            new = semiring.dot(ACL, BIL)
            alpha[A][C][L][: N - k, k] = new
            alpha[B][C][L][k:N, N - k - 1] = new
            new = semiring.dot(AIR, BCR)
            alpha[A][C][R][: N - k, k] = new
            alpha[B][C][R][k:N, N - k - 1] = new
    
        root_incomplete_span = semiring.times(alpha[A][C][L][0, :], arc_scores[:, :, torch.arange(N), torch.arange(N)])
        root =  [ Chart((batch,), arc_scores, semiring, cache=cache) for _ in range(N)]
        for k in range(N):
            AIR = root_incomplete_span[:, :, :k+1]
            BCR = alpha[B][C][R][k, N - (k+1):]
            root[k] = semiring.dot(AIR, BCR)
        v = torch.stack([root[l-1][:,i] for i, l in enumerate(lengths)], dim=1)
        return v, [arc_scores_in], alpha
    

    `

    Basically, I don't treat the first token as root anymore. I handle the root token just after the for-loop, so you may need handle the length variable. (length = length-1, root no longer be treated as part of sentence) . I tested the modified code and found it bug-free

    opened by sustcsonglin 4
  • Inference for the HMM model

    Inference for the HMM model

    Hello! I was playing with the HMM distribution and I obtained some results that I don't really understand. More precisely, I've set the following parameters

    t = torch.tensor([[0.99, 0.01], [0.01, 0.99]]).log()
    e = torch.tensor([[0.50, 0.50], [0.50, 0.50]]).log()
    i = torch.tensor(np.array([0.99, 0.01])).log()
    x = torch.randint(0, 2, size=(1, 8))
    

    and I was expecting the model to stay in the hidden state 0 regardless of the observed data x – it starts in state 0 and the transition matrix makes it very likely to maintain it. But when plotting the argmax, it appears that the model jumps from one state to the other:

    def show_chain(chain):
        plt.imshow(chain.detach().sum(-1).transpose(0, 1))
    
    dist = torch_struct.HMM(t, e, i, x)
    show_chain(dist.argmax[0])
    

    image

    I must be missing something obvious; but shouldn't dist.argmax correspond to argmax_z p(z | x, Θ)? Thank you!

    opened by danoneata 4
  • DependencyCRF partition function broken

    DependencyCRF partition function broken

    Getting the following in-place operation error when using the DependencyCRF:

    B,N = 3,50
    phi = torch.randn(B,N,N)
    DependencyCRF(phi).partition
    
    /usr/local/lib/python3.7/dist-packages/torch_struct/deptree.py in _check_potentials(self, arc_scores, lengths)
        121         arc_scores = semiring.convert(arc_scores)
        122         for b in range(batch):
    --> 123             semiring.zero_(arc_scores[:, b, lengths[b] + 1 :, :])
        124             semiring.zero_(arc_scores[:, b, :, lengths[b] + 1 :])
        125 
    
    /usr/local/lib/python3.7/dist-packages/torch_struct/semirings/semirings.py in zero_(xs)
        124     @staticmethod
        125     def zero_(xs):
    --> 126         return xs.fill_(-1e5)
        127 
        128     @staticmethod
    
    RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
    
    opened by teffland 3
  • [Question] How to compute a marginal probability over a (contiguous) set of nodes?

    [Question] How to compute a marginal probability over a (contiguous) set of nodes?

    Hi.

    Thank you for the great library. I have one question that I hope you could help with.

    How can I compute a marginal probability over a (contiguous) set of nodes? Right now, I am using your LinearChain-CRF to do NER. In addition to the best sequence itself, I also need to compute the model’s confidence in its predicted labeling over a segment of input. For example, what is the probability that a span of tokens constitute a person name?

    I read your example and see how you get the marginal prob for each individual node. But I was not quite sure how to compute the marginal prob over a subset of nodes. If you could give any hint, it would be great.

    Thank you.

    opened by kimdev95 3
  • Get the score of dist.topk()

    Get the score of dist.topk()

    The topk() function returns top k predictions from the distribution, how to easily get the corresponding score of each prediction?

    By the way, when sentence lengths are short and the k value of topk is large, how to know the number of predictions that are valid? For the example in DependencyCRF, when sentence length is 2 and k is 5, only the top 3 predictions are valid I think.

    opened by wangxinyu0922 3
  • Labeled projective dependency CRF

    Labeled projective dependency CRF

    This is work in progress and isn't ready to merge yet.

    This seems to work for partition, but argmax and marginals don't return as I expect. Both return tensor of shape (B, N, N); I'd expect them to return (B, N, N, L) tensors instead. Any advice?

    opened by kmkurn 3
  • [Question] How to apply pytorch-struct for 2 dimensional data?

    [Question] How to apply pytorch-struct for 2 dimensional data?

    I could find examples of pytorch struct usage for 1d sequence data like text or video frame. But I'm trying to parse tables structure in pdf documents.

    Could you provide some hints where to start?

    opened by YuriyPryyma 4
  • end_class support for Autoregressive

    end_class support for Autoregressive

    end_class is not used for the Autoregressive module: https://github.com/harvardnlp/pytorch-struct/blob/7146de5659ff17ad7be53023c025ffd099866412/torch_struct/autoregressive.py#L49

    opened by urchade 1
  • Update examples to use newer torchtext APIs

    Update examples to use newer torchtext APIs

    opened by erip 2
  • Instable learning with SemiMarkov CRF

    Instable learning with SemiMarkov CRF

    HI,

    First, thank you for fixing #110 (@da03), the SemiCRF works better now, I was able to get good results on span extraction tasks. However, I still encounter a learning instability where the loss (neg logprob) gets negative after several steps (and the accuracy starts to drop). The same problem occurs with batch_size = 1. Below I put the learning curve (f1_score and log loss).

    (Maybe the bug comes from the masking of spans where (length, length + span_with) and length + span_with > length, but I am not sure.)

    Edit: I created a test and it seems that the masking is good. Maybe the log_prob computation or the to_parts function ?

    train_loss score

    opened by urchade 0
  • fix bug- missing assignment of spans from sentCFG in documentation

    fix bug- missing assignment of spans from sentCFG in documentation

    Noticed a small bug in the documentation and example of SentCFG. The return of dist.argmax is (terms, rules, init, spans), but example in documentation only assigns (term, rules, init) and gives dim mismatch. As such when running the example it breaks. This fix resolves this issue.

    opened by jdegange 0
Releases(v0.5)
Official implementation of NPMs: Neural Parametric Models for 3D Deformable Shapes - ICCV 2021

NPMs: Neural Parametric Models Project Page | Paper | ArXiv | Video NPMs: Neural Parametric Models for 3D Deformable Shapes Pablo Palafox, Aljaz Bozic

PabloPalafox 109 Nov 22, 2022
In this work, we will implement some basic but important algorithm of machine learning step by step.

WoRkS continued English 中文 Français Probability Density Estimation-Non-Parametric Methods(概率密度估计-非参数方法) 1. Kernel / k-Nearest Neighborhood Density Est

liziyu0104 1 Dec 30, 2021
Implementation of "Deep Implicit Templates for 3D Shape Representation"

Deep Implicit Templates for 3D Shape Representation Zerong Zheng, Tao Yu, Qionghai Dai, Yebin Liu. arXiv 2020. This repository is an implementation fo

Zerong Zheng 144 Dec 07, 2022
Code for "CloudAAE: Learning 6D Object Pose Regression with On-line Data Synthesis on Point Clouds" @ICRA2021

CloudAAE This is an tensorflow implementation of "CloudAAE: Learning 6D Object Pose Regression with On-line Data Synthesis on Point Clouds" Files log:

Gee 35 Nov 14, 2022
This repository is the official implementation of the Hybrid Self-Attention NEAT algorithm.

This repository is the official implementation of the Hybrid Self-Attention NEAT algorithm. It contains the code to reproduce the results presented in the original paper: https://arxiv.org/abs/2112.0

Saman Khamesian 6 Dec 13, 2022
[CIKM 2021] Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning

Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning. This repo contains the PyTorch code and implementation for the paper E

Akuchi 18 Dec 22, 2022
The implementation of the paper "HIST: A Graph-based Framework for Stock Trend Forecasting via Mining Concept-Oriented Shared Information".

The HIST framework for stock trend forecasting The implementation of the paper "HIST: A Graph-based Framework for Stock Trend Forecasting via Mining C

Wentao Xu 110 Dec 27, 2022
sense-py-AnishaBaishya created by GitHub Classroom

Compute Statistics Here we compute statistics for a bunch of numbers. This project uses the unittest framework to test functionality. Pass the tests T

1 Oct 21, 2021
DANet for Tabular data classification/ regression.

Deep Abstract Networks A pyTorch implementation for AAAI-2022 paper DANets: Deep Abstract Networks for Tabular Data Classification and Regression. Bri

Ronnie Rocket 55 Sep 14, 2022
Official code for "Eigenlanes: Data-Driven Lane Descriptors for Structurally Diverse Lanes", CVPR2022

[CVPR 2022] Eigenlanes: Data-Driven Lane Descriptors for Structurally Diverse Lanes Dongkwon Jin, Wonhui Park, Seong-Gyun Jeong, Heeyeon Kwon, and Cha

Dongkwon Jin 106 Dec 29, 2022
[Link]deep_portfolo - Use Reforcemet earg ad Supervsed learg to Optmze portfolo allocato []

rl_portfolio This Repository uses Reinforcement Learning and Supervised learning to Optimize portfolio allocation. The goal is to make profitable agen

Deepender Singla 165 Dec 02, 2022
Code for our paper 'Generalized Category Discovery'

Generalized Category Discovery This repo is a placeholder for code for our paper: Generalized Category Discovery Abstract: In this paper, we consider

107 Dec 28, 2022
This repo implements a 3D segmentation task for an airport baggage dataset.

3D CT Scan Segmentation With Occupancy Network This repo implements a 3D superresolution segmentation task for an airport baggage dataset. Our final p

Christoph Reich 2 Mar 28, 2022
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

selfcontact This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] It includes the main function

Lea Müller 68 Dec 06, 2022
DeepAL: Deep Active Learning in Python

DeepAL: Deep Active Learning in Python Python implementations of the following active learning algorithms: Random Sampling Least Confidence [1] Margin

Kuan-Hao Huang 583 Jan 03, 2023
SingleVC performs any-to-one VC, which is an important component of MediumVC project.

SingleVC performs any-to-one VC, which is an important component of MediumVC project. Here is the official implementation of the paper, MediumVC.

谷下雨 26 Dec 28, 2022
Trajectory Extraction of road users via Traffic Camera

Traffic Monitoring Citation The associated paper for this project will be published here as soon as possible. When using this software, please cite th

Julian Strosahl 14 Dec 17, 2022
A task Provided by A respective Artenal Ai and Ml based Company to complete it

A task Provided by A respective Alternal Ai and Ml based Company to complete it .

Parth Madan 1 Jan 25, 2022
TART - A PyTorch implementation for Transition Matrix Representation of Trees with Transposed Convolutions

TART This project is a PyTorch implementation for Transition Matrix Representati

Lee Sael 2 Jan 19, 2022
DeepLM: Large-scale Nonlinear Least Squares on Deep Learning Frameworks using Stochastic Domain Decomposition (CVPR 2021)

DeepLM DeepLM: Large-scale Nonlinear Least Squares on Deep Learning Frameworks using Stochastic Domain Decomposition (CVPR 2021) Run Please install th

Jingwei Huang 130 Dec 02, 2022