0


Pytorch框架学习路径(二:张量操作)

张量操作

文章目录

张量操作

张量拼接与切分

torch.cat()

在这里插入图片描述
1、torch.cat创建张量

import torch
torch.manual_seed(1)# ======================================= example 1 =======================================# torch.cat

flag =True# flag = Falseif flag:
    t = torch.ones((2,3))

    t_0 = torch.cat([t, t], dim=0)
    t_1 = torch.cat([t, t, t], dim=1)print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))

OUT:

t_0:tensor([[1.,1.,1.],[1.,1.,1.],[1.,1.,1.],[1.,1.,1.]]) shape:torch.Size([4,3])
t_1:tensor([[1.,1.,1.,1.,1.,1.,1.,1.,1.],[1.,1.,1.,1.,1.,1.,1.,1.,1.]]) shape:torch.Size([2,9])

torch.stack()

2、 torch.stack()创建张量

# ======================================= example 2 =======================================# torch.stack

flag =True# flag = Falseif flag:
    t = torch.ones((2,3))# 增加一个第0个维度,并在这个维度上合并
    t_stack_0 = torch.stack([t, t, t], dim=0)# 增加一个第2个维度,并在这个维度上合并
    t_stack_1 = torch.stack([t, t, t], dim=2)print("t: {} shape: {}".format(t, t.shape))print("\nt_stack_0: {} shape: {}".format(t_stack_0, t_stack_0.shape))print("\nt_stack_1: {} shape: {}".format(t_stack_1, t_stack_1.shape))

OUT:

t: tensor([[1.,1.,1.],[1.,1.,1.]]) shape: torch.Size([2,3])

t_stack_0: tensor([[[1.,1.,1.],[1.,1.,1.]],[[1.,1.,1.],[1.,1.,1.]],[[1.,1.,1.],[1.,1.,1.]]]) shape: torch.Size([3,2,3])

t_stack_1: tensor([[[1.,1.,1.],[1.,1.,1.],[1.,1.,1.]],[[1.,1.,1.],[1.,1.,1.],[1.,1.,1.]]]) shape: torch.Size([2,3,3])

torch.chunk()

在这里插入图片描述
3、torch.chunk()创建张量

# ======================================= example 3 =======================================# torch.chunk

flag =True# flag = Falseif flag:
    a = torch.ones((2,7))# 7
    list_of_tensors = torch.chunk(a, dim=1, chunks=3)# 按第一维度(也就是列)切分成3份for idx, t inenumerate(list_of_tensors):print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

OUT:

第1个张量:tensor([[1.,1.,1.],[1.,1.,1.]]), shape is torch.Size([2,3])
第2个张量:tensor([[1.,1.,1.],[1.,1.,1.]]), shape is torch.Size([2,3])
第3个张量:tensor([[1.],[1.]]), shape is torch.Size([2,1])

torch.split()

在这里插入图片描述
4、torch.split()创建张量

# ======================================= example 4 =======================================# torch.split

flag =True# flag = Falseif flag:
    t = torch.ones((2,5))# list[2, 1, 2]求和一定要等于指定维度dim=1上的大小5,不然会报错
    list_of_tensors = torch.split(t,[2,1,2], dim=1)# [2 , 1, 2]for idx, t inenumerate(list_of_tensors):print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))# 举例:sum[2, 1, 1] = 4 不等于 5
    list_of_tensors = torch.split(t,[2,1,1], dim=1)for idx, t inenumerate(list_of_tensors):print("第{}个张量:{}, shape is {}".format(idx, t, t.shape))

OUT:

第1个张量:tensor([[1.,1.],[1.,1.]]), shape is torch.Size([2,2])
第2个张量:tensor([[1.],[1.]]), shape is torch.Size([2,1])
第3个张量:tensor([[1.,1.],[1.,1.]]), shape is torch.Size([2,2])
Traceback (most recent call last):
  File "E:/Code/In_pytorch/pytorch_lesson_code/Hello_pytorch/lesson/lesson-03/lesson-03.py", line 70,in<module>
    list_of_tensors = torch.split(t,[2,1,1], dim=1)
  File "D:\ProgramData\Anaconda3\envs\py37\lib\site-packages\torch\functional.py", line 156,in split
    return tensor.split(split_size_or_sections, dim)
  File "D:\ProgramData\Anaconda3\envs\py37\lib\site-packages\torch\_tensor.py", line 518,in split
    returnsuper(Tensor, self).split_with_sizes(split_size, dim)
RuntimeError: start (2)+ length (1) exceeds dimension size (2).

张量索引

torch.index_select()

在这里插入图片描述
5、torch.index_select()

# ======================================= example 5 =======================================# torch.index_select

