A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction"

Overview

ssnt-loss

ℹ️ This is a WIP project. the implementation is still being tested.

A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction" https://arxiv.org/abs/1609.08194.

Usage

There are two versions, a normal version and a memory efficient version. They should give the same output, please inform me if they don't.

>> target_mask = targets.ne(pad) # (B, T) >>> targets = targets[target_mask] # (T_flat,) >>> log_probs = log_probs[target_mask] # (T_flat, S, V) Args: log_probs (Tensor): Word prediction log-probs, should be output of log_softmax. tensor with shape (T_flat, S, V) where T_flat is the summation of all target lengths, S is the maximum number of input frames and V is the vocabulary of labels. targets (Tensor): Tensor with shape (T_flat,) representing the reference target labels for all samples in the minibatch. log_p_choose (Tensor): emission log-probs, should be output of F.logsigmoid. tensor with shape (T_flat, S) where T_flat is the summation of all target lengths, S is the maximum number of input frames. source_lengths (Tensor): Tensor with shape (N,) representing the number of frames for each sample in the minibatch. target_lengths (Tensor): Tensor with shape (N,) representing the length of the transcription for each sample in the minibatch. neg_inf (float, optional): The constant representing -inf used for masking. Default: -1e4 reduction (string, optional): Specifies reduction. suppoerts mean / sum. Default: None. """">
def ssnt_loss_mem(
    log_probs: Tensor,
    targets: Tensor,
    log_p_choose: Tensor,
    source_lengths: Tensor,
    target_lengths: Tensor,
    neg_inf: float = -1e4,
    reduction="mean",
):
    """The memory efficient implementation concatenates along the targets
    dimension to reduce wasted computation on padding positions.

    Assuming the summation of all targets in the batch is T_flat, then
    the original B x T x ... tensor is reduced to T_flat x ...

    The input tensors can be obtained by using target mask:
    Example:
        >>> target_mask = targets.ne(pad)   # (B, T)
        >>> targets = targets[target_mask]  # (T_flat,)
        >>> log_probs = log_probs[target_mask]  # (T_flat, S, V)

    Args:
        log_probs (Tensor): Word prediction log-probs, should be output of log_softmax.
            tensor with shape (T_flat, S, V)
            where T_flat is the summation of all target lengths,
            S is the maximum number of input frames and V is
            the vocabulary of labels.
        targets (Tensor): Tensor with shape (T_flat,) representing the
            reference target labels for all samples in the minibatch.
        log_p_choose (Tensor): emission log-probs, should be output of F.logsigmoid.
            tensor with shape (T_flat, S)
            where T_flat is the summation of all target lengths,
            S is the maximum number of input frames.
        source_lengths (Tensor): Tensor with shape (N,) representing the
            number of frames for each sample in the minibatch.
        target_lengths (Tensor): Tensor with shape (N,) representing the
            length of the transcription for each sample in the minibatch.
        neg_inf (float, optional): The constant representing -inf used for masking.
            Default: -1e4
        reduction (string, optional): Specifies reduction. suppoerts mean / sum.
            Default: None.
    """

Minimal example

import torch
import torch.nn as nn
import torch.nn.functional as F
from ssnt_loss import ssnt_loss_mem, lengths_to_padding_mask
B, S, H, T, V = 2, 100, 256, 10, 2000

# model
transcriber = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda()
predictor = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda()
joiner_trans = nn.Linear(H, V, bias=False).cuda()
joiner_alpha = nn.Sequential(
    nn.Linear(H, 1, bias=True),
    nn.Tanh()
).cuda()

# inputs
src_embed = torch.rand(B, S, H).cuda().requires_grad_()
tgt_embed = torch.rand(B, T, H).cuda().requires_grad_()
targets = torch.randint(0, V, (B, T)).cuda()
adjust = lambda x, goal: x * goal // x.max()
source_lengths = adjust(torch.randint(1, S+1, (B,)).cuda(), S)
target_lengths = adjust(torch.randint(1, T+1, (B,)).cuda(), T)

# forward
src_feats, (h1, c1) = transcriber(src_embed.transpose(1, 0))
tgt_feats, (h2, c2) = predictor(tgt_embed.transpose(1, 0))

# memory efficient joint
mask = ~lengths_to_padding_mask(target_lengths)
lattice = F.relu(
    src_feats.transpose(0, 1).unsqueeze(1) + tgt_feats.transpose(0, 1).unsqueeze(2)
)[mask]
log_alpha = F.logsigmoid(joiner_alpha(lattice)).squeeze(-1)
lattice = joiner_trans(lattice).log_softmax(-1)

# normal ssnt loss
loss = ssnt_loss_mem(
    lattice,
    targets[mask],
    log_alpha,
    source_lengths=source_lengths,
    target_lengths=target_lengths,
    reduction="sum"
) / (B*T)
loss.backward()
print(loss.item())

Note

This implementation is based on the simplifying derivation proposed for monotonic attention, where they use parallelized cumsum and cumprod to compute the alignment. Based on the similarity of SSNT and monotonic attention, we can infer that the forward variable alpha(i,j) can be computed similarly.

Feel free to contact me if there are bugs in the code.

Reference

Owner
張致強
張致強
[ICRA 2022] CaTGrasp: Learning Category-Level Task-Relevant Grasping in Clutter from Simulation

