Pytorch cuda extension of grid_sample1d

Overview

Grid Sample 1d

pytorch cuda extension of grid sample 1d. Since pytorch only supports grid sample 2d/3d, I extend the 1d version for efficiency. The forward pass is 2~3x faster than pytorch grid sample.

setup

  • Pytorch == 1.7.1
  • CUDA == 10.1

Other versions of pytorch or cuda may work but I haven't test.

you can choose to manually build it or use JIT

Build

python setup.py install

JIT

comment import grid_sample1d_cuda as grid_sample1d in op.py

uncomment

grid_sample1d = load(
    'grid_sample1d_cuda', ['grid_sample1d_cuda.cpp', 'grid_sample1d_cuda_kernel.cu'], verbose=True)

in op.py

Usage

import torch
from grid_sample1d import GridSample1d

grid_sample1d = GridSample1d(padding_mode=True, align_corners=True)
N = 16
C = 256
L_in = 64
L_out = 128
input = torch.randn((N, C, L_in)).cuda()
grids = torch.randn((N, L_out)).cuda()
output = grid_sample1d(input, grids)

Options are

  • padding_mode: True for border padding, False for zero padding
  • align_corners: same with align_corners in torch.nn.functional.grid_sample

difference

In forward pass, calculation on the channel dim C is parallel, which is serial in torch.nn.functional.grid_sample. Parallel calculation on C may cause round off error in backward. But for now, I found it doesn't influence the forward pass.

Test

Accuracy Test

Since grid sample 1d is a special case of grid sample 2d in most cases (not true when padding_mode & align_corners are both False). I test the accuracy of the implemented grid sample based on torch.nn.functional.grid_sample.

import torch
import torch.nn.functional as F


def gridsample1d_by2d(input, grid, padding_mode, align_corners):
    shape = grid.shape
    input = input.unsqueeze(-1)  # batch_size * C * L_in * 1
    grid = grid.unsqueeze(1)  # batch_size * 1 * L_out
    grid = torch.stack([-torch.ones_like(grid), grid], dim=-1)
    z = F.grid_sample(input, grid, padding_mode=padding_mode, align_corners=align_corners)
    C = input.shape[1]
    out_shape = [shape[0], C, shape[1]]
    z = z.view(*out_shape)  # batch_size * C * L_out
    return z

It is recommended to test on your computer because I only test it on CUDA 10.1 GTX 1080Ti

python test/acc_benchmark.py

Both the forward and the backward results are identical except for align_corners=True, padding_mode=False. It may be caused by round off error when we sum series float numbers in different orders.

Deterministic Test

It is very important to do deterministic test since the associative law is no more applied for the calculation of float numbers on computers.

python test/check_deterministic.py

Note

When padding_mode & align_corners are both False, we cannot regard grid sample 1d as a special case of grid sample 2d in pytorch. I have checked the cuda kernel of grid_sample in Pytorch. When padding_mode & align_corners are both False, the output of torch.nn.functional.grid_sample will be half of the expected. Hope it can be fixed one day.

CPU support

Too lazy to support

speed & memory cost

Here are the speed test results on different size of input

references

Owner
lyricpoem
lyricpoem
SCAAML is a deep learning framwork dedicated to side-channel attacks run on top of TensorFlow 2.x.

SCAAML (Side Channel Attacks Assisted with Machine Learning) is a deep learning framwork dedicated to side-channel attacks. It is written in python and run on top of TensorFlow 2.x.

Google 69 Dec 21, 2022
Code for Robust Contrastive Learning against Noisy Views

Robust Contrastive Learning against Noisy Views This repository provides a PyTorch implementation of the Robust InfoNCE loss proposed in paper Robust

Ching-Yao Chuang 53 Jan 08, 2023
Unimodal Face Classification with Multimodal Training

Unimodal Face Classification with Multimodal Training This is a PyTorch implementation of the following paper: Unimodal Face Classification with Multi

Wenbin Teng 3 Jul 06, 2022
The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution.

WSRGlow The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution. Audio sa

