A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction"

Overview

ssnt-loss

ℹ️ This is a WIP project. the implementation is still being tested.

A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction" https://arxiv.org/abs/1609.08194.

Usage

There are two versions, a normal version and a memory efficient version. They should give the same output, please inform me if they don't.

>> target_mask = targets.ne(pad) # (B, T) >>> targets = targets[target_mask] # (T_flat,) >>> log_probs = log_probs[target_mask] # (T_flat, S, V) Args: log_probs (Tensor): Word prediction log-probs, should be output of log_softmax. tensor with shape (T_flat, S, V) where T_flat is the summation of all target lengths, S is the maximum number of input frames and V is the vocabulary of labels. targets (Tensor): Tensor with shape (T_flat,) representing the reference target labels for all samples in the minibatch. log_p_choose (Tensor): emission log-probs, should be output of F.logsigmoid. tensor with shape (T_flat, S) where T_flat is the summation of all target lengths, S is the maximum number of input frames. source_lengths (Tensor): Tensor with shape (N,) representing the number of frames for each sample in the minibatch. target_lengths (Tensor): Tensor with shape (N,) representing the length of the transcription for each sample in the minibatch. neg_inf (float, optional): The constant representing -inf used for masking. Default: -1e4 reduction (string, optional): Specifies reduction. suppoerts mean / sum. Default: None. """">
def ssnt_loss_mem(
    log_probs: Tensor,
    targets: Tensor,
    log_p_choose: Tensor,
    source_lengths: Tensor,
    target_lengths: Tensor,
    neg_inf: float = -1e4,
    reduction="mean",
):
    """The memory efficient implementation concatenates along the targets
    dimension to reduce wasted computation on padding positions.

    Assuming the summation of all targets in the batch is T_flat, then
    the original B x T x ... tensor is reduced to T_flat x ...

    The input tensors can be obtained by using target mask:
    Example:
        >>> target_mask = targets.ne(pad)   # (B, T)
        >>> targets = targets[target_mask]  # (T_flat,)
        >>> log_probs = log_probs[target_mask]  # (T_flat, S, V)

    Args:
        log_probs (Tensor): Word prediction log-probs, should be output of log_softmax.
            tensor with shape (T_flat, S, V)
            where T_flat is the summation of all target lengths,
            S is the maximum number of input frames and V is
            the vocabulary of labels.
        targets (Tensor): Tensor with shape (T_flat,) representing the
            reference target labels for all samples in the minibatch.
        log_p_choose (Tensor): emission log-probs, should be output of F.logsigmoid.
            tensor with shape (T_flat, S)
            where T_flat is the summation of all target lengths,
            S is the maximum number of input frames.
        source_lengths (Tensor): Tensor with shape (N,) representing the
            number of frames for each sample in the minibatch.
        target_lengths (Tensor): Tensor with shape (N,) representing the
            length of the transcription for each sample in the minibatch.
        neg_inf (float, optional): The constant representing -inf used for masking.
            Default: -1e4
        reduction (string, optional): Specifies reduction. suppoerts mean / sum.
            Default: None.
    """

Minimal example

import torch
import torch.nn as nn
import torch.nn.functional as F
from ssnt_loss import ssnt_loss_mem, lengths_to_padding_mask
B, S, H, T, V = 2, 100, 256, 10, 2000

# model
transcriber = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda()
predictor = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda()
joiner_trans = nn.Linear(H, V, bias=False).cuda()
joiner_alpha = nn.Sequential(
    nn.Linear(H, 1, bias=True),
    nn.Tanh()
).cuda()

# inputs
src_embed = torch.rand(B, S, H).cuda().requires_grad_()
tgt_embed = torch.rand(B, T, H).cuda().requires_grad_()
targets = torch.randint(0, V, (B, T)).cuda()
adjust = lambda x, goal: x * goal // x.max()
source_lengths = adjust(torch.randint(1, S+1, (B,)).cuda(), S)
target_lengths = adjust(torch.randint(1, T+1, (B,)).cuda(), T)

# forward
src_feats, (h1, c1) = transcriber(src_embed.transpose(1, 0))
tgt_feats, (h2, c2) = predictor(tgt_embed.transpose(1, 0))

# memory efficient joint
mask = ~lengths_to_padding_mask(target_lengths)
lattice = F.relu(
    src_feats.transpose(0, 1).unsqueeze(1) + tgt_feats.transpose(0, 1).unsqueeze(2)
)[mask]
log_alpha = F.logsigmoid(joiner_alpha(lattice)).squeeze(-1)
lattice = joiner_trans(lattice).log_softmax(-1)

# normal ssnt loss
loss = ssnt_loss_mem(
    lattice,
    targets[mask],
    log_alpha,
    source_lengths=source_lengths,
    target_lengths=target_lengths,
    reduction="sum"
) / (B*T)
loss.backward()
print(loss.item())

Note

