张量操作
文章目录
张量操作
张量拼接与切分
torch.cat()
1、torch.cat创建张量
import torch
torch.manual_seed(1)# ======================================= example 1 =======================================# torch.cat
flag =True# flag = Falseif flag:
t = torch.ones((2,3))
t_0 = torch.cat([t, t], dim=0)
t_1 = torch.cat([t, t, t], dim=1)print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))
OUT:
t_0:tensor([[1.,1.,1.],[1.,1.,1.],[1.,1.,1.],[1.,1.,1.]]) shape:torch.Size([4,3])
t_1:tensor([[1.,1.,1.,1.,1.,1.,1.,1.,1.],[1.,1.,1.,1.,1.,1.,1.,1.,1.]]) shape:torch.Size([2,9])
torch.stack()
2、 torch.stack()创建张量
# ======================================= example 2 =======================================# torch.stack
flag =True# flag = Falseif flag:
t = torch.ones((2,3))# 增加一个第0个维度,并在这个维度上合并
t_stack_0 = torch.stack([t, t, t], dim=0)# 增加一个第2个维度,并在这个维度上合并
t_stack_1 = torch.stack([t, t, t], dim=2)print("t: {} shape: {}".format(t, t.shape))print("\nt_stack_0: {} shape: {}".format(t_stack_0, t_stack_0.shape))print("\nt_stack_1: {} shape: {}".format(t_stack_1, t_stack_1.shape))
OUT:
t: tensor([[1.,1.,1.],[1.,1.,1.]]) shape: torch.Size([2,3])
t_stack_0: tensor([[[1.,1.,1.],[1.,1.,1.]],[[1.,1.,1.],[1.,1.,1.]],[[1.,1.,1.],[1.,1.,1.]]]) shape: torch.Size([3,2,3])
t_stack_1: tensor([[[1.,1.,1.],[1.,1.,1.],[1.,1.,1.]],[[1.,1.,1.],[1.,1.,1.],[1.,1.,1.]]]) shape: torch.Size([2,3,3])
torch.chunk()
3、torch.chunk()创建张量
# ======================================= example 3 =======================================# torch.chunk
flag =True# flag = Falseif flag:
a = torch.ones((2,7))# 7
list_of_tensors = torch.chunk(a, dim=1, chunks=3)# 按第一维度(也就是列)切分成3份for idx, t inenumerate(list_of_tensors):print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
OUT:
第1个张量:tensor([[1.,1.,1.],[1.,1.,1.]]), shape is torch.Size([2,3])
第2个张量:tensor([[1.,1.,1.],[1.,1.,1.]]), shape is torch.Size([2,3])
第3个张量:tensor([[1.],[1.]]), shape is torch.Size([2,1])
torch.split()
4、torch.split()创建张量
# ======================================= example 4 =======================================# torch.split
flag =True# flag = Falseif flag:
t = torch.ones((2,5))# list[2, 1, 2]求和一定要等于指定维度dim=1上的大小5,不然会报错
list_of_tensors = torch.split(t,[2,1,2], dim=1)# [2 , 1, 2]for idx, t inenumerate(list_of_tensors):print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))# 举例:sum[2, 1, 1] = 4 不等于 5
list_of_tensors = torch.split(t,[2,1,1], dim=1)for idx, t inenumerate(list_of_tensors):print("第{}个张量:{}, shape is {}".format(idx, t, t.shape))
OUT:
第1个张量:tensor([[1.,1.],[1.,1.]]), shape is torch.Size([2,2])
第2个张量:tensor([[1.],[1.]]), shape is torch.Size([2,1])
第3个张量:tensor([[1.,1.],[1.,1.]]), shape is torch.Size([2,2])
Traceback (most recent call last):
File "E:/Code/In_pytorch/pytorch_lesson_code/Hello_pytorch/lesson/lesson-03/lesson-03.py", line 70,in<module>
list_of_tensors = torch.split(t,[2,1,1], dim=1)
File "D:\ProgramData\Anaconda3\envs\py37\lib\site-packages\torch\functional.py", line 156,in split
return tensor.split(split_size_or_sections, dim)
File "D:\ProgramData\Anaconda3\envs\py37\lib\site-packages\torch\_tensor.py", line 518,in split
returnsuper(Tensor, self).split_with_sizes(split_size, dim)
RuntimeError: start (2)+ length (1) exceeds dimension size (2).
张量索引
torch.index_select()
5、torch.index_select()
# ======================================= example 5 =======================================# torch.index_select
flag =True# flag = Falseif flag:
t = torch.randint(0,9, size=(3,3))
idx = torch.tensor([0,2], dtype=torch.long)# float
t_select_0 = torch.index_select(t, dim=0, index=idx)print("t:\n{}\nt_select_0:\n{}".format(t, t_select_0))print("\n")
t_select_1 = torch.index_select(t, dim=1, index=idx)print("t:\n{}\nt_select_1:\n{}".format(t, t_select_1))
OUT:
t:
tensor([[4,5,0],[5,7,1],[2,5,8]])
t_select_0:
tensor([[4,5,0],[2,5,8]])
t:
tensor([[4,5,0],[5,7,1],[2,5,8]])
t_select_1:
tensor([[4,0],[5,1],[2,8]])
torch.masked_select()
6、torch.masked_select()
# ======================================= example 6 =======================================# torch.masked_select
flag =True# flag = Falseif flag:
t = torch.randint(0,9, size=(3,3))# le指:小于
mask = t.le(5)# ge is mean greater than or equal/ gt: greater than le lt
t_select = torch.masked_select(t, mask)print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))
OUT:
t:
tensor([[4,5,0],[5,7,1],[2,5,8]])
mask:
tensor([[True,True,True],[True,False,True],[True,True,False]])
t_select:
tensor([4,5,0,5,1,2,5])
张量变换
torch.reshape()
7、torch.reshape()
# ======================================= example 7 =======================================# torch.reshape
flag =True# flag = Falseif flag:
t = torch.randperm(8)
t_reshape = torch.reshape(t,(-1,2,2))# -1print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
t[0]=1024print("\n")print("t:{}\nt_reshape:\n{}".format(t, t_reshape))print("t.data 内存地址:{}".format(id(t.data)))print("t_reshape.data 内存地址:{}".format(id(t_reshape.data)))
OUT:
t:tensor([5,4,2,6,7,3,1,0])
t_reshape:
tensor([[[5,4],[2,6]],[[7,3],[1,0]]])
t:tensor([1024,4,2,6,7,3,1,0])
t_reshape:
tensor([[[1024,4],[2,6]],[[7,3],[1,0]]])
t.data 内存地址:1747722786440
t_reshape.data 内存地址:1747722786440
torch.transpose()
8、torch.transpose()
# ======================================= example 8 =======================================# torch.transpose# flag = True
flag =Falseif flag:# torch.transpose
t = torch.rand((2,3,4))
t_transpose = torch.transpose(t, dim0=1, dim1=2)# c*h*w h*w*cprint("t shape:{}\nt_transpose shape: {}".format(t.shape, t_transpose.shape))
OUT:
t shape:torch.Size([2,3,4])
t_transpose shape: torch.Size([2,4,3])
torch.squeeze()
9、torch.squeeze()
# ======================================= example 9 =======================================# torch.squeeze
flag =True# flag = Falseif flag:
t = torch.rand((1,2,3,1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)print("t.shape: {}".format(t.shape))print("t_sq.shape: {}".format(t_sq.shape))print("t_0.shape: {}".format(t_0.shape))print("t_1.shape: {}".format(t_1.shape))
OUT:
t.shape: torch.Size([1,2,3,1])
t_sq.shape: torch.Size([2,3])
t_0.shape: torch.Size([2,3,1])
t_1.shape: torch.Size([1,2,3,1])
张量数学运算
torch.add()
# ======================================= example 8 =======================================# torch.add
flag =True# flag = Falseif flag:
t_0 = torch.randn((3,3))
t_1 = torch.ones_like(t_0)
t_add = torch.add(t_0,10, t_1)print("t_0:\n{}\nt_1:\n{}\nt_add_10:\n{}".format(t_0, t_1, t_add))
OUT:
t_0:
tensor([[0.6614,0.2669,0.0617],[0.6213,-0.4519,-0.1661],[-1.5228,0.3817,-1.0276]])
t_1:
tensor([[1.,1.,1.],[1.,1.,1.],[1.,1.,1.]])
t_add_10:
tensor([[10.6614,10.2669,10.0617],[10.6213,9.5481,9.8339],[8.4772,10.3817,8.9724]])
本文转载自: https://blog.csdn.net/weixin_54546190/article/details/125024663
版权归原作者 ☞源仔 所有, 如有侵权,请联系我们删除。
版权归原作者 ☞源仔 所有, 如有侵权,请联系我们删除。