Implementing DropPath/StochasticDepth in PyTorch

Related tags

Deep LearningDropPath
Overview
%load_ext memory_profiler

Implementing Stochastic Depth/Drop Path In PyTorch

DropPath is available on glasses my computer vision library!

Introduction

Today we are going to implement Stochastic Depth also known as Drop Path in PyTorch! Stochastic Depth introduced by Gao Huang et al is technique to "deactivate" some layers during training.

Let's take a look at a normal ResNet Block that uses residual connections (like almost all models now).If you are not familiar with ResNet, I have an article showing how to implement it.

Basically, the block's output is added to its input: output = block(input) + input. This is called a residual connection

alt

Here we see four ResnNet like blocks, one after the other.

alt

Stochastic Depth/Drop Path will deactivate some of the block's weight

alt

The idea is to reduce the number of layers/block used during training, saving time and make the network generalize better.

Practically, this means setting to zero the output of the block before adding.

Implementation

Let's start by importing our best friend, torch.

import torch
from torch import nn
from torch import Tensor

We can define a 4D tensor (batch x channels x height x width), in our case let's just send 4 images with one pixel each :)

x = torch.ones((4, 1, 1, 1))

We need a tensor of shape batch x 1 x 1 x 1 that will be used to set some of the elements in the batch to zero, using a given prob. Bernoulli to the rescue!

keep_prob: float = .5
mask: Tensor = x.new_empty(x.shape[0], 1, 1, 1).bernoulli_(keep_prob)
    
mask
tensor([[[[0.]]],


        [[[1.]]],


        [[[1.]]],


        [[[1.]]]])

Btw, this is equivelant to

mask: Tensor = (torch.rand(x.shape[0], 1, 1, 1) > keep_prob).float()
mask
tensor([[[[1.]]],


        [[[1.]]],


        [[[1.]]],


        [[[1.]]]])

Before we multiply x by the mask we need to divide x by keep_prob to rescale down the inputs activation during training, see cs231n. So

x_scaled : Tensor = x / keep_prob
x_scaled
tensor([[[[2.]]],


        [[[2.]]],


        [[[2.]]],


        [[[2.]]]])

Finally

output: Tensor = x_scaled * mask
output
tensor([[[[2.]]],


        [[[2.]]],


        [[[2.]]],


        [[[2.]]]])

We can put together in a function

def drop_path(x: Tensor, keep_prob: float = 1.0) -> Tensor:
    mask: Tensor = x.new_empty(x.shape[0], 1, 1, 1).bernoulli_(keep_prob)
    x_scaled: Tensor = x / keep_prob
    return x_scaled * mask

drop_path(x, keep_prob=0.5)
tensor([[[[0.]]],


        [[[0.]]],


        [[[2.]]],


        [[[0.]]]])

We can also do the operation in place

def drop_path(x: Tensor, keep_prob: float = 1.0) -> Tensor:
    mask: Tensor = x.new_empty(x.shape[0], 1, 1, 1).bernoulli_(keep_prob)
    x.div_(keep_prob)
    x.mul_(mask)
    return x


drop_path(x, keep_prob=0.5)
tensor([[[[2.]]],


        [[[2.]]],


        [[[0.]]],


        [[[0.]]]])

However, we may want to use x somewhere else, and dividing x or mask by keep_prob is the same thing. Let's arrive at the final implementation

def drop_path(x: Tensor, keep_prob: float = 1.0, inplace: bool = False) -> Tensor:
    mask: Tensor = x.new_empty(x.shape[0], 1, 1, 1).bernoulli_(keep_prob)
    mask.div_(keep_prob)
    if inplace:
        x.mul_(mask)
    else:
        x = x * mask
    return x

x = torch.ones((4, 1, 1, 1))
drop_path(x, keep_prob=0.8)
tensor([[[[1.2500]]],


        [[[1.2500]]],


        [[[1.2500]]],


        [[[1.2500]]]])

drop_path only works for 2d data, we need to automatically calculate the number of dimensions from the input size to make it work for any data time

def drop_path(x: Tensor, keep_prob: float = 1.0, inplace: bool = False) -> Tensor:
    mask_shape: Tuple[int] = (x.shape[0],) + (1,) * (x.ndim - 1) 
    # remember tuples have the * operator -> (1,) * 3 = (1,1,1)
    mask: Tensor = x.new_empty(mask_shape).bernoulli_(keep_prob)
    mask.div_(keep_prob)
    if inplace:
        x.mul_(mask)
    else:
        x = x * mask
    return x

x = torch.ones((4, 1))
drop_path(x, keep_prob=0.8)
tensor([[0.],
        [0.],
        [0.],
        [0.]])

Let's create a nice DropPath nn.Module

class DropPath(nn.Module):
    def __init__(self, p: float = 0.5, inplace: bool = False):
        super().__init__()
        self.p = p
        self.inplace = inplace

    def forward(self, x: Tensor) -> Tensor:
        if self.training and self.p > 0:
            x = drop_path(x, self.p, self.inplace)
        return x

    def __repr__(self):
        return f"{self.__class__.__name__}(p={self.p})"

    