Kexun Zhang 96 Jan 03, 2023
BARTScore: Evaluating Generated Text as Text Generation

This is the Repo for the paper: BARTScore: Evaluating Generated Text as Text Generation Updates 2021.06.28 Release online evaluation Demo 2021.06.25 R

NeuLab 196 Dec 17, 2022
[CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision

TorchSemiSeg [CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision by Xiaokang Chen1, Yuhui Yuan2, Gang Zeng1, Jingdong Wang

Chen XiaoKang 387 Jan 08, 2023
[CVPR 2021] Unsupervised 3D Shape Completion through GAN Inversion

ShapeInversion Paper Junzhe Zhang, Xinyi Chen, Zhongang Cai, Liang Pan, Haiyu Zhao, Shuai Yi, Chai Kiat Yeo, Bo Dai, Chen Change Loy "Unsupervised 3D

100 Dec 22, 2022
Self Governing Neural Networks (SGNN): the Projection Layer

Self Governing Neural Networks (SGNN): the Projection Layer A SGNN's word projections preprocessing pipeline in scikit-learn In this notebook, we'll u

Guillaume Chevalier 22 Nov 06, 2022
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip) Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in P

Phil Wang 55 Jan 01, 2023
Blind Image Super-resolution with Elaborate Degradation Modeling on Noise and Kernel

Blind Image Super-resolution with Elaborate Degradation Modeling on Noise and Kernel This repository is the official PyTorch implementation of BSRDM w

Zongsheng Yue 69 Jan 05, 2023
SAFL: A Self-Attention Scene Text Recognizer with Focal Loss

SAFL: A Self-Attention Scene Text Recognizer with Focal Loss This repository implements the SAFL in pytorch. Installation conda env create -f environm

6 Aug 24, 2022
A robust camera and Lidar fusion based velocity estimator to undistort the pointcloud.

Lidar with Velocity A robust camera and Lidar fusion based velocity estimator to undistort the pointcloud. related paper: Lidar with Velocity : Motion

ISEE Research Group 164 Dec 30, 2022
A curated list of awesome neural radiance fields papers

Awesome Neural Radiance Fields A curated list of awesome neural radiance fields papers, inspired by awesome-computer-vision. How to submit a pull requ

Yen-Chen Lin 3.9k Dec 27, 2022
Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.

Core ML Tools Use coremltools to convert machine learning models from third-party libraries to the Core ML format. The Python package contains the sup

Apple 3k Jan 08, 2023
A rule-based log analyzer & filter

Flog 一个根据规则集来处理文本日志的工具。 前言 在日常开发过程中,由于缺乏必要的日志规范,导致很多人乱打一通,一个日志文件夹解压缩后往往有几十万行。 日志泛滥会导致信息密度骤减,给排查问题带来了不小的麻烦。 以前都是用grep之类的工具先挑选出有用的,再逐条进行排查,费时费力。在忍无可忍之后决

上山打老虎 9 Jun 23, 2022
Code for Efficient Visual Pretraining with Contrastive Detection

Code for DetCon This repository contains code for the ICCV 2021 paper "Efficient Visual Pretraining with Contrastive Detection" by Olivier J. Hénaff,

DeepMind 56 Nov 13, 2022
GNPy: Optical Route Planning and DWDM Network Optimization

GNPy is an open-source, community-developed library for building route planning and optimization tools in real-world mesh optical networks

Telecom Infra Project 140 Dec 19, 2022
Seasonal Contrast: Unsupervised Pre-Training from Uncurated Remote Sensing Data

Seasonal Contrast: Unsupervised Pre-Training from Uncurated Remote Sensing Data This is the official PyTorch implementation of the SeCo paper: @articl

ElementAI 101 Dec 12, 2022
Learning Representational Invariances for Data-Efficient Action Recognition

Learning Representational Invariances for Data-Efficient Action Recognition Official PyTorch implementation for Learning Representational Invariances

Virginia Tech Vision and Learning Lab 27 Nov 22, 2022
Distributionally robust neural networks for group shifts

Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization This code implements the g

151 Dec 25, 2022