当前位置:网站首页>[paddleseg source code reading] about the trivial matter that the paddleseg model returns a list
[paddleseg source code reading] about the trivial matter that the paddleseg model returns a list
2022-07-18 22:59:00 【Master Fuwen】
1. Push backwards
stay paddleseg/core/infer.py in ,inference and slide_inference function ` in :
( meanwhile ,aug_inference Will also call inference function , inference Function will also be called internally slide_inference function )
def slide_inference(model, im, crop_size, stride):
""" Infer by sliding window. Args: model (paddle.nn.Layer): model to get logits of image. im (Tensor): the input image. crop_size (tuple|list). The size of sliding window, (w, h). stride (tuple|list). The size of stride, (w, h). Return: Tensor: The logit of input image. """
h_im, w_im = im.shape[-2:]
w_crop, h_crop = crop_size
w_stride, h_stride = stride
# calculate the crop nums
rows = np.int(np.ceil(1.0 * (h_im - h_crop) / h_stride)) + 1
cols = np.int(np.ceil(1.0 * (w_im - w_crop) / w_stride)) + 1
# prevent negative sliding rounds when imgs after scaling << crop_size
rows = 1 if h_im <= h_crop else rows
cols = 1 if w_im <= w_crop else cols
# TODO 'Tensor' object does not support item assignment. If support, use tensor to calculation.
final_logit = None
count = np.zeros([1, 1, h_im, w_im])
for r in range(rows):
for c in range(cols):
h1 = r * h_stride
w1 = c * w_stride
h2 = min(h1 + h_crop, h_im)
w2 = min(w1 + w_crop, w_im)
h1 = max(h2 - h_crop, 0)
w1 = max(w2 - w_crop, 0)
im_crop = im[:, :, h1:h2, w1:w2]
logits = model(im_crop) # <------------------------- Look here , The following lines
if not isinstance(logits, collections.abc.Sequence):
raise TypeError(
"The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
.format(type(logits)))
logit = logits[0].numpy()
if final_logit is None:
final_logit = np.zeros([1, logit.shape[1], h_im, w_im])
final_logit[:, :, h1:h2, w1:w2] += logit[:, :, :h2 - h1, :w2 - w1]
count[:, :, h1:h2, w1:w2] += 1
if np.sum(count == 0) != 0:
raise RuntimeError(
'There are pixel not predicted. It is possible that stride is greater than crop_size'
)
final_logit = final_logit / count
final_logit = paddle.to_tensor(final_logit)
return final_logit
def inference(model,
im,
ori_shape=None,
transforms=None,
is_slide=False,
stride=None,
crop_size=None):
""" Inference for image. Args: model (paddle.nn.Layer): model to get logits of image. im (Tensor): the input image. ori_shape (list): Origin shape of image. transforms (list): Transforms for image. is_slide (bool): Whether to infer by sliding window. Default: False. crop_size (tuple|list). The size of sliding window, (w, h). It should be probided if is_slide is True. stride (tuple|list). The size of stride, (w, h). It should be probided if is_slide is True. Returns: Tensor: If ori_shape is not None, a prediction with shape (1, 1, h, w) is returned. If ori_shape is None, a logit with shape (1, num_classes, h, w) is returned. """
if hasattr(model, 'data_format') and model.data_format == 'NHWC':
im = im.transpose((0, 2, 3, 1))
if not is_slide:
logits = model(im) # <------------------------- Look here , The following lines
if not isinstance(logits, collections.abc.Sequence):
raise TypeError(
"The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
.format(type(logits)))
logit = logits[0]
else:
logit = slide_inference(model, im, crop_size=crop_size, stride=stride)
if hasattr(model, 'data_format') and model.data_format == 'NHWC':
logit = logit.transpose((0, 3, 1, 2))
if ori_shape is not None:
logit = reverse_transform(logit, ori_shape, transforms, mode='bilinear')
pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
return pred, logit
else:
return logit
If is_slide by False, Take this branch ( Just go inference function ):
logits = model(im) # <------------------------- Look here , The following lines
if not isinstance(logits, collections.abc.Sequence):
raise TypeError(
"The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
.format(type(logits)))
logit = logits[0]
If is_slide by True, Take this branch (inference Call in slide_inference function ):
logits = model(im_crop) # <------------------------- Look here , The following lines
if not isinstance(logits, collections.abc.Sequence):
raise TypeError(
"The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}"
.format(type(logits)))
logit = logits[0].numpy()
1 yes logit yes logits Of the 0 Elements
2 yes logits Should be collections.abc.Sequence class , Such as list, tuple And so on
All in all ,model(input) The return should be list , Probability has only one element
2. Example
Light through this inference , It's not very accurate , Let's find some model Look at him forward The return value of list still Tensor Just fine
Let's take a look at the hot ones recently pp_liteseg Model
Position in paddleseg/models/pp_liteseg.py
We just need to find , Decor @manager.MODELS.add_component Just decorate the class
That is to say 
View it directly forward The return value of the function :
See that both branches return list, There is no need to know list What the hell is inside , Just know that the return is list that will do
in addition , stay Not training When ,list There is only one element in
Let's take a look at the fire before pphumanseg_lite Model paddleseg/models/pphumanseg_lite.py
Find the main class , First, look at the decorative objects @manager.MODELS.add_component Decorative class
Second, look at __all__ What are the elements in :
__all__ = ['PPHumanSegLite']
Look at it forward function :
def forward(self, x):
# Encoder
input_shape = paddle.shape(x)[2:]
x = self.conv_bn0(x) # 1/2
shortcut = self.conv_bn1(x) # shortcut
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) # 1/4
x = self.block1(x) # 1/8
x = self.block2(x) # 1/16
# Decoder
x = self.depthwise_separable0(x)
shortcut_shape = paddle.shape(shortcut)[2:]
x = F.interpolate(
x,
shortcut_shape,
mode='bilinear',
align_corners=self.align_corners)
x = paddle.concat(x=[shortcut, x], axis=1)
x = self.depthwise_separable1(x)
logit = self.depthwise_separable2(x)
logit = F.interpolate(
logit,
input_shape,
mode='bilinear',
align_corners=self.align_corners)
return [logit] # <---------- logit It's obviously a Tensor, Put a frame on him here , Become a list
3. summary
- from infer.inference see logit Always take logits Of the 0 Elements , It is speculated that the return value of the model should be a sequence ( list )
- adopt PPHumanSegLite and pp_liteseg Two models , You can see that the return value is indeed a list
- Guess it may be for compatibility , Therefore, set all return values to list
- Maybe it's more true , Ha ha ha ha ha ha ha , exactly ( Some are unnecessary hh)
边栏推荐
- Two stack implementation queue and two queue implementation stack (JS)
- 产品测评师工作重复繁琐怎么办?能实现自动化吗?
- 三维点云课程(一)——点云基础介绍
- Solution to Chinese garbled code in response results of burpsuite tool
- 2022.7.14-----leetcode. seven hundred and forty-five
- Which securities company has a low handling fee for opening an account, and which stock is safe to open an account
- 溢流阀力士乐ZDB10VP2-4X/315V
- 原装rexroth比例阀4WRBA10W64-2X/G24N9Z4/M
- Remember once, ants were abused on all sides. The water was too deep. Have you built the ferry across the river?
- NASA took the first clear picture of the moment after the big bang
猜你喜欢

时序事件数据怎么处理?【慕尼黑工业大学博士论文】神经时序点过程(NTPP): 连续时间事件数据建模

10分钟自定义搭建行人分析系统,检测跟踪、行为识别、人体属性All-in-One

三维点云课程(二)——最邻近问题

Beijing Hualian BHG mall continues to make efforts, and BHG day leads the new upsurge of urban consumption
![[advanced C language] ⑨ dynamic memory allocation knowledge summary is super detailed](/img/2d/0f5757160dd75de763317d74ef6f25.png)
[advanced C language] ⑨ dynamic memory allocation knowledge summary is super detailed

3D point cloud course (IV) -- clustering and model fitting

PMP每日一练 | 考试不迷路-7.16
![[applet project development -- JD mall] product list page of uni app (Part 1)](/img/e2/3af126edf091b3b39954686d63f7c7.png)
[applet project development -- JD mall] product list page of uni app (Part 1)

leetcode--49字母异位词分组

3D point cloud course (I) -- Introduction to point cloud Foundation
随机推荐
pytorch的安装(非常详细)
基于多数据源零代码同时生成多个数据库CRUD增删改查RESTful API接口——MySql,PostgreSql,Oracle,Microsoft SQL Server多数据源
记一次蚂蚁金服四面遭虐,面试水太深,过河的渡船你造好了吗?
华大110时钟校准
多线程应用
Remove the k-bit number [greedy thought & monotonic stack implementation]
DeepMind最新114页报告《多智能体强化学习中的新兴易货贸易行为》
Listen to drag and drop events. You can't get the uploaded file content by dragging for the first time, and you can get the uploaded file content normally after the second time
Deepmind's latest 114 page report "emerging barter trade behavior in Multi-Agent Reinforcement Learning"
SQL Server 各种锁 NOLOCK、UPDLOCK、HOLDLOCK、READPAST
The observation returned by the 'reset()' method is not contained with the
Leetcode 47. Full arrangement II
[applet project development -- JD mall] product list page of uni app (Part 1)
10分钟自定义搭建行人分析系统,检测跟踪、行为识别、人体属性All-in-One
Leetcode -- 49 letter ectopic word grouping
德国rexroth比例阀4WRPEH6C5B40L-3X/M/24F1
Nature aging | activates FoxM1 gene, or doubles human life span
Common problems of database
Huada 110 clock calibration
TCP congestion control details | 6 Active queue management