0


torch.cat()中dim说明

torch.cat()

torch.cat(Tuple[Tensor],dim)->Tensor

输入为Tensor的List/Tuple,输出为一个Tensor

torch.cat()用于对张量的拼接,与数组拼接函数torch.stack()用法类似,二者区别在于输入的变量是数组还是张量。

其中初学者最费解的就是dim的选取,dim的取值范围由输入张量的维度决定,输入为n维张量,dim取值在[0,n-1],接下来我们以实验理解dim不同取值对应的不同操作结果。

初次接触众多博客对dim的讲解为,对于两个二维张量作为输入,dim取0结果为两个张量按行拼接,取1结果为按列拼接,但是对于高维来说就有点难以直观想象结果了,我们尝试三维情况进而总结规律。

先从一个简单的例子入手,输入两个张量为二维,dim取值分别为0和1 :

import torch
X=torch.tensor([[1,2,3],[4,5,6]])
Y=torch.tensor([[7,8,9],[1,4,7]])
input=[X,Y]
A=torch.cat(input,dim=0)
B=torch.cat(input,dim=1)
print("X:{}\nY:{}\ndim0:{}\ndim1:{}".format(X,Y,A,B))

结果如下

   ![](https://img-blog.csdnimg.cn/bbd453b0fca1427fbc0e99c57ddef9c1.png)

可以看出对于两个二维张量作为输入,dim取0结果为两个张量按行拼接,取1结果为按列拼接,但是对于高维来说就有点难以直观想象结果了,我们尝试三维情况进而总结规律。

import torch
X=torch.tensor([[[1,2],[3,4]],[[5,6],[7,8]]])
Y=torch.tensor([[[7,6],[5,4]],[[8,9],[9,10]]])
input=[X,Y]
A=torch.cat(input,dim=0)
B=torch.cat(input,dim=1)
C=torch.cat(input,dim=2)

print("X:{}\nY:{}\ndim0:{}\ndim1:{}\ndim2:{}".format(X,Y,A,B,C))

输入为两个三维张量:

输出:

可见对于dim=0,其输出结果为对两个张量的最高维度包含的内容进行拼接,此例中,X和Y均为三维张量,其最高维度包含的内容为二维,因此,dim=0结果是对其二维张量进行拼接组成的三维张量:

那么对于dim=1的情况,就是对次高维包含内容进行拼接,次高维为2维,其内容为1维,将1维进行拼接得到:

以此类推,对于dim=n-1的情况比较难理解,此例dim=2,对次次高维即1维的内容进行拼接,其中1维的内容是0维,可以理解为1维张量括号内的元素,即每个数字,将其进行拼接,得到结果:

至此,torch.cat()的dim作用已经讲清楚,建议动手实验一下就可以弄明白其中的奥秘!!!


本文转载自: https://blog.csdn.net/m0_60105894/article/details/128311630
版权归原作者 IT pupil 所有, 如有侵权,请联系我们删除。

“torch.cat()中dim说明”的评论:

还没有评论