0


计算复杂度

提示:计算复杂度的简单理解(第一次写博客)

计算复杂度

计算复杂度

我们以Vicinity Vision Transformer论文中的图为例。
在这里插入图片描述图注:标准自注意力(左)和线性化自注意力(右)的图示。

    N
   
  
  
   N
  
 
N表示输入图像的

 
  
   
    p
   
   
    a
   
   
    t
   
   
    c
   
   
    h
   
  
  
   patch
  
 
patch数,

 
  
   
    d
   
  
  
   d
  
 
d是特征维度。使

 
  
   
    N
   
   
    ≫
   
   
    d
   
  
  
   N\gg d
  
 
N≫d,线性化自注意力的计算复杂度相对于输入长度线性增长,而标准自注意力的计算复杂度是二次的。

从输入到输出可以这样计算:

    (
   
   
    N
   
   
    ×
   
   
    d
   
   
    )
   
   
    ×
   
   
    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    =
   
   
    N
   
   
    ×
   
   
    N
   
   
    ×
   
   
    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    ×
   
   
    (
   
   
    N
   
   
    ×
   
   
    N
   
   
    )
   
   
    =
   
   
    d
   
   
    ×
   
   
    N
   
  
  
   (N\times d)\times (d\times N)=N\times N\times (d\times N)\times (N\times N)=d\times N
  
 
(N×d)×(d×N)=N×N×(d×N)×(N×N)=d×N


 
  
   
    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    ×
   
   
    (
   
   
    N
   
   
    ×
   
   
    d
   
   
    )
   
   
    =
   
   
    d
   
   
    ×
   
   
    d
   
   
    ×
   
   
    (
   
   
    d
   
   
    ×
   
   
    d
   
   
    )
   
   
    ×
   
   
    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    =
   
   
    d
   
   
    ×
   
   
    N
   
  
  
   (d\times N)\times (N\times d)=d\times d\times (d\times d)\times (d\times N)=d\times N
  
 
(d×N)×(N×d)=d×d×(d×d)×(d×N)=d×N

关于计算复杂度:其实可以认为是乘法次数。我们给出最直观的解释。

假设有两个矩阵做乘法,如下:

     [
    
    
     
      
       
        
         1
        
       
      
      
       
        
         2
        
       
      
     
     
      
       
        
         3
        
       
      
      
       
        
         4
        
       
      
     
     
      
       
        
         5
        
       
      
      
       
        
         6
        
       
      
     
    
    
     ]
    
   
   
    ×
   
   
    
     [
    
    
     
      
       
        
         1
        
       
      
      
       
        
         2
        
       
      
      
       
        
         3
        
       
      
     
     
      
       
        
         4
        
       
      
      
       
        
         5
        
       
      
      
       
        
         6
        
       
      
     
    
    
     ]
    
   
   
    =
   
   
    
     [
    
    
     
      
       
        
         1
        
       
      
      
       
        
         2
        
       
      
      
       
        
         3
        
       
      
     
     
      
       
        
         4
        
       
      
      
       
        
         5
        
       
      
      
       
        
         6
        
       
      
     
     
      
       
        
         7
        
       
      
      
       
        
         8
        
       
      
      
       
        
         9
        
       
      
     
    
    
     ]
    
   
  
  
   \left[\begin{matrix}1&2\\3&4\\5&6\\\end{matrix}\right]\times\left[\begin{matrix}1&2&3\\4&5&6\\\end{matrix}\right]=\left[\begin{matrix}1&2&3\\4&5&6\\7&8&9\\\end{matrix}\right]
  
 
⎣⎡​135​246​⎦⎤​×[14​25​36​]=⎣⎡​147​258​369​⎦⎤​,其中行数为

 
  
   
    N
   
  
  
   N
  
 
N,列数为

 
  
   
    d
   
  
  
   d
  
 
d。


 
  
   
    (
   
   
    3
   
   
    ×
   
   
    2
   
   
    )
   
   
    ×
   
   
    (
   
   
    2
   
   
    ×
   
   
    3
   
   
    )
   
   
    =
   
   
    (
   
   
    3
   
   
    ×
   
   
    3
   
   
    )
   
   
    ×
   
   
    (
   
   
    N
   
   
    ×
   
   
    d
   
   
    )
   
   
    ×
   
   
    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    =
   
   
    (
   
   
    N
   
   
    ×
   
   
    N
   
   
    )
   
  
  
   (3\times 2)\times (2\times 3)=(3\times 3)\times (N\times d)\times (d\times N)=(N\times N)
  
 
(3×2)×(2×3)=(3×3)×(N×d)×(d×N)=(N×N)


 
  
   
    3
   
   
    ×
   
   
    3
   
  
  
   3\times 3
  
 