DropPath()(torch.ones((4, 1)))
tensor([[2.],
        [0.],
        [0.],
        [0.]])

Usage with Residual Connections

We have our DropPath, cool but how do we use it? We need a classic ResNet block, let's implement our good old friend BottleNeckBlock

from torch import nn


class ConvBnAct(nn.Sequential):
    def __init__(self, in_features: int, out_features: int, kernel_size=1):
        super().__init__(
            nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.BatchNorm2d(out_features),
            nn.ReLU()
        )
         

class BottleNeck(nn.Module):
    def __init__(self, in_features: int, out_features: int, reduction: int = 4):
        super().__init__()
        self.block = nn.Sequential(
            # wide -> narrow
            ConvBnAct(in_features, out_features // reduction, kernel_size=1),
            # narrow -> narrow
            ConvBnAct( out_features // reduction, out_features // reduction, kernel_size=3),
            # wide -> narrow
            ConvBnAct( out_features // reduction, out_features, kernel_size=1),
        )
        # I am lazy, no shortcut etc
        
    def forward(self, x: Tensor) -> Tensor:
        res = x
        x = self.block(x)
        return x + res
    
    
BottleNeck(64, 64)(torch.ones((1,64, 28, 28))).shape
torch.Size([1, 64, 28, 28])

To deactivate the block the operation x + res must be equal to res, so our DropPath has to be applied after the block.

class BottleNeck(nn.Module):
    def __init__(self, in_features: int, out_features: int, reduction: int = 4):
        super().__init__()
        self.block = nn.Sequential(
            # wide -> narrow
            ConvBnAct(in_features, out_features // reduction, kernel_size=1),
            # narrow -> narrow
            ConvBnAct( out_features // reduction, out_features // reduction, kernel_size=3),
            # wide -> narrow
            ConvBnAct( out_features // reduction, out_features, kernel_size=1),
        )
        # I am lazy, no shortcut etc
        self.drop_path = DropPath()
        
    def forward(self, x: Tensor) -> Tensor:
        res = x
        x = self.block(x)
        x = self.drop_path(x)
        return x + res
    
BottleNeck(64, 64)(torch.ones((1,64, 28, 28)))
tensor([[[[1.0009, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0134, 1.0034, 1.0034,  ..., 1.0034, 1.0034, 1.0000],
          [1.0134, 1.0034, 1.0034,  ..., 1.0034, 1.0034, 1.0000],
          ...,
          [1.0134, 1.0034, 1.0034,  ..., 1.0034, 1.0034, 1.0000],
          [1.0134, 1.0034, 1.0034,  ..., 1.0034, 1.0034, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0005, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0421],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0421],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0421],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0421],
          [1.0000, 1.0011, 1.0011,  ..., 1.0011, 1.0011, 1.0247]],

         [[1.0203, 1.0123, 1.0123,  ..., 1.0123, 1.0123, 1.0299],
          [1.0000, 1.0005, 1.0005,  ..., 1.0005, 1.0005, 1.0548],
          [1.0000, 1.0005, 1.0005,  ..., 1.0005, 1.0005, 1.0548],
          ...,
          [1.0000, 1.0005, 1.0005,  ..., 1.0005, 1.0005, 1.0548],
          [1.0000, 1.0005, 1.0005,  ..., 1.0005, 1.0005, 1.0548],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         ...,

         [[1.0011, 1.0180, 1.0180,  ..., 1.0180, 1.0180, 1.0465],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0245],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0245],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0245],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0245],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0130, 1.0170, 1.0170,  ..., 1.0170, 1.0170, 1.0213],
          [1.0052, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0065],
          [1.0052, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0065],
          ...,
          [1.0052, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0065],
          [1.0052, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0065],
          [1.0012, 1.0139, 1.0139,  ..., 1.0139, 1.0139, 1.0065]],

         [[1.0103, 1.0181, 1.0181,  ..., 1.0181, 1.0181, 1.0539],
          [1.0001, 1.0016, 1.0016,  ..., 1.0016, 1.0016, 1.0231],
          [1.0001, 1.0016, 1.0016,  ..., 1.0016, 1.0016, 1.0231],
          ...,
          [1.0001, 1.0016, 1.0016,  ..., 1.0016, 1.0016, 1.0231],
          [1.0001, 1.0016, 1.0016,  ..., 1.0016, 1.0016, 1.0231],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]]]],
       grad_fn=<AddBackward0>)

Tada 🎉 ! Now, randomly, our .block will be completely skipped!


Owner
Francesco Saverio Zuppichini
Computer Vision Engineer @ 🤗 BSc informatics. MSc AI. Artificial Intelligence /Deep Learning Enthusiast & Full Stack developer
Francesco Saverio Zuppichini
Conceptual 12M is a dataset containing (image-URL, caption) pairs collected for vision-and-language pre-training.

