当前位置:网站首页>Pytorch depth separable convolution and mobilenet_ v1
Pytorch depth separable convolution and mobilenet_ v1
2022-07-18 21:10:00 【Jiang Junze】
Pytorch Depth separable convolution and MobileNet_v1
1. Depth separates the convolution
Deep separable convolution presents a new idea : For different inputs channel Adopt different convolution kernels for convolution , It decomposes the ordinary convolution operation into two processes .

Convolution process
Suppose there is N × H × W × C N \times H \times W \times C N×H×W×C The input of , At the same time there is k individual 3 × 3 3 \times 3 3×3 Convolution of . If you set pad=1 And stride=1 , Then the ordinary convolution output is N × H × W × k N \times H \times W \times k N×H×W×k

Depthwise The process
Depthwise It means to be N × H × W × C N \times H \times W \times C N×H×W×C The input of is divided into g r o u p = C group=C group=C Group , Then each group does 3 × 3 3 \times 3 3×3 Convolution . This is equivalent to collecting each Channel The spatial features of , namely Depthwise features
Pointwise The process
Pointwise It means right N × H × W × C N \times H \times W \times C N×H×W×C Input to do k A common one 1 ∗ 1 1*1 1∗1 Convolution . This is equivalent to collecting the characteristics of each point , namely Pointwise features .Depthwise+Pointwise The final output is also N × H × W × k N \times H \times W \times k N×H×W×k
2. Advantage innovation
Depthwise+Pointwise It can be approximately regarded as a convolution :
- Ordinary convolution :3x3 Conv+BN+ReLU
- Mobilenet Convolution :3x3 Depthwise Conv+BN+ReLU and 1x1 Pointwise Conv+BN+ReLU
Computational acceleration
The parameter quantity decreases
Suppose the number of input channels is 3, The number of output channels is required to be 256, Two ways :
- Go straight to the next 3×3×256 Convolution kernel , The parameter is :3×3×3×256 = 6,912
- DW operation , Complete in two steps , The parameter is :3×3×3+3×1×1×256 = 795(3 Characteristic layers *(3*3 Convolution kernel )), The convolution depth parameter is usually taken as 1
The number of multiplication operations is reduced
Compare the multiplications of different convolutions :
- The calculation amount of ordinary convolution is : H × W × C × k × 3 × 3 H\times W\times C\times k \times 3 \times3 H×W×C×k×3×3
- Depthwise The amount of calculation is : H × W × C × 3 × 3 H \times W \times C \times 3 \times 3 H×W×C×3×3
- Pointwise The amount of calculation is : H × W × C × k H \times W \times C \times k H×W×C×k
adopt Depthwise+Pointwise The split , It is equivalent to compressing the calculation amount of ordinary convolution into :
d e p t h w i s e + p o i n t w i s e c o n v = H × W × C × 3 × 3 + H × W × C × k H × W × C × k × 3 × 3 = 1 k + 1 3 × 3 \frac{depthwise+pointwise}{conv}=\frac{H \times W \times C \times 3 \times 3+H \times W \times C \times k}{H \times W \times C \times k \times 3 \times 3}=\frac{1}{k}+\frac{1}{3 \times 3} convdepthwise+pointwise=H×W×C×k×3×3H×W×C×3×3+H×W×C×k=k1+3×31
Channel area separation
Deep separable convolution takes into account both channel and region changes in previous ordinary convolution operations ( Convolution first considers only regions , Then consider the passage ), The separation of channel and region is realized .
3. Network structure
Mobilenet v1 Use depth separable convolution to accelerate , Its structure is as follows
- First, go through a step of 2 Of 3*3 Traditional convolution layer for feature extraction
- Then through a series of deep separable convolutions (DW+PW Convolution ) Feature extraction
- Finally, it passes through the average pool layer 、 Fully connected layer , And what happened softmax Function to get the final output value .

