0


关于torch.cat()与torch.stack()

关于torch.cat()与torch.stack()整理

代码中一直使用torch.cat()和torch.stack()进行tensor维度拼接,花点时间整理下。方便使用🤷‍♂️:

1.用法

torch.cat(): 用于连接两个相同大小的张量

torch.stack(): 用于连接两个相同大小的张量,并扩展维度

见代码示例更清晰:

import torch
a = torch.tensor(torch.arange(10)).reshape(3,3)
b = torch.tensor(torch.arange(10,100,10)).reshape(3,3)print(a)
Out[7]: 
tensor([[1,2,3],[4,5,6],[7,8,9]])print(b)
Out[10]: 
tensor([[10,20,30],[40,50,60],[70,80,90]])

对上面两个tensor进行操作
torch.cat()
拼接函数,将多个张量拼接成一个张量,保持维度不变。torch.cat()有两个参数,第一个是要拼接的张量的列表或是元组;第二个参数是拼接的维度。

使用不同的参数,输出的结果不同,首先填入一个会返回错误的参数:从返回报错原因可以看到,参数的返回必须是在[-2, 1]之间。

d3 = torch.cat((a, b), dim=2)# 返回输出如下
Traceback (most recent call last):
  File "/home/franklinpan/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3251,in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-23-b2602bd6230f>", line 1,in<module>
    d3 = torch.cat((a, b), dim=2)
IndexError: Dimension out of range(expected to be inrange of [-2,1], but got 2)

设置dim=-1,得到如下结果,当参数为-1时,与dim=1的返回结果相同
dim=-1,表示在第二维度进行拼接

d_1= torch.cat((a, b), dim=-1)print(d_1)
Out[25]: 
tensor([[1,2,3,10,20,30],[4,5,6,40,50,60],[7,8,9,70,80,90]])
 
d1 = torch.cat((a, b), dim=1)print(d1)
Out[22]: 
tensor([[1,2,3,10,20,30],[4,5,6,40,50,60],[7,8,9,70,80,90]])

设置dim=-2,与dim=0相同:
表示在第一维度进行拼接

d_2= torch.cat((a, b), dim=-2)print(d_2)
Out[27]: 
tensor([[1,2,3],[4,5,6],[7,8,9],[10,20,30],[40,50,60],[70,80,90]])
 
d1 = torch.cat((a, b), dim=0)print(d1)
Out[20]: 
tensor([[1,2,3],[4,5,6],[7,8,9],[10,20,30],[40,50,60],[70,80,90]])

可以看到,采用不同的参数,输出的张量维度仍然与原来张量的维度保持一致。
若输入参数的维度不一样,会产生什么结果呢?

当输出张量保持一个维度一致时,若在相同维度的方向进行连接torch.cat操作,则仍然可以张量的合并操作,若在维度不同的方向进行连接操作,会报错。(🤦‍♀️torch.cat操作没有广播机制

**torch.stack()**操作
拼接函数,是拼接以后,再扩展一维。torch.stack()有两个参数,第一个是要拼接的张量的列表或是元组;第二个参数是拼接的维度。
此处不再重复dim=-3 or -2等操作,当dim=0时

c1 = torch.stack((a, b), dim=0)
 
Out[12]: 
tensor([[[1,2,3],[4,5,6],[7,8,9]],[[10,20,30],[40,50,60],[70,80,90]]])

当dim=1时

c2 = torch.stack((a, b), dim=1)
 
Out[15]: 
tensor([[[1,2,3],[10,20,30]],[[4,5,6],[40,50,60]],[[7,8,9],[70,80,90]]])

当 dim=2时

c3 = torch.stack((a, b), dim=2)
 
Out[17]: 
tensor([[[1,10],[2,20],[3,30]],[[4,40],[5,50],[6,60]],[[7,70],[8,80],[9,90]]])

若在torch.stack中使用不同维度的输入,得到报错的反馈
从实例可见,torch.stack操作将会增加合并后张量的维度

总结:

torch.cat()与torch.stack()操作都是对张量进行拼接操作,不同点如下:

torch.stack()将对张量维度进行扩张

torch.cat()可以对只有一个方向维度相同的张量进行合并,而torch.stack()要求输入张量的维度必须一样。

stack与cat的区别在于,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。见下面代码:

A=torch.tensor([[1,2,3],[4,5,6],[7,8,9]],dtype=torch.float)print("A:",A)
B=torch.tensor([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]],dtype=torch.float)print("B:",B)print("*********************************")

c=torch.cat((A,B),dim=0)#保持维度不变print(c)print(c.shape)

d=torch.stack((A,B),dim=0)#多扩展一维度print(d)print(d.shape)

运行结果:
在这里插入图片描述

扩展:

torch.cat和torch.stack()的拼接为[]数据时:
见拼接列表数据


本文转载自: https://blog.csdn.net/qq_38765642/article/details/127842547
版权归原作者 幼稚园的扛把子~ 所有, 如有侵权,请联系我们删除。

“关于torch.cat()与torch.stack()”的评论:

还没有评论