0


torch.einsum() 用法说明

关联网站:
einops官网


torch.einsum( equation , ** operands* ) → Tensor

对输入元素

operands

沿指定的维度、使用爱因斯坦求和符号的乘积求和。

参数:

  • equation ( string ) – 爱因斯坦求和的下标。
  • operandsList[ Tensor *]*)——计算爱因斯坦求和的张量。

​ Einsum允许计算许多常见的多维线性代数数组运算,方法是根据由

equation

给出的爱因斯坦求和约定,以速记(short-hand)格式表示它们。这种格式的细节在下面描述,但通常想法是

operands

用一些下标标记输入的每个维度,并定义哪些下标是输出的一部分,

operands

然后通过对下标不属于输出维度的元素的乘积求和来计算输出。例如,矩阵乘法可以使用einsum计算为torch.einsum(“ij,jk->ik”, A, B)。这里,j 是求和下标,i 和 k 是输出下标(有关原因的更多详细信息,请参见下面的部分)。

equation 参数说明:

equation

字符串以与维度相同的顺序指定输入的每个维度的下标( [a-z,A-Z]

operands

中的字母) ,用逗号 (‘,’) 分隔每个操作数的下标,例如’ij,jk’指定两个二维操作数的下标。标有相同下标的维度必须是可广播的,即它们的大小必须匹配或为1。例外情况是,如果对相同的输入操作数重复下标,在这种情况下,此操作数的标有此下标的维度必须在大小上匹配,并且操作数将被其沿这些维度的对角线替换。

equation

中只出现一次的下标将是输出的一部分,按字母顺序递增排序。输出是通过按元素乘以输入来计算的

operands

,它们的维度根据下标对齐,然后对下标不属于输出的维度求和。

​ 或者,可以通过在等式末尾添加箭头 (

->

) 后跟输出下标来显式定义输出下标。例如,以下等式计算矩阵乘法的转置:‘ij,jk->ki’。对于某些输入操作数,输出下标必须至少出现一次,而对于输出则最多出现一次。

​ 可以使用省略号 (

...

) 代替下标来广播省略号所涵盖的维度。每个输入操作数最多可以包含一个省略号,它将覆盖下标未覆盖的维度,例如,对于具有 5 维的输入操作数,等式“ab…c”中的省略号覆盖第三和第四维。省略号不需要覆盖

operands

中相同数量的维度,但省略号的“形状”(它们覆盖的维度的大小)必须一起传播。如果未使用箭头 (

->

) 表示法显式定义输出,则省略号将首先出现在输出(最左侧的维度)中,位于输入操作数仅出现一次的下标标签之前。例如下面的等式实现批量矩阵乘法’…ij,…jk’。

​ 最后几点注意事项:

equation

可能在不同元素(下标、省略号、箭头和逗号)之间包含空格,但类似“…”的内容无效。空字符串 ’ ’ 对标量operands有效。

注:

  1. torch.einsum处理省略号 (‘…’) 的方式与 NumPy 不同,因为它允许对省略号覆盖的维度求和,也就是说,省略号不需要是输出的一部分。
  2. 此函数不会优化给定的表达式,因此用于相同计算的不同公式可能会运行得更快或消耗更少的内存。像 opt_einsum ( https://optimized-einsum.readthedocs.io/en/stable/ )这样的项目可以为你优化公式。

例:

# trace(迹)>>> torch.einsum('ii', torch.randn(4,4))
tensor(-1.4157)# diagonal(对角线)>>> torch.einsum('ii->i', torch.randn(4,4))
tensor([0.0266,2.4750,-1.0881,-1.3075])# outer product(外积)>>> x = torch.randn(5)
tensor([-0.3550,-0.6059,-1.3375,-1.5649,0.2675])>>> y = torch.randn(4)
tensor([-0.2202,-1.5290,-2.0062,0.9600])>>> torch.einsum('i,j->ij', x, y)
tensor([[0.0782,0.5428,0.7122,-0.3408],[0.1334,0.9264,1.2156,-0.5817],[0.2945,2.0451,2.6834,-1.2840],[0.3445,2.3927,3.1396,-1.5023],[-0.0589,-0.4089,-0.5366,0.2568]])# batch matrix multiplication(批量矩阵乘法)>>> As = torch.randn(3,2,5)
tensor([[[-0.0306,0.8251,0.0157,-0.4563,0.5550],[-1.4550,0.0762,0.9258,0.1198,-1.1737]],[[-0.4460,-0.7224,0.7260,0.7552,0.0326],[-0.3904,-1.2392,0.4848,-0.4756,0.2301]],[[1.5307,0.7668,-1.9426,1.7473,-0.6258],[0.6758,1.8240,-0.2053,0.0973,-0.6118]]])>>> Bs = torch.randn(3,5,4)
tensor([[[-0.7054,-0.2155,-1.5458,-0.8236],[-1.4957,-2.2604,0.6897,-1.0360],[1.2924,0.2798,1.0544,0.3656],[-0.3993,-1.2463,-0.6601,0.2706],[1.0727,0.5418,-0.2516,-0.1133]],[[0.4215,1.5712,-0.2351,1.3741],[1.6418,0.9806,-1.0259,-1.1297],[0.7326,0.4989,0.4404,0.2975],[-0.6866,0.5696,-0.8942,0.6815],[1.7486,0.5344,0.0538,0.5258]],[[1.6280,-1.3989,-0.2900,0.0936],[-0.9436,-0.1766,0.6780,0.3152],[0.9645,-0.1199,-1.1644,-1.0290],[-0.2791,-0.8086,0.2161,0.7901],[1.3222,-1.4023,-2.4181,-1.2875]]])>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-0.4147,-0.9847,0.7946,-1.0103],[0.8020,-0.3849,3.4942,1.6233]],[[-1.3035,-0.5993,0.4922,0.9511],[-1.1150,-1.7346,2.0142,0.8047]],[[-1.4202,-2.5790,4.2288,4.5702],[-1.6549,-0.4636,2.7802,1.7141]]])# with sublist format and ellipsis(带有子列表格式和省略号)>>> torch.einsum(As,[...,0,1], Bs,[...,1,2],[...,0,2])
tensor([[[-0.4147,-0.9847,0.7946,-1.0103],[0.8020,-0.3849,3.4942,1.6233]],[[-1.3035,-0.5993,0.4922,0.9511],[-1.1150,-1.7346,2.0142,0.8047]],[[-1.4202,-2.5790,4.2288,4.5702],[-1.6549,-0.4636,2.7802,1.7141]]])# batch permute(批量交换)>>> A = torch.randn(2,3,4,5)>>> torch.einsum('...ij->...ji', A).shape
torch.Size([2,3,5,4])# equivalent to torch.nn.functional.bilinear(等价于torch.nn.functional.bilinear)>>> A = torch.randn(3,5,4)>>> l = torch.randn(2,5)>>> r = torch.randn(2,4)>>> torch.einsum('bn,anm,bm->ba', l, A, r)
tensor([[-0.3430,-5.2405,0.4494],[0.3311,5.5201,-3.0356]])

本文转载自: https://blog.csdn.net/Dust_Evc/article/details/128434089
版权归原作者 Dust_Evc 所有, 如有侵权,请联系我们删除。

“torch.einsum() 用法说明”的评论:

还没有评论