This implementation is based on the simplifying derivation proposed for monotonic attention, where they use parallelized cumsum and cumprod to compute the alignment. Based on the similarity of SSNT and monotonic attention, we can infer that the forward variable alpha(i,j) can be computed similarly.

Feel free to contact me if there are bugs in the code.

Reference

Owner
張致強
張致強
The Deep Learning with Julia book, using Flux.jl.

Deep Learning with Julia DL with Julia is a book about how to do various deep learning tasks using the Julia programming language and specifically the

Logan Kilpatrick 67 Dec 25, 2022
Buffon’s needle: one of the oldest problems in geometric probability

Buffon-s-Needle Buffon’s needle is one of the oldest problems in geometric proba

3 Feb 18, 2022
Trajectory Extraction of road users via Traffic Camera

Traffic Monitoring Citation The associated paper for this project will be published here as soon as possible. When using this software, please cite th

Julian Strosahl 14 Dec 17, 2022
Data, notebooks, and articles associated with the RSNA AI Deep Learning Lab at RSNA 2021

RSNA AI Deep Learning Lab 2021 Intro Welcome Deep Learners! This document provides all the information you need to participate in the RSNA AI Deep Lea

RSNA 65 Dec 16, 2022
CenterFace(size of 7.3MB) is a practical anchor-free face detection and alignment method for edge devices.

CenterFace Introduce CenterFace(size of 7.3MB) is a practical anchor-free face detection and alignment method for edge devices. Recent Update 2019.09.

StarClouds 1.2k Dec 21, 2022
Referring Video Object Segmentation

Awesome-Referring-Video-Object-Segmentation Welcome to starts ⭐ & comments 💹 & sharing 😀 !! - 2021.12.12: Recent papers (from 2021) - welcome to ad

Explorer 57 Dec 11, 2022
RIM: Reliable Influence-based Active Learning on Graphs.

RIM: Reliable Influence-based Active Learning on Graphs. This repository is the official implementation of RIM. Requirements To install requirements:

Wentao Zhang 4 Aug 29, 2022
An open source object detection toolbox based on PyTorch

MMDetection is an open source object detection toolbox based on PyTorch. It is a part of the OpenMMLab project.

Bo Chen 24 Dec 28, 2022
The source code and data of the paper "Instance-wise Graph-based Framework for Multivariate Time Series Forecasting".

IGMTF The source code and data of the paper "Instance-wise Graph-based Framework for Multivariate Time Series Forecasting". Requirements The framework

Wentao Xu 24 Dec 05, 2022
Object-aware Contrastive Learning for Debiased Scene Representation

Object-aware Contrastive Learning Official PyTorch implementation of "Object-aware Contrastive Learning for Debiased Scene Representation" by Sangwoo

43 Dec 14, 2022
Yolov5-lite - Minimal PyTorch implementation of YOLOv5

Yolov5-Lite: Minimal YOLOv5 + Deep Sort Overview This repo is a shortened versio

Kadir Nar 57 Nov 28, 2022
High-quality implementations of standard and SOTA methods on a variety of tasks.

Uncertainty Baselines The goal of Uncertainty Baselines is to provide a template for researchers to build on. The baselines can be a starting point fo

Google 1.1k Dec 30, 2022
Continuous Diffusion Graph Neural Network

We present Graph Neural Diffusion (GRAND) that approaches deep learning on graphs as a continuous diffusion process and treats Graph Neural Networks (GNNs) as discretisations of an underlying PDE.

Twitter Research 227 Jan 05, 2023
Pytorch implementation of SenFormer: Efficient Self-Ensemble Framework for Semantic Segmentation

SenFormer: Efficient Self-Ensemble Framework for Semantic Segmentation Efficient Self-Ensemble Framework for Semantic Segmentation by Walid Bousselham

61 Dec 26, 2022
Implementation for "Conditional entropy minimization principle for learning domain invariant representation features"

Implementation for "Conditional entropy minimization principle for learning domain invariant representation features". The code is reproduced from thi

1 Nov 02, 2022
The full training script for Enformer (Tensorflow Sonnet) on TPU clusters

Enformer TPU training script (wip) The full training script for Enformer (Tensorflow Sonnet) on TPU clusters, in an effort to migrate the model to pyt

Phil Wang 10 Oct 19, 2022
DTCN IJCAI - Sequential prediction learning framework and algorithm

DTCN This is the implementation of our paper "Sequential Prediction of Social Me

Bobby 2 Jan 24, 2022
Justmagic - Use a function as a method with this mystic script, like in Nim

justmagic Use a function as a method with this mystic script, like in Nim. Just

witer33 8 Oct 08, 2022
Car Price Predictor App used to predict the price of the car based on certain input parameters created using python's scikit-learn, fastapi, numpy and joblib packages.

Pricefy Car Price Predictor App used to predict the price of the car based on certain input parameters created using python's scikit-learn, fastapi, n

Siva Prakash 1 May 10, 2022
Implementation of SETR model, Original paper: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.

SETR - Pytorch Since the original paper (Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.) has no official

zhaohu xing 112 Dec 16, 2022