Conceptual 12M We introduce the Conceptual 12M (CC12M), a dataset with ~12 million image-text pairs meant to be used for vision-and-language pre-train

Google Research Datasets 226 Dec 07, 2022
An algorithmic trading bot that learns and adapts to new data and evolving markets using Financial Python Programming and Machine Learning.

ALgorithmic_Trading_with_ML An algorithmic trading bot that learns and adapts to new data and evolving markets using Financial Python Programming and

1 Mar 14, 2022
Joint Unsupervised Learning (JULE) of Deep Representations and Image Clusters.

Joint Unsupervised Learning (JULE) of Deep Representations and Image Clusters. Overview This project is a Torch implementation for our CVPR 2016 paper

Jianwei Yang 278 Dec 25, 2022
How Effective is Incongruity? Implications for Code-mix Sarcasm Detection.

Code for the paper: How Effective is Incongruity? Implications for Code-mix Sarcasm Detection - ICON ACL 2021

2 Jun 05, 2022
Using python and scikit-learn to make stock predictions

MachineLearningStocks in python: a starter project and guide EDIT as of Feb 2021: MachineLearningStocks is no longer actively maintained MachineLearni

Robert Martin 1.3k Dec 29, 2022
一个多语言支持、易使用的 OCR 项目。An easy-to-use OCR project with multilingual support.

AgentOCR 简介 AgentOCR 是一个基于 PaddleOCR 和 ONNXRuntime 项目开发的一个使用简单、调用方便的 OCR 项目 本项目目前包含 Python Package 【AgentOCR】 和 OCR 标注软件 【AgentOCRLabeling】 使用指南 Pytho

AgentMaker 98 Nov 10, 2022
PyTorch implementation of Pointnet2/Pointnet++

Pointnet2/Pointnet++ PyTorch Project Status: Unmaintained. Due to finite time, I have no plans to update this code and I will not be responding to iss

Erik Wijmans 1.2k Dec 29, 2022
Personal thermal comfort models using digital twins: Preference prediction with BIM-extracted spatial-temporal proximity data from Build2Vec

Personal thermal comfort models using digital twins: Preference prediction with BIM-extracted spatial-temporal proximity data from Build2Vec This repo

Building and Urban Data Science (BUDS) Group 5 Dec 02, 2022
Rethinking the Importance of Implementation Tricks in Multi-Agent Reinforcement Learning

RIIT Our open-source code for RIIT: Rethinking the Importance of Implementation Tricks in Multi-AgentReinforcement Learning. We implement and standard

405 Jan 06, 2023
Hunt down social media accounts by username across social networks

Hunt down social media accounts by username across social networks Installation | Usage | Docker Notes | Contributing Installation # clone the repo $

1 Dec 14, 2021
This is an implementation of PIFuhd based on Pytorch

Open-PIFuhd This is a unofficial implementation of PIFuhd PIFuHD: Multi-Level Pixel-Aligned Implicit Function forHigh-Resolution 3D Human Digitization

Lingteng Qiu 235 Dec 19, 2022
Real-time pose estimation accelerated with NVIDIA TensorRT

trt_pose Want to detect hand poses? Check out the new trt_pose_hand project for real-time hand pose and gesture recognition! trt_pose is aimed at enab

NVIDIA AI IOT 803 Jan 06, 2023
An end-to-end regression problem of predicting the price of properties in Bangalore.

Bangalore-House-Price-Prediction An end-to-end regression problem of predicting the price of properties in Bangalore. Deployed in Heroku using Flask.

Shruti Balan 1 Nov 25, 2022
[NeurIPS 2021] Low-Rank Subspaces in GANs

Low-Rank Subspaces in GANs Figure: Image editing results using LowRankGAN on StyleGAN2 (first three columns) and BigGAN (last column). Low-Rank Subspa

112 Dec 28, 2022
Towards Open-World Feature Extrapolation: An Inductive Graph Learning Approach

This repository holds the implementation for paper Towards Open-World Feature Extrapolation: An Inductive Graph Learning Approach Download our preproc

Qitian Wu 42 Dec 27, 2022
Analyses of the individual electric field magnitudes with Roast.

Aloi Davide - PhD Student (UoB) Analysis of electric field magnitudes (wp2a dataset only at the moment) and correlation analysis with Dynamic Causal M

Davide Aloi 7 Dec 15, 2022
PFFDTD is an open-source FDTD simulator for 3D room acoustics

PFFDTD is an open-source FDTD simulator for 3D room acoustics

Brian Hamilton 34 Nov 24, 2022
Code for ICCV2021 paper SPEC: Seeing People in the Wild with an Estimated Camera

SPEC: Seeing People in the Wild with an Estimated Camera [ICCV 2021] SPEC: Seeing People in the Wild with an Estimated Camera, Muhammed Kocabas, Chun-

Muhammed Kocabas 187 Dec 26, 2022
Signals-backend - A suite of card games written in Python

Card game A suite of card games written in the Python language. Features coming

1 Feb 15, 2022
ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

Double-zh 45 Dec 19, 2022