0


PyTorch中repeat、tile与repeat_interleave的区别

torch.Tensor.repeat

repeat

可以形象地理解为将已有的张量多次重复以组成 “分块矩阵”

import torch

""" Example 1 """
t = torch.arange(3)print(t.repeat((2,)))# tensor([0, 1, 2, 0, 1, 2])print(t.repeat((2,2)))# tensor([[0, 1, 2, 0, 1, 2],#         [0, 1, 2, 0, 1, 2]])""" Example 2 """
t = torch.arange(4).reshape(2,2)print(t.repeat((2,)))# RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensorprint(t.repeat((2,2)))# tensor([[0, 1, 0, 1],#         [2, 3, 2, 3],#         [0, 1, 0, 1],#         [2, 3, 2, 3]])print(t.repeat((2,3,4)))# tensor([[[0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3],#          [0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3],#          [0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3]],# #         [[0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3],#          [0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3],#          [0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3]]])

可以看出要

repeat

的维度不能低于张量本身的维度。

torch.Tensor.tile

大部分情况下,

tile

repeat

的作用相同,如下:

""" Example 1 """
t = torch.arange(3)print(t.tile((2,)))# tensor([0, 1, 2, 0, 1, 2])print(t.tile((2,2)))# tensor([[0, 1, 2, 0, 1, 2],#         [0, 1, 2, 0, 1, 2]])""" Example 2 """
t = torch.arange(4).reshape(2,2)print(t.tile((2,)))# tensor([[0, 1, 0, 1],#         [2, 3, 2, 3]])print(t.tile((2,2)))# tensor([[0, 1, 0, 1],#         [2, 3, 2, 3],#         [0, 1, 0, 1],#         [2, 3, 2, 3]])print(t.tile((2,3,4)))# tensor([[[0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3],#          [0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3],#          [0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3]],# #         [[0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3],#          [0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3],#          [0, 1, 0, 1, 0, 1, 0, 1],#          [2, 3, 2, 3, 2, 3, 2, 3]]])

repeat

不同的是,当要重复的维度低于张量的维度时,

tile

也能够处理,此时

tile

会使用前置

    1
   
  
  
   1
  
 
1 自动补齐维度。

torch.Tensor.repeat_interleave

之前提到的

repeat

tile

都是重复整个张量,而这次的

repeat_interleave

则是重复张量中的元素

参数如下:

torch.Tensor.repeat_interleave(repeats, dim=None)
  • repeats:代表张量中每个元素将要重复的次数。可以为整数或张量;
  • dim:决定了沿哪一个轴去重复数字。默认情况下会将输入展平再进行重复,最后输出展平的张量。
""" Example 1 """
t = torch.arange(3)print(t.repeat_interleave(repeats=3))# tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])""" Example 2 """
t = torch.arange(4).reshape(2,2)print(t.repeat_interleave(repeats=3))# tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])print(t.repeat_interleave(repeats=3, dim=0))# tensor([[0, 1],#         [0, 1],#         [0, 1],#         [2, 3],#         [2, 3],#         [2, 3]])print(t.repeat_interleave(repeats=3, dim=1))# tensor([[0, 0, 0, 1, 1, 1],#         [2, 2, 2, 3, 3, 3]])""" Example 3 """
t = torch.arange(4).reshape(2,2)print(t.repeat_interleave(repeats=torch.tensor([2,3]), dim=0))# t的第一行重复2次,第2行重复3次# tensor([[0, 1],#         [0, 1],#         [2, 3],#         [2, 3],#         [2, 3]])print(t.repeat_interleave(repeats=torch.tensor([3,2]), dim=1))# t的第一列重复3次,第2列重复2次# tensor([[0, 0, 0, 1, 1],#         [2, 2, 2, 3, 3]])

本文转载自: https://blog.csdn.net/raelum/article/details/125554754
版权归原作者 raelum 所有, 如有侵权,请联系我们删除。

“PyTorch中repeat、tile与repeat_interleave的区别”的评论:

还没有评论