当前位置:网站首页>Pytorch中torch.unsqueeze()和torch.squeeze()函数解析
Pytorch中torch.unsqueeze()和torch.squeeze()函数解析
2022-07-15 17:32:00 【cv_lhp】
一. torch.squeeze()函数解析
1. 官网链接
torch.squeeze(),如下图所示:
2. torch.squeeze()函数解析
torch.squeeze(input, dim=None, out=None)
squeeze()函数的功能是维度压缩。返回一个tensor(张量),其中 input 中维度大小为1的所有维都已删除。
举个例子:如果 input 的形状为 (A×1×B×C×1×D),那么返回的tensor的形状则为 (A×B×C×D)
当给定 dim 时,那么只在给定的维度(dimension)上进行压缩操作,注意给定的维度大小必须是1,否则不能进行压缩。
举个例子:如果 input 的形状为 (A×1×B),squeeze(input, dim=0)后,返回的tensor不变,因为第0维的大小为A,不是1;squeeze(input, 1)后,返回的tensor将被压缩为 (A×B)。
3. 代码举例
3.1 输入size=(2, 1, 2, 1, 2)的张量
x = torch.randn(size=(2, 1, 2, 1, 2))
x.shape
输出结果如下:
torch.Size([2, 1, 2, 1, 2])
3.2 把x中维度大小为1的所有维都已删除
y = torch.squeeze(x)#表示把x中维度大小为1的所有维都已删除
y.shape
输出结果如下:
torch.Size([2, 2, 2])
3.3 把x中第一维删除,但是第一维大小为2,不为1,因此结果删除不掉
y = torch.squeeze(x,0)#表示把x中第一维删除,但是第一维大小为2,不为1,因此结果删除不掉
y.shape
输出结果如下:
torch.Size([2, 1, 2, 1, 2])
3.4 把x中第二维删除,因为第二维大小是1,因此可以删掉
y = torch.squeeze(x,1)#表示把x中第二维删除,因为第二维大小是1,因此可以删掉
y.shape
输出结果如下:
torch.Size([2, 2, 1, 2])
3.5 把x中最后一维删除,但是最后一维大小为2,不为1,因此结果删除不掉
y = torch.squeeze(x,dim=-1)#表示把x中最后一维删除,但是最后一维大小为2,不为1,因此结果删除不掉
y.shape
输出结果如下:
torch.Size([2, 1, 2, 1, 2])
二.torch.unsqueeze()函数解析
1. 官网链接
torch.unsqueeze(),如下图所示:
2. torch.unsqueeze()函数解析
torch.unsqueeze(input, dim) → Tensor
unsqueeze()函数起升维的作用,参数dim表示在哪个地方加一个维度,注意dim范围在:[-input.dim() - 1, input.dim() + 1]之间,比如输入input是一维,则dim=0时数据为行方向扩,dim=1时为列方向扩,再大错误。
3. 代码举例
3.1 输入一维张量,在第0维(行)扩展,第0维大小为1
x = torch.tensor([1, 2, 3, 4])
y = torch.unsqueeze(x, 0)#在第0维扩展,第0维大小为1
y,y.shape
输出结果如下:
(tensor([[1, 2, 3, 4]]), torch.Size([1, 4]))
3.2 在第1维(列)扩展,第1维大小为1
y = torch.unsqueeze(x, 1)#在第1维扩展,第1维大小为1
y,y.shape
输出结果如下:
(tensor([[1],
[2],
[3],
[4]]),
torch.Size([4, 1]))
3.3 在第最后一维(也就是倒数第一维进行)扩展,最后一维大小为1
y = torch.unsqueeze(x, -1)#在第最后一维扩展,最后一维大小为1
y,y.shape
输出结果如下:
(tensor([[1],
[2],
[3],
[4]]),
torch.Size([4, 1]))
边栏推荐
- Redis+caffeine two-level cache enables smooth access speed
- 利用chardet检测网页编码
- Gift from JRockit: JMC virtual machine diagnostic tool
- Record the use of yolov5 (1)
- PolarDB for PostgreSQL的分布式查询引擎是怎样的?
- C语言:【位域操作】(结构体中使用冒号)
- OpenHarmony相关知识学习
- How to greedy match in VIM
- 来自JRockit的礼物:JMC虚拟机诊断工具
- Openharmony module 2 file samgr_ Server resolution (1)
猜你喜欢
随机推荐
会用redis吗?那还不快来了解下redis protocol
How to make Sitemaps website map
Eight guidelines for modbus-rs485 wiring
MODBUS-RS485布线的8条准则
Cache penetration, cache avalanche, cache breakdown?
英特尔发布开源AI参考套件
VSCode【因为在此系统上禁止运行脚本】
Pytorch中torch.max()函数解析
AMD Ryzen 5 7600X 6核心和4.4GHz 'Zen 4 ' CPU现身跑分数据库
利用chardet检测网页编码
Get to know the three modules of openharmony
C#小技巧 获取枚举所有枚举值
C language (high level) static address book
Using chardet to detect web page coding
【英雄哥七月集训】第 15天:深度优先搜索
Illegal profits exceed one million, and new outlets in the industry are being cracked and eroded
Unit MySQL appears in MySQL Solution of service could not be found
Dictionary tree (trie tree)
概率沉思录:1.Plausible reasoning
Queue(单项队列)和Deque(双端队列)的知识点整理








