当前位置:网站首页>Bisenetv1 face segmentation
Bisenetv1 face segmentation
2022-07-19 03:09:00 【HySmiley】
1、 The paper
BiSeNet: Bilateral Segmentation Network for Real-time Semantic Segmentation
https://arxiv.org/abs/1808.00897.pdf
It is mentioned in the paper that : Reduce the spatial resolution , The speed of real-time reasoning will lead to poor performance . Therefore, it is proposed that the spatial path (Spatial Path) And context path (Context Path) Two part bilateral segmentation network (BiSeNet).
Spatial Path: Save spatial information , Generate high-resolution features .
Context Path: Adopt the fast down sampling strategy to obtain enough receptive fields .
The author summarizes the real-time semantic segmentation , Three ways to accelerate the model :
①、 Try to limit the input size , Reduce computational complexity by cropping or resizing . Although the method is simple and effective , But the loss of spatial detail undermines the prediction , Especially near the boundary , This leads to a decline in measurement and visualization accuracy .
②、 Prune the channels of the network , To speed up reasoning , Especially in the early stages of the basic model . However , It weakens the space capacity .
③、ENet It is recommended to abandon the final stage of the model , Pursue a very compact framework . However , The disadvantages of this method are obvious : because ENet The downsampling operation was abandoned in the final stage , The accepted domain of the model is not enough to cover large objects , Resulting in poor recognition ability .
A detailed reference :https://blog.csdn.net/sinat_17456165/article/details/106152907
2、 Network structure

Each module
①Spatial Path: It consists of several groups Convolution +BN+relu form The convolution step of each layer is 2.
characteristic : Network shallow 、 Channel width . effect : Retain rich spatial information to generate high-resolution features .
class SpatialPath(nn.Module):
def __init__(self):
super(SpatialPath, self).__init__()
self.cbnr1=ConvBNRelu(3,64,7,2,3)
self.cbnr2 = ConvBNRelu(64, 64, 3, 2, 1)
self.cbnr3 = ConvBNRelu(64, 64, 3, 2, 1)
self.cbnr4 = ConvBNRelu(64, 128, 1, 1, 0)
self.init_weight()
def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None:
nn.init.constant_(ly.bias, 0)
def forward(self,x):
x=self.cbnr1(x)
x=self.cbnr2(x)
x=self.cbnr3(x)
x=self.cbnr4(x)
return x
def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nowd_params += list(module.parameters())
return wd_params, nowd_params
②Context Path : from ARM+ Lightweight network (Res18/Xception39 etc. )
characteristic : Deep network . effect : Get enough feeling fields .
With res18 For example :

If not used torchvision in model library , Rewrite res18 Network and use its pre training model .
The parameter names in the network can be different , But the number of network layers needs to be consistent , It is mainly convenient for parameter assignment .
initialization - Loading of pre training parameters .
def init_weight(self):
model=resnet18(pretrained=False)
model.fc=None
model.load_state_dict(torch.load('resnet18-5c106cde.pth'))
# If you do not use temporary variables , Parameter values are not updated
self_state_dict=self.state_dict()
dict=[]
for k,v in model.state_dict().items():
dict.append(v)
for i,(k,v) in enumerate(self_state_dict.items()):
self_state_dict.update({k:dict[i]})
self.load_state_dict(self_state_dict)
ARM modular :
Refine features , characteristic : Calculate no loss .
③FFM Feature fusion module
It mainly integrates the characteristics of the two paths map
3、 Data sets
Face segmentation dataset CelebAMask-HQ contain 3w Face images , And the segmentation of each part of the face mask.
The dataset has 19 A split label ( Include background ):'skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'.

mask Image is 24 position png chart , And each classification label is independent , It needs to be quantified and integrated into a graph to convert it into 8 position png chart .
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import os.path as osp
import os
import cv2
from transform import *
from PIL import Image
face_data = '/data/CelebAMask-HQ/CelebA-HQ-img'
face_sep_mask = '/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
mask_path = '/data/CelebAMask-HQ/mask'
counter = 0
total = 0
for i in range(15):
atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
for j in range(i * 2000, (i + 1) * 2000):
mask = np.zeros((512, 512))
for l, att in enumerate(atts, 1):
total += 1
file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
path = osp.join(face_sep_mask, str(i), file_name)
if os.path.exists(path):
counter += 1
sep_mask = np.array(Image.open(path).convert('P'))
# print(np.unique(sep_mask))
mask[sep_mask == 225] = l
cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
print(j)
print(counter, total)After the merger mask The image is :

Data set partitioning :train:test=9:1
4、 Load data set
①、 Data to enhance
Data enhancement methods include random clipping 、 Mirror image 、 The zoom 、 Color space enhancement .
Random cutting : The original image and mask Deal with .
Mirror image : The original image and mask Image processing ,mask The middle part of the label is interchangeable : eyes 、 eyebrow 、 Ears .
The zoom : The original image and mask Deal with .
Color space : Saturation of the original image 、 Contrast 、 Transparency adjustment .
②、 load
DataLoader And DataSet Use a combination of
transform transformation
Image traversal
Batch of images
5、 Loss function

