前言
这个本来是没打算写的,因为看了官方的解释以及在网上看了好几个教程都没理解什么意思,所以把自己理解的东西整理分享一下。
官方的解释
官网链接:torch.gather()
给个截图如下
常用的参数有3个,第一个
input
表示要从中选取元素,第二个
dim
表示操作的维度,第三个
index
表示选取元素的索引。
按照官方的解释我是没看懂的,后面去找教程也一知半解,所以自己琢磨了一下,终于悟了。
使用详解
结合着例子,直接看代码把:
import torch
a = torch.arange(3,12).view(3,3)print(a)# tensor([[ 3, 4, 5],# [ 6, 7, 8],# [ 9, 10, 11]])
index = torch.tensor([[2,1,0]])
b = torch.gather(a, dim=0, index=index)print(b)# tensor([[9, 7, 5]])# 1、将index中的各个元素的索引明确,获得具体坐标:# index = torch.tensor([[2, 1, 0]])中,# 2的索引(坐标)为(0,0),1的索引(坐标)为(0,1),0的索引(坐标)为(0,2)# 2、将具体坐标中对应的维度替换成index中的值:# 2的索引(坐标)为(0,0),将第0个维度的索引替换后的新坐标为(2, 0),用2替换掉0# 1的索引(坐标)为(0,1),将第0个维度的索引替换后的新坐标为(1, 1),用1替换掉0# 0的索引(坐标)为(0,2),将第0个维度的索引替换后的新坐标为(0, 2),用0替换掉0# 3、按照新的坐标取输入中的值:# tensor([[ 3, 4, 5],# [ 6, 7, 8],# [ 9, 10, 11]]),坐标(2,0)值为9,坐标(1,1)值为7,坐标(0,2)值为5,得到最后的结果[9,7,5].
index = torch.tensor([[2,1,0]])
c = torch.gather(a, dim=1, index=index)print(c)# tensor([[5, 4, 3]])# 1、获取具体坐标:(0,0),(0,1),(0,2)# 2、第1维度替换坐标:(0,2),(0,1),(0,0)# 3、找元素:[5,4,3]# 二维的情况也一样
index = torch.tensor([[0,2],[1,2]])
d = torch.gather(a, dim=1, index=index)print(d)# tensor([[3, 5],# [7, 8]])# 1、获取具体坐标:(0,0),(0,1),(1,0),(1,1)# 2、第1维度替换坐标:(0,0),(0,2),(1,1),(1,2)# 3、找元素:[[3, 5],[7, 8]]
怕在代码里面太暗了看不清楚,在这里再贴一次:
以第一个为例:
创建张量
a = torch.arange(3, 12).view(3, 3)
print(a)
index = torch.tensor([[2, 1, 0]])
b = torch.gather(a, dim=0, index=index)
print(b)
a 的值如下:
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
b 的值为tensor([[9, 7, 5]])
具体过程:
1、将index中的各个元素的索引明确,获得具体坐标:
index = torch.tensor([[2, 1, 0]])中,
2的索引(坐标)为(0,0),1的索引(坐标)为(0,1),0的索引(坐标)为(0,2)
2、将具体坐标中对应的维度替换成index中的值:
2的索引(坐标)为(0,0),将第0个维度的索引替换后的新坐标为(2, 0),用2替换掉0
1的索引(坐标)为(0,1),将第0个维度的索引替换后的新坐标为(1, 1),用1替换掉0
0的索引(坐标)为(0,2),将第0个维度的索引替换后的新坐标为(0, 2),用0替换掉0
3、按照新的坐标取输入中的值:
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]),坐标(2,0)值为9,坐标(1,1)值为7,坐标(0,2)值为5,得到最后的结果[9,7,5].
实战操作
torch.gather()这个函数通常用在批量的获取张量的某些数据,比如说,要获取一个大小为(b, n, c)的张量中的多个不连续的索引行向量,这种操作通常在下采样的过程中会用到。按照正常的思路,实现这样的操作需要写几个for循环,但for循环在训练时特别的慢,因此可以使用torch.gather来实现这一功能。
具体例子:
deftest_gather():
b, n, c =4,3,3
k =2# 下采样的个数
a = torch.rand((b, n, c))# 定义输入数据print('a:', a)
idx = torch.randint(low=0, high=n, size=(b, k))# (b, k),生成随机索引print('idx:', idx)# 进行维度扩展和复制
new_idx = idx.unsqueeze(-1).expand(-1,-1,3)print('idx after expand', new_idx)# 关键语句,按照一定的维度来取出数据
b = torch.gather(a, dim=1, index=new_idx)print('b', b)# 这个是for循环版本的操作
c = torch.stack([a[i][idx[i],:]for i inrange(len(a))])print(b == c)
打印信息如下:
a: tensor([[[0.8053,0.7751,0.7346],[0.4371,0.1006,0.6389],[0.9040,0.1699,0.3022]],[[0.7410,0.5656,0.9189],[0.4067,0.4953,0.1776],[0.9622,0.0738,0.3553]],[[0.5321,0.9538,0.5806],[0.2257,0.7163,0.7548],[0.2393,0.4100,0.2497]],[[0.2234,0.9685,0.7388],[0.7087,0.0933,0.7147],[0.1741,0.0103,0.6587]]])
idx: tensor([[0,2],[2,2],[0,0],[0,1]])
idx after expand tensor([[[0,0,0],[2,2,2]],[[2,2,2],[2,2,2]],[[0,0,0],[0,0,0]],[[0,0,0],[1,1,1]]])
b tensor([[[0.8053,0.7751,0.7346],[0.9040,0.1699,0.3022]],[[0.9622,0.0738,0.3553],[0.9622,0.0738,0.3553]],[[0.5321,0.9538,0.5806],[0.5321,0.9538,0.5806]],[[0.2234,0.9685,0.7388],[0.7087,0.0933,0.7147]]])
tensor([[[True,True,True],[True,True,True]],[[True,True,True],[True,True,True]],[[True,True,True],[True,True,True]],[[True,True,True],[True,True,True]]])
Process finished with exit code 0
再来详细的看一下是怎么选取数据的:
以b中的第一个行向量为例子,看一下是怎么得到的
1、明确坐标,idx[0][0] = 0,也就是说1的坐标为(0, 0),由于这里0代表的是第一行向量,但是在torch.gather中要精确到具体的坐标,0只是代表了一个维度而且,还有两个维度,因此要将其进行维度扩展和复制:
idx.unsqueeze(-1).expand(-1,-1,3)# unsqueeze函数进行维度扩张,expand对最后一个维度进行复制
具体效果可以从打印的信息看出:
维度扩展之前:idx: tensor([[0,2],
维度扩展之前:idx after expand tensor([[[0,0,0],[2,2,2]],
2、替换坐标,因此索引0的值在扩张之后是(0, 0, 0),索引2为(2, 2, 2)由于在本次的代码中torch.gather是对第一维度进行操作 b = torch.gather(a, dim=1, index=new_idx),因此对第一个维度进行变换,因此对第一个维度进行替换,取值过程如下,(0, 0, 0)的对应坐标为(0, 0, 0), (0, 0, 1), (0, 0, 2),对第一个维度进行替换得到new_idx = (0, 0, 0), (0, 0, 1), (0, 0, 2),因此就可以根据这new_idx去a中取值,也就是a[0][0][:]。同样的道理,(2, 2, 2)的坐标为(0, 1, 0), (0, 1, 1),(0, 1, 2) -> (0, 2, 0), (0, 2, 1),(0, 2, 2),因此根据这三个索引可以取出a中的值,输出结果可以从打印信息看出:
a[0][0]=[0.8053,0.7751,0.7346]
a[0][2]=[0.9040,0.1699,0.3022]
对应的取出的值为:
b tensor([[[0.8053,0.7751,0.7346],[0.9040,0.1699,0.3022]],
参考链接
图解PyTorch中的torch.gather函数
Pytorch系列(1):torch.gather()
pytorch之torch.gather方法
pytorch中的所有随机数(normal、rand、randn、randint、randperm) 以及 随机数种子(seed、manual_seed、initial_seed)
结束语
文章为分享、记录、整理自己的经历情况,水平有限,如有错误之处敬请指出。
版权归原作者 沉默的前行者 所有, 如有侵权,请联系我们删除。