3×3矩阵第一个元素涉及的乘法次数:

 
  
   
    1
   
   
    ×
   
   
    1
   
   
    +
   
   
    2
   
   
    ×
   
   
    4
   
   
    =
   
   
    9
   
  
  
   1\times 1+2\times 4=9
  
 
1×1+2×4=9 共2次乘法;其它元素是一样的。最后可以得到

 
  
   
    2
   
   
    ×
   
   
    9
   
   
    =
   
   
    2
   
   
    ×
   
   
    3
   
   
    ×
   
   
    3
   
   
    =
   
   
    d
   
   
    ×
   
   
    N
   
   
    ×
   
   
    N
   
   
    =
   
   
    
     N
    
    
     2
    
   
   
    d
   
  
  
   2\times 9=2\times 3\times 3=d\times N\times N=N^{2}d
  
 
2×9=2×3×3=d×N×N=N2d.

假设又有两个矩阵做乘法,如下:

     [
    
    
     
      
       
        
         1
        
       
      
      
       
        
         2
        
       
      
      
       
        
         3
        
       
      
     
     
      
       
        
         4
        
       
      
      
       
        
         5
        
       
      
      
       
        
         6
        
       
      
     
    
    
     ]
    
   
   
    ×
   
   
    
     [
    
    
     
      
       
        
         1
        
       
      
      
       
        
         2
        
       
      
     
     
      
       
        
         3
        
       
      
      
       
        
         4
        
       
      
     
     
      
       
        
         5
        
       
      
      
       
        
         6
        
       
      
     
    
    
     ]
    
   
   
    =
   
   
    
     [
    
    
     
      
       
        
         1
        
       
      
      
       
        
         2
        
       
      
     
     
      
       
        
         3
        
       
      
      
       
        
         4
        
       
      
     
    
    
     ]
    
   
  
  
   \left[\begin{matrix}1&2&3\\4&5&6\\\end{matrix}\right]\times\left[\begin{matrix}1&2\\3&4\\5&6\\\end{matrix}\right]=\left[\begin{matrix}1&2\\3&4\\\end{matrix}\right]
  
 
[14​25​36​]×⎣⎡​135​246​⎦⎤​=[13​24​],其中行数为

 
  
   
    d
   
  
  
   d
  
 
d,列数为

 
  
   
    N
   
  
  
   N
  
 
N。


 
  
   
    (
   
   
    2
   
   
    ×
   
   
    3
   
   
    )
   
   
    ×
   
   
    (
   
   
    3
   
   
    ×
   
   
    2
   
   
    )
   
   
    =
   
   
    (
   
   
    2
   
   
    ×
   
   
    2
   
   
    )
   
   
    ×
   
   
    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    ×
   
   
    (
   
   
    N
   
   
    ×
   
   
    d
   
   
    )
   
   
    =
   
   
    (
   
   
    d
   
   
    ×
   
   
    d
   
   
    )
   
  
  
   (2\times 3)\times (3\times 2)=(2\times 2)\times (d\times N)\times (N\times d)=(d\times d)
  
 
(2×3)×(3×2)=(2×2)×(d×N)×(N×d)=(d×d)


 
  
   
    2
   
   
    ×
   
   
    2
   
  
  
   2\times 2
  
 
2×2矩阵第一个元素涉及的乘法次数:

 
  
   
    1
   
   
    ×
   
   
    1
   
   
    +
   
   
    2
   
   
    ×
   
   
    3
   
   
    +
   
   
    2
   
   
    ×
   
   
    5
   
   
    =
   
   
    17
   
  
  
   1\times 1+2\times 3+2\times 5=17
  
 
1×1+2×3+2×5=17 共3次乘法;其它元素是一样的。最后可以得到

 
  
   
    3
   
   
    ×
   
   
    4
   
   
    =
   
   
    3
   
   
    ×
   
   
    2
   
   
    ×
   
   
    2
   
   
    =
   
   
    N
   
   
    ×
   
   
    d
   
   
    ×
   
   
    d
   
   
    =
   
   
    N
   
   
    
     d
    
    
     2
    
   
  
  
   3\times 4=3\times 2\times 2=N\times d\times d=Nd^2
  
 
3×4=3×2×2=N×d×d=Nd2 .

为什么会有这种情况呢?以第二个例子为例,可以观察到,所得结果的一个元素的乘法数量和消失的维度大小有关,也就是列数

    N
   
  
  
   N
  
 
N,或者说,列数

 
  
   
    N
   
  
  
   N
  
 