Li:logsoftmax
lp: Main loss
li: Ancillary loss (cp The process )
6、 Optimizer
Random gradient descent method , Super parameter settings 、 to update .
7、 journal
Use logger The database records the data in the training process .
8、 Evaluation indicators
Confusion matrix form :
T(F)/P(N) | The prediction is true | The prediction is false |
It's true | True positive (TP) | false negative (FN) |
False in fact | False positive (FP) | True negative (TN) |
Compute build :

def confusion_matrix(self,pre,lab):
P_pre=pre.flatten()
L_lab=lab.flatten()
mask=(L_lab>=0)&(L_lab<self.num_class)
confusion=np.zeros((self.num_class,self.num_class))#,dtype=np.int32
#n*L+P
confusion+=np.bincount(self.num_class*L_lab[mask].astype(int)+P_pre[mask],minlength=self.num_class**2).reshape(self.num_class,self.num_class)
return confusion
The evaluation index of the model is calculated by the confusion matrix :

Pixel accuracy :
def pixel_acc(self,confusion):
return np.diag(confusion).sum()/(confusion.sum())
Accuracy of each category
def class_acc(self,confusion):
return np.diag(confusion)/np.maximum(confusion.sum(axis=1),1)#vector(1*numclass)
Category average accuracy :
def mpa(self,cls_acc):
return np.nanmean(cls_acc)
iou Occurring simultaneously than
def iou(self,confusion):
return np.diag(confusion) / np.maximum(np.sum(confusion,axis=1) + np.sum(confusion,axis=0) - np.diag(confusion), 1)
miou The average ratio of crossing and merging
def miou(self,iou_):
return np.nanmean(iou_)
9、 Result analysis
Training 8w Time
acc=94.95%, macc=57.41%, mIoU=52.40%
test :


Reference resources :
GitHub - zllrunning/face-makeup.PyTorch: Lip and hair color editor using face parsing maps.
边栏推荐
- [MCU simulation] (XIII) instruction system logic operation instruction shift instruction
- MySQL storage engine details
- 04_ Service registration Eureka
- Code demonstration of fcos face detection model in openvino
- It's good to take more exercise
- C language foundation day4 array
- 5. Is the asynctool framework flawed?
- [MCU simulation] (XVI) control transfer instructions - unconditional transfer instructions, conditional transfer instructions
- Oracle gets the last and first data (gets the first and last data by time)
- [MCU simulation] (VII) addressing mode - bit addressing
猜你喜欢

Understand the JVM memory structure in one article
![[MCU simulation] (XX) org - set start address](/img/9e/4e44dd779b0de28a190d86fbb1c2c0.png)
[MCU simulation] (XX) org - set start address

The place where the dream begins ---- first knowing C language

ubuntu清除cuda缓存

5. Is the asynctool framework flawed?

PyTorch最佳实践和代码模板

Full virtualization and semi virtualization

【剑指Offer】31-35题(判断一个序列是否是栈的出栈序列之一,层序打印二叉树以及分行打印、每行逆着打印),判断序列是否是二叉搜索树的后序遍历路径,二叉树找一条权值为K的路径,复制复杂链表
![[NoSQL] redis master-slave, sentinel, cluster](/img/69/37b80398617040984b006d3d7b71b9.png)
[NoSQL] redis master-slave, sentinel, cluster

Detailed explanation of case when usage of SQL
随机推荐
[MCU simulation] (XX) org - set start address
5. Is the asynctool framework flawed?
数据源对象管理(第三方对象资源) & 加载properties文件
ubuntu清除cuda缓存
MySQL master-slave replication + read write separation
ncnn 线程
[MCU simulation] (IV) addressing mode register addressing and direct addressing
Fiddler grabbing
This is a mathematical problem
1. Introduction, analysis and implementation of asynctool framework
【单片机仿真】(十一)指令系统逻辑运算指令 — 逻辑与指令ANL、逻辑或指令ORL
[single chip microcomputer simulation] (XI) instruction system logic operation instruction - logic and instruction anl, logic or instruction ORL
05-中央处理器
【剑指Offer】31-35题(判断一个序列是否是栈的出栈序列之一,层序打印二叉树以及分行打印、每行逆着打印),判断序列是否是二叉搜索树的后序遍历路径,二叉树找一条权值为K的路径,复制复杂链表
[MCU simulation] (I) proteus8.9 installation tutorial
Go语言 实现发送短信验证码 并登录
RESNET learning notes
Obvious things
[MCU simulation] (II) keil installation tutorial
MySQL数据库中的事务和存储引擎