This is the official implementation of our paper: Bowen Wen, Wenzhao Lian, Kostas Bekris, and Stefan Schaal. "CaTGrasp: Learning Category-Level Task-R

Bowen Wen 199 Jan 04, 2023
[ICML 2021] "Graph Contrastive Learning Automated" by Yuning You, Tianlong Chen, Yang Shen, Zhangyang Wang

Graph Contrastive Learning Automated PyTorch implementation for Graph Contrastive Learning Automated [talk] [poster] [appendix] Yuning You, Tianlong C

Shen Lab at Texas A&M University 80 Nov 23, 2022
The devkit of the nuPlan dataset.

The devkit of the nuPlan dataset.

Motional 264 Jan 03, 2023
Multiple-Object Tracking with Transformer

TransTrack: Multiple-Object Tracking with Transformer Introduction TransTrack: Multiple-Object Tracking with Transformer Models Training data Training

Peize Sun 537 Jan 04, 2023
Rethinking of Pedestrian Attribute Recognition: A Reliable Evaluation under Zero-Shot Pedestrian Identity Setting

Pytorch Pedestrian Attribute Recognition: A strong PyTorch baseline of pedestrian attribute recognition and multi-label classification.

Jian 79 Dec 18, 2022
A modern pure-Python library for reading PDF files

pdf A modern pure-Python library for reading PDF files. The goal is to have a modern interface to handle PDF files which is consistent with itself and

6 Apr 06, 2022
A Keras implementation of YOLOv4 (Tensorflow backend)

keras-yolo4 请使用更完善的版本: https://github.com/miemie2013/Keras-YOLOv4 Please visit here for more complete model: https://github.com/miemie2013/Keras-YOLOv

384 Nov 29, 2022
This repository is the official implementation of Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regularized Fine-Tuning (NeurIPS21).

Core-tuning This repository is the official implementation of ``Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regular

vanint 18 Dec 17, 2022
A PyTorch-Based Framework for Deep Learning in Computer Vision

TorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision @misc{you2019torchcv, author = {Ansheng You and Xiangtai Li and Zhen Zhu a

Donny You 2.2k Jan 09, 2023
RuleBERT: Teaching Soft Rules to Pre-Trained Language Models

RuleBERT: Teaching Soft Rules to Pre-Trained Language Models (Paper) (Slides) (Video) RuleBERT is a pre-trained language model that has been fine-tune

16 Aug 24, 2022
This is the official code of L2G, Unrolling and Recurrent Unrolling in Learning to Learn Graph Topologies.

Learning to Learn Graph Topologies This is the official code of L2G, Unrolling and Recurrent Unrolling in Learning to Learn Graph Topologies. Requirem

Stacy X PU 16 Dec 09, 2022
FastFace: Lightweight Face Detection Framework

Light Face Detection using PyTorch Lightning

Ömer BORHAN 75 Dec 05, 2022
This project is the official implementation of our accepted ICLR 2021 paper BiPointNet: Binary Neural Network for Point Clouds.

BiPointNet: Binary Neural Network for Point Clouds Created by Haotong Qin, Zhongang Cai, Mingyuan Zhang, Yifu Ding, Haiyu Zhao, Shuai Yi, Xianglong Li

Haotong Qin 59 Dec 17, 2022
PyDeepFakeDet is an integrated and scalable tool for Deepfake detection.

PyDeepFakeDet An integrated and scalable library for Deepfake detection research. Introduction PyDeepFakeDet is an integrated and scalable Deepfake de

Junke, Wang 49 Dec 11, 2022
[CVPR 2022 Oral] EPro-PnP: Generalized End-to-End Probabilistic Perspective-n-Points for Monocular Object Pose Estimation

EPro-PnP EPro-PnP: Generalized End-to-End Probabilistic Perspective-n-Points for Monocular Object Pose Estimation In CVPR 2022 (Oral). [paper] Hanshen

同济大学智能汽车研究所综合感知研究组 ( Comprehensive Perception Research Group under Institute of Intelligent Vehicles, School of Automotive Studies, Tongji University) 842 Jan 04, 2023
AI Summer's complete catalog of articles

Learn Deep Learning with AI Summer A collection of all articles (almost 100) written for the AI Summer blog organized by topic. Deep Learning Theory M

AI Summer 95 Dec 29, 2022
TimeSHAP explains Recurrent Neural Network predictions.

TimeSHAP TimeSHAP is a model-agnostic, recurrent explainer that builds upon KernelSHAP and extends it to the sequential domain. TimeSHAP computes even

Feedzai 90 Dec 18, 2022
[ICLR 2021] Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization

Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization Kaidi Cao, Yining Chen, Junwei Lu, Nikos Arechiga, Adrien Gaidon, Tengyu Ma

Kaidi Cao 29 Oct 20, 2022
Contextualized Perturbation for Textual Adversarial Attack, NAACL 2021

Contextualized Perturbation for Textual Adversarial Attack Introduction This is a PyTorch implementation of Contextualized Perturbation for Textual Ad

cookielee77 30 Jan 01, 2023
The authors' official PyTorch SigWGAN implementation

The authors' official PyTorch SigWGAN implementation This repository is the official implementation of [Sig-Wasserstein GANs for Time Series Generatio

9 Jun 16, 2022