N就是所得结果一个元素的乘法次数。那么多少个元素呢?元素个数就要看你是如何进行的乘法操作,其实就是矩阵大小。比如

 
  
   
    (
   
   
    2
   
   
    ×
   
   
    3
   
   
    )
   
   
    ×
   
   
    (
   
   
    3
   
   
    ×
   
   
    2
   
   
    )
   
   
    =
   
   
    (
   
   
    2
   
   
    ×
   
   
    2
   
   
    )
   
   
    ×
   
   
    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    ×
   
   
    (
   
   
    N
   
   
    ×
   
   
    d
   
   
    )
   
   
    =
   
   
    (
   
   
    d
   
   
    ×
   
   
    d
   
   
    )
   
  
  
   (2\times 3)\times (3\times 2)=(2\times 2)\times (d\times N)\times (N\times d)=(d\times d)
  
 
(2×3)×(3×2)=(2×2)×(d×N)×(N×d)=(d×d),那么就是

 
  
   
    
     d
    
    
     2
    
   
  
  
   d^2
  
 
d2个元素,最后乘法次数就是

 
  
   
    N
   
   
    
     d
    
    
     2
    
   
  
  
   Nd^2
  
 
Nd2。

乘法次数=消失的维度 × 所得矩阵大小

那么计算复杂度呢?我们不要去管

    O
   
   
    (
   
   
    ∙
   
   
    )
   
  
  
   O(\bullet)
  
 
O(∙)具体代表什么,这不重要。

以第一个图为例,乘法次数1:

    (
   
   
    N
   
   
    ×
   
   
    d
   
   
    )
   
   
    ×
   
   
    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    =
   
   
    
     N
    
    
     2
    
   
   
    d
   
  
  
   (N\times d)\times (d\times N)=N^{2}d
  
 
(N×d)×(d×N)=N2d;乘法次数

 
  
   
    2
   
  
  
   2
  
 
2:

 
  
   
    (
   
   
    N
   
   
    ×
   
   
    d
   
   
    )
   
   
    ×
   
   
    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    =
   
   
    
     N
    
    
     2
    
   
   
    d
   
  
  
   (N\times d)\times (d\times N)=N^{2}d
  
 
(N×d)×(d×N)=N2d。

 
  
   
    O
   
   
    (
   
   
    
     N
    
    
     2
    
   
   
    d
   
   
    +
   
   
    
     N
    
    
     2
    
   
   
    d
   
   
    )
   
   
    =
   
   
    O
   
   
    (
   
   
    
     N
    
    
     2
    
   
   
    )
   
  
  
   O(N^{2}d+N^{2}d)=O(N^2)
  
 
O(N2d+N2d)=O(N2)。因为

 
  
   
    N
   
   
    ≫
   
   
    d
   
  
  
   N\gg d
  
 
N≫d,所以

 
  
   
    d
   
  
  
   d
  
 
d(还有常数

 
  
   
    2
   
  
  
   2
  
 
2)被省略了,即

 
  
   
    O
   
   
    (
   
   
    
     N
    
    
     2
    
   
   
    )
   
  
  
   O(N^2)
  
 
O(N2)。

以第二个图为例,乘法次数1:

    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    ×
   
   
    (
   
   
    N
   
   
    ×
   
   
    d
   
   
    )
   
   
    =
   
   
    N
   
   
    
     d
    
    
     2
    
   
  
  
   (d\times N)\times (N\times d)=Nd^2
  
 
(d×N)×(N×d)=Nd2;乘法次数2:

 
  
   
    (
   
   
    d
   
   
    ×
   
   
    d
   
   
    )
   
   
    ×
   
   
    (
   
   
    d
   
   
    ×
   
   
    N
   
   
    )
   
   
    =
   
   
    N
   
   
    
     d
    
    
     2
    
   
  
  
   (d\times d)\times (d\times N)=Nd^2
  
 
(d×d)×(d×N)=Nd2。

 
  
   
    O
   
   
    (
   
   
    N
   
   
    
     d
    
    
     2
    
   
   
    +
   
   
    N
   
   
    
     d
    
    
     2
    
   
   
    )
   
   
    =
   
   
    O
   
   
    (
   
   
    N
   
   
    )
   
  
  
   O(Nd^2+Nd^2)=O(N)
  
 
O(Nd2+Nd2)=O(N)。因为

 
  
   
    N
   
   
    ≫
   
   
    d
   
  
  
   N\gg d
  
 
N≫d,所以

 
  
   
    d
   
  
  
   d
  
 
d(还有常数2)被省略了,即

 
  
   
    O
   
   
    (
   
   
    N
   
   
    )
   
  
  
   O(N)
  
 
O(N)。

事实告诉我们,我们两个的结果一样,但是我们可以通过控制中间过程减少计算复杂度。


本文转载自: https://blog.csdn.net/qq_42584216/article/details/125533848
版权归原作者 qq_42584216 所有, 如有侵权,请联系我们删除。

“计算复杂度”的评论:

还没有评论