当前位置:网站首页>Pytorch中torch.max()函数解析
Pytorch中torch.max()函数解析
2022-07-15 17:31:00 【cv_lhp】
一. torch.max()函数解析
1. 官网链接
torch.max,如下图所示:

2. torch.max(input)函数解析
torch.max(input) → Tensor
将输入input张量,无论有几维,首先将其reshape排列成一个一维向量,然后找出这个一维向量里面最大值
3. 代码举例
3.1 输入一维张量,返回一维张量里面最大值
x = torch.randn(4)
y = torch.max(x)
x,y
输出结果如下:
(tensor([-0.6223, 0.0043, -0.8753, 1.4240]), tensor(1.4240))
3.2 输入二维张量,返回二维张量里面最大值
x = torch.randn(3,4)
y = torch.max(x)
x,y
输出结果如下:
(tensor([[-1.1052, 0.1026, 0.9994, -0.3092],
[-0.8400, 0.2004, 0.9212, 0.7807],
[-1.2979, -0.4327, 2.3044, 0.0140]]),
tensor(2.3044))
3.3 输入两个一维张量,输出这两个张量里面相应元素中的最大值
x = torch.randn(4)
z = torch.randn(4)
max = torch.max(x,z)
x,z,max
输出结果如下:
(tensor([-1.5147, -1.2790, -1.0159, -0.4732]),
tensor([-0.4547, -2.8545, 0.0554, -0.3548]),
tensor([-0.4547, -1.2790, 0.0554, -0.3548]))
3.4 输入两个张量,一个张量一维,一个张量二维,此时一维张量会进行广播成二维张量,然后再输出这两个张量里面相应元素中的最大值,输出张量为二维。
x = torch.randn(3,4)
z = torch.randn(4)
max = torch.max(x,z)
x,z,max
输出结果如下:
(tensor([[ 1.1917, 0.6338, 0.7590, -0.9802],
[ 0.2247, 0.3635, 1.3743, 1.6229],
[ 1.6165, 0.0634, 0.5259, 0.1285]]),
tensor([3.4765, 0.4480, 0.1502, 0.3738]),
tensor([[3.4765, 0.6338, 0.7590, 0.3738],
[3.4765, 0.4480, 1.3743, 1.6229],
[3.4765, 0.4480, 0.5259, 0.3738]]))
3.5 输入两个二维张量,输出这两个张量里面相应元素中的最大值,输出张量为二维。
x = torch.randn(3,4)
z = torch.randn(3,4)
max = torch.max(x,z)
x,z,max
输出结果如下:
(tensor([[-0.0835, 0.0718, -1.7404, -0.3218],
[ 0.0577, 0.6271, 1.4014, -0.6417],
[ 0.3917, 0.0761, 1.2479, -0.4352]]),
tensor([[-0.0717, 0.3822, 0.7256, 1.4147],
[-0.1271, 0.1503, 0.3934, 1.6760],
[-2.2341, 2.5286, -0.3500, -0.1751]]),
tensor([[-0.0717, 0.3822, 0.7256, 1.4147],
[ 0.0577, 0.6271, 1.4014, 1.6760],
[ 0.3917, 2.5286, 1.2479, -0.1751]]))
4. torch.max(input,dim)函数解析
torch.max(input, dim, keepdim=False, *, out=None)
输入input(二维)张量,当dim=0时表示找出每列的最大值,函数会返回两个tensor,第一个tensor是每列的最大值,第二个tensor是每列最大值的索引;当dim=1时表示找出每行的最大值,函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。
5. 代码举例
5.1 dim=0,找出每列的最大值,函数会返回两个tensor,第一个tensor是每列的最大值,第二个tensor是每列最大值的索引,两个tensor都是一维。
x = torch.randn(3,4)
max,indices = torch.max(x,dim=0)
x,max,indices
(tensor([[ 0.1806, 1.0274, 0.5138, -1.4184],
[ 0.5892, -0.7117, -1.2707, 0.7682],
[ 0.5152, -0.8803, 1.7604, 0.4852]]),
torch.return_types.max(
values=tensor([0.5892, 1.0274, 1.7604, 0.7682]),
indices=tensor([1, 0, 2, 1])))
输出结果如下:
(tensor([[ 0.0190, 0.8180, -1.0463, 1.7940],
[ 0.7537, -1.0291, -2.3431, 0.3906],
[ 0.3715, 1.6940, -1.1200, -0.4580]]),
tensor([ 0.7537, 1.6940, -1.0463, 1.7940]),
tensor([1, 2, 0, 0]))
5.2 dim=1,找出每行的最大值,函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引,两个tensor都是一维。
x = torch.randn(3,4)
max,indices = torch.max(x,dim=1)
x,max,indices
输出结果如下:
(tensor([[ 1.4832, 0.1886, -0.3044, -0.6111],
[-0.8998, 0.0610, 0.3388, 1.7176],
[ 1.6153, 0.6864, 2.3225, 1.3818]]),
tensor([1.4832, 1.7176, 2.3225]),
tensor([0, 3, 2]))
边栏推荐
- What is the distributed query engine of polardb for PostgreSQL?
- Introduction to C language (6)
- Redis distributed lock: what have you experienced from Xiaobai to Dashen?
- How to import TPC-H data through PSQL?
- LINQ implements dynamic orderby
- PolarDB for PostgreSQL的分布式查询引擎是怎样的?
- How to deploy polardb for PostgreSQL before HTAP capability accelerates TPC-H implementation?
- 实现一下几个简单的loader
- MySQL触发器
- Openharmony module II parsing of header files under interfaces (8)
猜你喜欢

想到多线程并发就心虚?先来巩固下这些线程基础知识吧!

How to greedy match in VIM

Unit MySQL appears in MySQL Solution of service could not be found

2、趋势科技2017校招开发岗试题

缓存穿透、缓存雪崩、缓存击穿?

Win10 right click the new column to add a new markdown file (typora.md)

Redis+Caffeine两级缓存,让访问速度纵享丝滑

OpenHarmony安全模块之AES加密学习

分库分表真的适合你的系统吗?聊聊分库分表和NewSQL如何选择

【示波器的基本使用】以及【示波器按键面板上各个按键含义的介绍】
随机推荐
2、趋势科技2017校招开发岗试题
Intel releases open source AI Reference Suite
福赛生物解读2022上半年大气环境变化,VOCs治理依然是破局关键
Getting started with compilation
Markdown in CSDN sets the width of table columns
In mysql, the decimal (10,2) format is written to Kafka through stream and becomes stri
根据经纬度计算两点之间的距离
使用 SSH 方式拉取代码
Preliminary analysis of openharmony module II
JVM调优命令大全及常用命令工具和实战步骤
模块二interfaces下头文件解析(3)
C语言:【位域操作】(结构体中使用冒号)
numpy获取二维数组某一行、某一列
(高频面试题)计算机网络
如何通过psql导入TPC-H数据?
VSCode【因为在此系统上禁止运行脚本】
模块二interfaces下头文件解析(2)
美团一面:为什么线程崩溃崩溃不会导致 JVM 崩溃?
Simple understanding of CAS and AQS
美团一面面经及详细答案