flag =True# flag = Falseif flag:
    t = torch.randint(0,9, size=(3,3))
    idx = torch.tensor([0,2], dtype=torch.long)# float
    t_select_0 = torch.index_select(t, dim=0, index=idx)print("t:\n{}\nt_select_0:\n{}".format(t, t_select_0))print("\n")
    
    t_select_1 = torch.index_select(t, dim=1, index=idx)print("t:\n{}\nt_select_1:\n{}".format(t, t_select_1))

OUT:

t:
tensor([[4,5,0],[5,7,1],[2,5,8]])
t_select_0:
tensor([[4,5,0],[2,5,8]])

t:
tensor([[4,5,0],[5,7,1],[2,5,8]])
t_select_1:
tensor([[4,0],[5,1],[2,8]])

torch.masked_select()

在这里插入图片描述
6、torch.masked_select()

# ======================================= example 6 =======================================# torch.masked_select

flag =True# flag = Falseif flag:

    t = torch.randint(0,9, size=(3,3))# le指:小于
    mask = t.le(5)# ge is mean greater than or equal/   gt: greater than  le  lt
    t_select = torch.masked_select(t, mask)print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))

OUT:

t:
tensor([[4,5,0],[5,7,1],[2,5,8]])
mask:
tensor([[True,True,True],[True,False,True],[True,True,False]])
t_select:
tensor([4,5,0,5,1,2,5])

张量变换

torch.reshape()

在这里插入图片描述
7、torch.reshape()

# ======================================= example 7 =======================================# torch.reshape

flag =True# flag = Falseif flag:
    t = torch.randperm(8)
    t_reshape = torch.reshape(t,(-1,2,2))# -1print("t:{}\nt_reshape:\n{}".format(t, t_reshape))

    t[0]=1024print("\n")print("t:{}\nt_reshape:\n{}".format(t, t_reshape))print("t.data 内存地址:{}".format(id(t.data)))print("t_reshape.data 内存地址:{}".format(id(t_reshape.data)))

OUT:

t:tensor([5,4,2,6,7,3,1,0])
t_reshape:
tensor([[[5,4],[2,6]],[[7,3],[1,0]]])

t:tensor([1024,4,2,6,7,3,1,0])
t_reshape:
tensor([[[1024,4],[2,6]],[[7,3],[1,0]]])
t.data 内存地址:1747722786440
t_reshape.data 内存地址:1747722786440

torch.transpose()

在这里插入图片描述
8、torch.transpose()

# ======================================= example 8 =======================================# torch.transpose# flag = True
flag =Falseif flag:# torch.transpose
    t = torch.rand((2,3,4))
    t_transpose = torch.transpose(t, dim0=1, dim1=2)# c*h*w     h*w*cprint("t shape:{}\nt_transpose shape: {}".format(t.shape, t_transpose.shape))

OUT:

t shape:torch.Size([2,3,4])
t_transpose shape: torch.Size([2,4,3])

torch.squeeze()

在这里插入图片描述
9、torch.squeeze()

# ======================================= example 9 =======================================# torch.squeeze

flag =True# flag = Falseif flag:
    t = torch.rand((1,2,3,1))
    t_sq = torch.squeeze(t)
    t_0 = torch.squeeze(t, dim=0)
    t_1 = torch.squeeze(t, dim=1)print("t.shape: {}".format(t.shape))print("t_sq.shape: {}".format(t_sq.shape))print("t_0.shape: {}".format(t_0.shape))print("t_1.shape: {}".format(t_1.shape))

OUT:

t.shape: torch.Size([1,2,3,1])
t_sq.shape: torch.Size([2,3])
t_0.shape: torch.Size([2,3,1])
t_1.shape: torch.Size([1,2,3,1])

张量数学运算

在这里插入图片描述

torch.add()

在这里插入图片描述

# ======================================= example 8 =======================================# torch.add

flag =True# flag = Falseif flag:
    t_0 = torch.randn((3,3))
    t_1 = torch.ones_like(t_0)
    t_add = torch.add(t_0,10, t_1)print("t_0:\n{}\nt_1:\n{}\nt_add_10:\n{}".format(t_0, t_1, t_add))

OUT:

t_0:
tensor([[0.6614,0.2669,0.0617],[0.6213,-0.4519,-0.1661],[-1.5228,0.3817,-1.0276]])
t_1:
tensor([[1.,1.,1.],[1.,1.,1.],[1.,1.,1.]])
t_add_10:
tensor([[10.6614,10.2669,10.0617],[10.6213,9.5481,9.8339],[8.4772,10.3817,8.9724]])

本文转载自: https://blog.csdn.net/weixin_54546190/article/details/125024663
版权归原作者 ☞源仔 所有, 如有侵权,请联系我们删除。

“Pytorch框架学习路径(二:张量操作)”的评论:

还没有评论