pytorch Realization
import torch
import torch.nn as nn
def conv_bn(in_channel, out_channel, stride = 1):
""" Traditional convolution block :Conv+BN+Act """
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, 3, stride, 1, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU6(inplace=True)
)
def conv_dsc(in_channel, out_channel, stride = 1):
""" Depth separates the convolution :DW+BN+Act + Conv+BN+Act """
return nn.Sequential(
nn.Conv2d(in_channel, in_channel, 3, stride, 1, groups=in_channel, bias=False),
nn.BatchNorm2d(in_channel),
nn.ReLU6(inplace=True),
nn.Conv2d(in_channel, out_channel, 1, 1, 0, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU6(inplace=True),
)
class MobileNetV1(nn.Module):
def __init__(self,in_dim=3, num_classes=1000):
super(MobileNetV1, self).__init__()
self.num_classes = num_classes
self.stage1 = nn.Sequential(
conv_bn(in_dim, 32, 2),
conv_dsc(32, 64, 1),
conv_dsc(64, 128, 2),
conv_dsc(128, 128, 1),
conv_dsc(128, 256, 2),
conv_dsc(256, 256, 1),
)
self.stage2 = nn.Sequential(
conv_dsc(256, 512, 2),
conv_dsc(512, 512, 1),
conv_dsc(512, 512, 1),
conv_dsc(512, 512, 1),
conv_dsc(512, 512, 1),
conv_dsc(512, 512, 1),
)
self.stage3 = nn.Sequential(
conv_dsc(512, 1024, 2),
conv_dsc(1024, 1024, 1),
)
self.avg = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(1024, self.num_classes)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.avg(x)
x = x.view(-1, 1024)
x = self.fc(x)
return x
边栏推荐
- Mathematical method - random integer in the packaging range - six digit random verification code - random color string - decimal conversion and decimal places retention
- Hystrix 部署
- 基于 Servlet 项目——博客系统
- Tableau JDBC connection graphdb
- 数据统计分析案例(对比分析、销量定比分析、同比、双坐标图、环比、shift、贡献度分析(帕累托法则)、差异化分析、resample、季节性波动分析)
- 【js 封装一个简单的异步API,获取异步操作结果和过程解析】
- Markdown 基本语法格式
- 基于STM32电源模块开发
- JS speed up video playback
- 剑指 Offer 54. 二叉搜索树的第k大节点
猜你喜欢

OpenCV 教程 02: OpenCV 的核心操作

The first ide overlord in the universe, replaced...

高数下|二重积分的计算1|高数叔|手写笔记

浅谈数组方法重构再封装-forEach-Map——push(),unshift(),shift(),Map(),filter(),every(),some(), reduce()

SkyWalking 针对 gRPC 的负载均衡和自动扩容实践

Find out the motivation and needs of enterprise location, and carry out investment attraction work efficiently

第十三篇,STM32 I2C串行总线通信实现

初探BLE 蓝牙电池服务

产业园区如何做好精细化运营管理

VMware recovery snapshot failed to create an anonymous paging file of 5040 MB: insufficient system resources to complete the requested service
随机推荐
App security detection guide learning notes
R language uses pcauchy function to generate Cauchy distribution cumulative distribution function data, and uses plot function to visualize Cauchy distribution cumulative distribution function data
高数下|二重积分的概念及性质|高数叔|手写笔记
R language uses the range function of dplyr package to sort dataframe, and the range function performs descending sorting based on a field (variable)
Raspberry pie record
一步到位玩透Ansible-目录
Stm32f407---- power management
【缓存】一种新的缓存 Caffeine Cach 介绍
剑指 Offer 55 - I. 二叉树的深度
Redis implements distributed locks
Scrcpy projection
sentinel
剑指 Offer 57 - II. 和为s的连续正数序列
docker mysql
基于STM32电源模块开发
Play through ansible directory in one step
uniapp Request请求封装的方法
R language ggplot2 visualization: use the ggecdf function of ggpubr package to visualize the empirical cumulative density function curve
R语言使用dplyr包的arrange函数进行dataframe排序、arrange函数基于一个字段(变量)进行降序排序实战
Apache log related