0


【torch.argmax与torch.max详解】

Pytorch常用函数

一、torch.max

1.调用方式

1)

torch.max(input)

:只需送入输入张量;

2)

torch.max(input, dim, keepdim=False, *, out=None)

:送入张量的同时,需要指定沿着哪个维度进行最大值运算;
这两种调用方式对输入张量的形状没有要求,一维数据或者多维数据都可以。

2.相关介绍

1)返回输入张量中最大值相关数据:

  • 方式一,即不指定dim时,默认将张量展开成一维张量,然后返回第一个最大值;
  • 方式二,即指定dim时,沿着指定的dim维进行最大值运算,输出结果由剩下的维度组成,比如原始维度为H,W,若指定dim=0(即H维),则输出结果由W个元素构成;

2)如果有多个最大值则返回第一个最大值;

3.代码实例及图示理解

首先定义一个简单的方法,当传入张量x和维度dim参数时,分别打印两种调用方式对应的输出:

defprint_maxvalue(x,dim=0):
    max_value=torch.max(x)print(max_value)print('-'*10)
    max_value,max_index=torch.max(x,dim=dim)print(max_value)print(max_index)

对于二维数据,其形状为(H,W)=(10,2):

x=torch.tensor([[0,1],[2,5],[7,3],[5,1],[8,7],[7,6],[9,6],[4,4],[2,0],[9,9]])
print_maxvalue(x,dim=0)

输出结果:

tensor(9)# 所有元素中的第一个最大值----------
tensor([9,9])# 沿着指定dim维进行最大值运算
tensor([6,9])# 沿着指定dim维进行最大值运算,并返回最大值对应的下标

结果分析:
(1)方式一
将张量展开成一维张量,其长度为L=10×2=20,然后返回第一个最大值9
在这里插入图片描述

(2)方式二
指定dim=0,此维度长度为10,表示沿着第0维进行最大值运算,分别对第0维的10个元素取最大值,并返回其对应下标
在这里插入图片描述

二、torch.argmax

1.调用方式

1)

torch.argmax(input)

:只需送入输入张量;

2)

torch.argmax(input, dim, keepdim=False)

:送入张量的同时,需要指定沿着哪个维度进行运算;
这两种调用方式对输入张量的形状没有要求,一维数据或者多维数据都可以。

2.相关介绍

1)返回输入张量中最大值的索引:

  • 方式一,即不指定dim时,默认将张量展开成一维张量,然后返回对应的下标;
  • 方式二,即指定dim时,沿着指定的dim维进行选择,输出结果由剩下的维度组成,比如原始维度为H,W,若指定dim=0(即H维),则输出结果由W个元素构成;

2)如果有多个最大值则返回第一个最大值的下标;
3)返回torch.max函数指定dim时返回的第二个值;

3.代码实例及图示理解

首先定义一个简单的方法,当传入张量x和维度dim参数时,分别打印两种调用方式对应的输出:

defprint_(x,dim=0):# print(x)# print(x.shape)print('-'*10)# 方式一
    max_index = torch.argmax(x)print(max_index)print('-'*10)# 方式二
    max_index = torch.argmax(x, dim=dim)print(max_index)print('-'*10)

1)一维数据:L

x=torch.tensor([8,2,7,15,1])
print_(x,dim=0)

输出结果:

tensor(3)
tensor(3)

结果分析:
在这里插入图片描述

这是最简单的一种方式,就类似一维数组查询最大元素对应下标的过程一致:

  • 对于方式一,传入一维张量后,直接返回第一个最大值15对应的下标3;
  • 对于方式二, 此时数据只有一个维度,故只能指定沿着维度dim=0进行运算,实质还是在所有元素中寻找最大值并返回其下标;

2)二维数据:(H,W)

x=torch.tensor([[0,1],[2,5],[7,3],[5,1],[8,7],[7,6],[9,6],[4,4],[2,0],[9,9]])
print_(x,dim=0)# print_(x,dim=1)

输出结果:

dim=0:H,W->W
tensor(12)
tensor([6,9])# 一般分类问题就适用这种情况,在一个批次的预测输出中确定每个样本的类别,输出结果中每个元素即表示批次中每个样本对应的类别
dim=1: H,W->H
tensor(12)
tensor([1,1,0,0,0,0,0,0,0,0])

结果分析:
(1)方式一
先将输入张量沿着所有维度展开为一维数据,然后返回第一个最大值9对应的下标12
在这里插入图片描述
(2)方式二
函数沿着指定的dim维度进行运算,
dim=0表示张量沿着第0维的方向进行运算,比如此处dim=0维长度为10,则表示在每列的10个元素中找到最大值并返回其下标:
此处第一列最大值为9,而其下标为6
在这里插入图片描述

dim=1表示张量沿着第1维的方向进行运算,比如此处dim=1维长度为2,则表示在每行的2个元素中找到最大值并返回其下标:
此处第一行最大值为1,而其下标为1
在这里插入图片描述

3)多维数据:(N,C,H,W)

x=torch.tensor([[[[1,3],[7,8]],[[8,1],[5,3]],[[2,8],[4,4]]],[[[3,0],[2,0]],[[0,4],[7,16]],[[4,8],[4,3]]]])

print_(x,dim=0)# print_(x,dim=1)# print_(x,dim=2)# print_(x,dim=3)

输出结果:

dim=0:N,C,H,W->C,H,W
tensor(19)
tensor([[[1,0],[0,0]],[[0,1],[1,1]],[[1,0],[0,0]]])

dim=1:N,C,H,W->N,H,W
tensor(19)
tensor([[[1,2],[0,0]],[[2,2],[1,1]]])

dim=2:N,C,H,W->N,C,W
tensor(19)
tensor([[[1,1],[0,1],[1,0]],[[0,0],[1,1],[0,0]]])

dim=3:N,C,H,W->N,C,H
tensor(19)
tensor([[[1,1],[0,0],[1,0]],[[0,0],[1,1],[1,0]]])

结果分析:
开始就说到了,

  • 当调用方式二,指定dim时,函数会沿着指定的维度进行运算,其输出结果的维度由剩余的维度决定;
  • 使用方式一时会直接将张量展开为一维数据,然后返回第一个最大值的下标;

(1)方式一
输入张量形状为(N,C,H,W)=(2,3,2,2),可以清晰地看到,将张量展开为一维数据为长度为L=2×3×2×2=24,且第一个最大值16此时对应的下标为19。
在这里插入图片描述
(2)方式二
dim=0维长度为2,剩余维度为(3,2,2)
在这里插入图片描述

dim=1维长度为3,剩余维度为(2,2,2)
在这里插入图片描述

依次类推…

总结:
其实该函数应用场景最多的是分类任务在进行测试时,判断预测结果的对应类别,此时函数的输入通常为二维数据,只需要使用torch.argmax(x,dim=1)即可达到想要的结果。

三、torch.max与torch.argmax的联系

1)torch.max在寻找输入张量中最大值,而torch.argmax则是寻找最大值对应的下标;
2)二者均使用第一种方式,即未指定dim时,直接将张量展开为一维数据,torch.max返回第一个最大值本身,而torch.argmax则返回最大值的下标;
3)二者均使用第二种方式,即指定dim时,torch.max沿着指定的dim维选取最大值,同时返回最大值本身及其对应下标,而torch.argmax只返回最大值对应的下标。换句话说,torch.argmax的输出结果其实是torch.max指定dim时返回结果中的第二个元素,对应最大值的下标索引;

举个例子:
对于输入张量:

x=torch.tensor([[0,1],[2,5],[7,3],[5,1],[8,7],[7,6],[9,6],[4,4],[2,0],[9,9]])

torch.argmax(x,dim=0)的输出结果为:

tensor([6,9])

torch.max(x,dim=0)的输出结果为:

torch.return_types.max(values=tensor([9,9]),indices=tensor([6,9]))

其中indices即表示指定dim时找到的最大值的对应下标。


本文转载自: https://blog.csdn.net/qq_43665602/article/details/127151890
版权归原作者 NorthSmile 所有, 如有侵权,请联系我们删除。

“【torch.argmax与torch.max详解】”的评论:

还没有评论