提示:计算复杂度的简单理解(第一次写博客)
计算复杂度
计算复杂度
我们以Vicinity Vision Transformer论文中的图为例。
图注:标准自注意力(左)和线性化自注意力(右)的图示。
N
N
N表示输入图像的
p
a
t
c
h
patch
patch数,
d
d
d是特征维度。使
N
≫
d
N\gg d
N≫d,线性化自注意力的计算复杂度相对于输入长度线性增长,而标准自注意力的计算复杂度是二次的。
从输入到输出可以这样计算:
(
N
×
d
)
×
(
d
×
N
)
=
N
×
N
×
(
d
×
N
)
×
(
N
×
N
)
=
d
×
N
(N\times d)\times (d\times N)=N\times N\times (d\times N)\times (N\times N)=d\times N
(N×d)×(d×N)=N×N×(d×N)×(N×N)=d×N
(
d
×
N
)
×
(
N
×
d
)
=
d
×
d
×
(
d
×
d
)
×
(
d
×
N
)
=
d
×
N
(d\times N)\times (N\times d)=d\times d\times (d\times d)\times (d\times N)=d\times N
(d×N)×(N×d)=d×d×(d×d)×(d×N)=d×N
关于计算复杂度:其实可以认为是乘法次数。我们给出最直观的解释。
假设有两个矩阵做乘法,如下:
[
1
2
3
4
5
6
]
×
[
1
2
3
4
5
6
]
=
[
1
2
3
4
5
6
7
8
9
]
\left[\begin{matrix}1&2\\3&4\\5&6\\\end{matrix}\right]\times\left[\begin{matrix}1&2&3\\4&5&6\\\end{matrix}\right]=\left[\begin{matrix}1&2&3\\4&5&6\\7&8&9\\\end{matrix}\right]
⎣⎡135246⎦⎤×[142536]=⎣⎡147258369⎦⎤,其中行数为
N
N
N,列数为
d
d
d。
(
3
×
2
)
×
(
2
×
3
)
=
(
3
×
3
)
×
(
N
×
d
)
×
(
d
×
N
)
=
(
N
×
N
)
(3\times 2)\times (2\times 3)=(3\times 3)\times (N\times d)\times (d\times N)=(N\times N)
(3×2)×(2×3)=(3×3)×(N×d)×(d×N)=(N×N)
3
×
3
3\times 3
3×3矩阵第一个元素涉及的乘法次数:
1
×
1
+
2
×
4
=
9
1\times 1+2\times 4=9
1×1+2×4=9 共2次乘法;其它元素是一样的。最后可以得到
2
×
9
=
2
×
3
×
3
=
d
×
N
×
N
=
N
2
d
2\times 9=2\times 3\times 3=d\times N\times N=N^{2}d
2×9=2×3×3=d×N×N=N2d.
假设又有两个矩阵做乘法,如下:
[
1
2
3
4
5
6
]
×
[
1
2
3
4
5
6
]
=
[
1
2
3
4
]
\left[\begin{matrix}1&2&3\\4&5&6\\\end{matrix}\right]\times\left[\begin{matrix}1&2\\3&4\\5&6\\\end{matrix}\right]=\left[\begin{matrix}1&2\\3&4\\\end{matrix}\right]
[142536]×⎣⎡135246⎦⎤=[1324],其中行数为
d
d
d,列数为
N
N
N。
(
2
×
3
)
×
(
3
×
2
)
=
(
2
×
2
)
×
(
d
×
N
)
×
(
N
×
d
)
=
(
d
×
d
)
(2\times 3)\times (3\times 2)=(2\times 2)\times (d\times N)\times (N\times d)=(d\times d)
(2×3)×(3×2)=(2×2)×(d×N)×(N×d)=(d×d)
2
×
2
2\times 2
2×2矩阵第一个元素涉及的乘法次数:
1
×
1
+
2
×
3
+
2
×
5
=
17
1\times 1+2\times 3+2\times 5=17
1×1+2×3+2×5=17 共3次乘法;其它元素是一样的。最后可以得到
3
×
4
=
3
×
2
×
2
=
N
×
d
×
d
=
N
d
2
3\times 4=3\times 2\times 2=N\times d\times d=Nd^2
3×4=3×2×2=N×d×d=Nd2 .
为什么会有这种情况呢?以第二个例子为例,可以观察到,所得结果的一个元素的乘法数量和消失的维度大小有关,也就是列数
N
N
N,或者说,列数
N
N
N就是所得结果一个元素的乘法次数。那么多少个元素呢?元素个数就要看你是如何进行的乘法操作,其实就是矩阵大小。比如
(
2
×
3
)
×
(
3
×
2
)
=
(
2
×
2
)
×
(
d
×
N
)
×
(
N
×
d
)
=
(
d
×
d
)
(2\times 3)\times (3\times 2)=(2\times 2)\times (d\times N)\times (N\times d)=(d\times d)
(2×3)×(3×2)=(2×2)×(d×N)×(N×d)=(d×d),那么就是
d
2
d^2
d2个元素,最后乘法次数就是
N
d
2
Nd^2
Nd2。
乘法次数=消失的维度 × 所得矩阵大小。
那么计算复杂度呢?我们不要去管
O
(
∙
)
O(\bullet)
O(∙)具体代表什么,这不重要。
以第一个图为例,乘法次数1:
(
N
×
d
)
×
(
d
×
N
)
=
N
2
d
(N\times d)\times (d\times N)=N^{2}d
(N×d)×(d×N)=N2d;乘法次数
2
2
2:
(
N
×
d
)
×
(
d
×
N
)
=
N
2
d
(N\times d)\times (d\times N)=N^{2}d
(N×d)×(d×N)=N2d。
O
(
N
2
d
+
N
2
d
)
=
O
(
N
2
)
O(N^{2}d+N^{2}d)=O(N^2)
O(N2d+N2d)=O(N2)。因为
N
≫
d
N\gg d
N≫d,所以
d
d
d(还有常数
2
2
2)被省略了,即
O
(
N
2
)
O(N^2)
O(N2)。
以第二个图为例,乘法次数1:
(
d
×
N
)
×
(
N
×
d
)
=
N
d
2
(d\times N)\times (N\times d)=Nd^2
(d×N)×(N×d)=Nd2;乘法次数2:
(
d
×
d
)
×
(
d
×
N
)
=
N
d
2
(d\times d)\times (d\times N)=Nd^2
(d×d)×(d×N)=Nd2。
O
(
N
d
2
+
N
d
2
)
=
O
(
N
)
O(Nd^2+Nd^2)=O(N)
O(Nd2+Nd2)=O(N)。因为
N
≫
d
N\gg d
N≫d,所以
d
d
d(还有常数2)被省略了,即
O
(
N
)
O(N)
O(N)。
事实告诉我们,我们两个的结果一样,但是我们可以通过控制中间过程减少计算复杂度。
版权归原作者 qq_42584216 所有, 如有侵权,请联系我们删除。