0


Pytorch基础:Tensor的连续性

相关阅读

Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


    在Pytorch中,一个连续的张量指的是张量中各数据元素在底层的存储顺序与其在张量中的位置一致。这意味着每一个元素的地址可以通过下面的线性映射公式来确定:

address(i_{0},i_{2},...,i_{n-1}) = base\_address+\sum_{k=0}^{n-1}(i_{k}\cdot stride(k))

    其中,![i_{0}](https://latex.csdn.net/eq?i_%7B0%7D)是第k维的索引,![stride(k)](https://latex.csdn.net/eq?stride%28k%29)是第![k](https://latex.csdn.net/eq?k)维的步长(就是第![k](https://latex.csdn.net/eq?k)维的数据在存储时,相邻数据在底层线性存储时相隔的数据数),![base\_address](https://latex.csdn.net/eq?base%5C_address)是张量底层数据存储的起始地址。

    对于一个连续的张量,其![stride](https://latex.csdn.net/eq?stride)应该符合从最内层维度(第0维度)到最外层维度递减的模式。更准确地说,如果一个张量有![n](https://latex.csdn.net/eq?n)个维度,并且每个维度的大小是![](https://latex.csdn.net/eq?)![s_{0},s_{2},...,s_{n-1}](https://latex.csdn.net/eq?s_%7B0%7D%2Cs_%7B2%7D%2C...%2Cs_%7Bn-1%7D),那么其连续性可以使用下面的方式判定:

stride(n-1)=1

stride(k)=stride(k+1)\times s_{k+1} \ for\ k = n-2,n-1,...,0

    当你创建一个新的张量时,默认情况下,它是连续的,这意味着它的元素在内存中是按照顺序存储的。可以通过size()方法和stride()方法获得一个张量的形状和步长,利用公式判断是否连续;也可以使用storage()方法,直接获得一个张量在底层的线性存储结果;最方便的是使用.is_contiguous()方法,他会直接返回一个布尔值。
import torch

x = torch.tensor([[1, 2], [3, 4]])
print(list(x.size()))
print(x.stride())
print(x.storage())
print("Is x contiguous?", x.is_contiguous()) 

输出:
[2, 2]
(2, 1)
 1
 2
 3
 4
[torch.LongStorage of size 4]
Is x contiguous? True
    为什么会出现非连续张量呢?在PyTorch中,非连续张量的出现往往与张量视图(tensor views)的概念密切相关。张量视图允许一个新张量作为原张量的视图存在,其中新张量与其原张量共享相同的底层数据。这种设计旨在避免显式的数据复制,从而实现快速且内存高效的操作(例如切片、转置)。

    既然两个不同的张量共享底层存储,如果其中一个张量是连续的,另一个必然是不连续的,
import torch

x = torch.tensor([[1, 2], [3, 4]])
print(list(x.size()))
print(x.stride())
print(x.storage())
print("Is x contiguous?", x.is_contiguous()) 

y = x.t()
print(list(y.size()))
print(y.stride())
print(y.storage())
print("Is y contiguous?", y.is_contiguous())

输出:
[2, 2]
(2, 1)
 1
 2
 3
 4
[torch.LongStorage of size 4]
Is x contiguous? True
[2, 2]
(1, 2)
 1
 2
 3
 4
[torch.LongStorage of size 4]
Is y contiguous? False
    除了t(),Pytorch中有下面这些会返回视图的操作(因此可能出现非连续张量)。
基本的切片和索引, 例如tensor[0, 2:, 1:7:2]
adjoint()
as_strided()
detach()
diagonal()
expand()
expand_as()
movedim()
narrow()
permute()
select()
squeeze()
transpose()
t()
T
H
mT
mH
real
imag
view_as_real()
unflatten()
unfold()
unsqueeze()
view()
view_as()
unbind()
split()
hsplit()
vsplit()
tensor_split()
split_with_sizes()
swapaxes()
swapdims()
chunk()
indices() (仅限稀疏张量)
values() (仅限稀疏张量)
    除此之外,reshape(),reshape_as()和flatten() 既有可能返回张量的视图,也可能返回一个拥有独立存储空间的新张量。这取决于一些特定的条件,具体可见关于这些操作的文章。

Pytorch基础:Tensor的reshape方法_pytorch reshape-CSDN博客https://chenzhang.blog.csdn.net/article/details/133445832Pytorch基础:Tensor的flatten方法_tensor.flatten-CSDN博客https://chenzhang.blog.csdn.net/article/details/136570774 最后值得一提的是,contiguous()方法能返回一个连续的张量,如果原张量已连续,则会返回原张量。

import torch

x = torch.tensor([[1, 2], [3, 4]])
print(list(x.size()))
print(x.stride())
print(x.storage())
print("Is x contiguous?", x.is_contiguous())

y = x.t()

z = y.contiguous()
print(list(z.size()))
print(z.stride())
print(z.storage())
print("Is z contiguous?", z.is_contiguous())

输出:
[2, 2]
(2, 1)
 1
 2
 3
 4
[torch.LongStorage of size 4]
Is x contiguous? True
[2, 2]
(2, 1)
 1
 3
 2
 4
[torch.LongStorage of size 4]
Is z contiguous? True

本文转载自: https://blog.csdn.net/weixin_45791458/article/details/140736700
版权归原作者 日晨难再 所有, 如有侵权,请联系我们删除。

“Pytorch基础:Tensor的连续性”的评论:

还没有评论