torch.Tensor.repeat
repeat
可以形象地理解为将已有的张量多次重复以组成 “分块矩阵”。
import torch
""" Example 1 """
t = torch.arange(3)print(t.repeat((2,)))# tensor([0, 1, 2, 0, 1, 2])print(t.repeat((2,2)))# tensor([[0, 1, 2, 0, 1, 2],# [0, 1, 2, 0, 1, 2]])""" Example 2 """
t = torch.arange(4).reshape(2,2)print(t.repeat((2,)))# RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensorprint(t.repeat((2,2)))# tensor([[0, 1, 0, 1],# [2, 3, 2, 3],# [0, 1, 0, 1],# [2, 3, 2, 3]])print(t.repeat((2,3,4)))# tensor([[[0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3],# [0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3],# [0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3]],# # [[0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3],# [0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3],# [0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3]]])
可以看出要
repeat
的维度不能低于张量本身的维度。
torch.Tensor.tile
大部分情况下,
tile
与
repeat
的作用相同,如下:
""" Example 1 """
t = torch.arange(3)print(t.tile((2,)))# tensor([0, 1, 2, 0, 1, 2])print(t.tile((2,2)))# tensor([[0, 1, 2, 0, 1, 2],# [0, 1, 2, 0, 1, 2]])""" Example 2 """
t = torch.arange(4).reshape(2,2)print(t.tile((2,)))# tensor([[0, 1, 0, 1],# [2, 3, 2, 3]])print(t.tile((2,2)))# tensor([[0, 1, 0, 1],# [2, 3, 2, 3],# [0, 1, 0, 1],# [2, 3, 2, 3]])print(t.tile((2,3,4)))# tensor([[[0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3],# [0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3],# [0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3]],# # [[0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3],# [0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3],# [0, 1, 0, 1, 0, 1, 0, 1],# [2, 3, 2, 3, 2, 3, 2, 3]]])
与
repeat
不同的是,当要重复的维度低于张量的维度时,
tile
也能够处理,此时
tile
会使用前置
1
1
1 自动补齐维度。
torch.Tensor.repeat_interleave
之前提到的
repeat
和
tile
都是重复整个张量,而这次的
repeat_interleave
则是重复张量中的元素。
参数如下:
torch.Tensor.repeat_interleave(repeats, dim=None)
repeats
:代表张量中每个元素将要重复的次数。可以为整数或张量;dim
:决定了沿哪一个轴去重复数字。默认情况下会将输入展平再进行重复,最后输出展平的张量。
""" Example 1 """
t = torch.arange(3)print(t.repeat_interleave(repeats=3))# tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])""" Example 2 """
t = torch.arange(4).reshape(2,2)print(t.repeat_interleave(repeats=3))# tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])print(t.repeat_interleave(repeats=3, dim=0))# tensor([[0, 1],# [0, 1],# [0, 1],# [2, 3],# [2, 3],# [2, 3]])print(t.repeat_interleave(repeats=3, dim=1))# tensor([[0, 0, 0, 1, 1, 1],# [2, 2, 2, 3, 3, 3]])""" Example 3 """
t = torch.arange(4).reshape(2,2)print(t.repeat_interleave(repeats=torch.tensor([2,3]), dim=0))# t的第一行重复2次,第2行重复3次# tensor([[0, 1],# [0, 1],# [2, 3],# [2, 3],# [2, 3]])print(t.repeat_interleave(repeats=torch.tensor([3,2]), dim=1))# t的第一列重复3次,第2列重复2次# tensor([[0, 0, 0, 1, 1],# [2, 2, 2, 3, 3]])
版权归原作者 raelum 所有, 如有侵权,请联系我们删除。