0


基于CKKS的非交互式安全Transformer推理实现

Secure Transformer Inference Made Non-interactive

本文介绍了如何使用CKKS来计算transformer推理的每个部分。同时给出了一系列优化算法。主要涉及到的计算算法有以下几种: 密文的压缩与分解技术、SIMD槽折叠技术、Sgn()、QuickSum、QuickMax、密文-明文矩阵相乘、密文-密文矩阵相乘法、Softmax算法、归一化、GELU函数、Argmax函数等等。其中密文的压缩与分解技术,和SIMD槽折叠技术是本文的核心创新算法。

Abstract

随着ChatGPT的普及,安全transformer推理已经成为一个突出了研究主题。已有的解决方法通常是交互式的,涉及到客户端和服务端之间大量的通信负载和交互轮次。

本文提出NEXUS,这是第一个用于安全transformer推理的非交互式协议,其中客户端仅需要提交一个加密输入,然后等待来自服务器的加密结果即可。NEXUS的核心是两个创新的技术:SIMD密文压缩和分解技术,以及SIMD槽折叠技术。此外,同24年的另外一个解决方案相比,本方法达到了2.8倍的加速,且减少了368.6倍的带宽消耗。

1 Introduction

Transformers,例如GPT和BERT,已经彻底改变了AI领域。Transformer擅长于广泛领域的应用,比如语言翻译,内容生成以及问题回答。然而这些应用总是涉及到敏感数据,从而导致越来越多地关于用户隐私的担忧。例,OpenAI开发的ChatGPT作为一种在线推理服务,以及为开发人员提供的远程API,其中使用者通过提交prompts或者消息可以很容易地访问这些服务。尽管这些方法是方便的,但是由于使用者提交的数据可能包含敏感信息,故而造成了严重的隐私风险。

Secure inference是一种两方密码协议,该协议使模型推理以如下方式处理运行,即服务器S不会了解到关于客户C提交的输入的任何信息,且C不会了解到关于S的模型的任何信息,仅仅能得到最终的推理结果。

该协议大多被设计于安全CNNs推

     [ 
    
   
     2 
    
   
     , 
    
   
     27 
    
   
     , 
    
   
     30 
    
   
     , 
    
   
     36 
    
   
     ] 
    
   
  
    [2,27,30,36] 
   
  
[2,27,30,36],最近的许多工作也支持基于Transformer的模型 
 
  
   
   
     [ 
    
   
     10 
    
   
     , 
    
   
     24 
    
   
     , 
    
   
     26 
    
   
     , 
    
   
     35 
    
   
     , 
    
   
     38 
    
   
     , 
    
   
     40 
    
   
     ] 
    
   
  
    [10,24,26,35,38,40] 
   
  
[10,24,26,35,38,40]​,值得注意的是,这些安全Transformer模型大多都是交互式的,因此会导致巨大的通信开销和交互轮次,这里我们必须强调非交互式安全Transformer推理的重要性。
本文贡献:

本文中,我们提出了NEXUS,第一个secure transformer inference的非交互协议。通过NEXUS,C使用RNS-CKKS加密输入,S对FHE加密数据执行transformer。CKKS的SIMD技术被应用于批处理

     N 
    
   
     = 
    
    
    
      2 
     
    
      15 
     
    
   
  
    N=2^{15} 
   
  
N=215个数据,多项式近似可以用于处理非线性函数,比如GELU,softmax,层归一化和argmax。

NEXUS不需要对模型进行任何重训练与微调,且为了提高NEXUS的效率,我们提出了两种新颖的且基础的技术。

  • SIMD密文压缩与分解:该技术可以将2N个SIMD密文压缩为一个密文,然后可以使用4N个密文—明文乘法和替换将其解压回来。该技术可以大大减少客户端和服务器之间传输的密文数量,而不会为后续计算带来任何额外的开销。

在这里插入图片描述

  • SIMD槽折叠:在所有SIMD槽中计算关联函数f(),例如sum和max。结果值会自动的填充SIMD密文的槽,允许将其应用于原始密文的每个槽。

本文贡献总结如下:

  • secure transformer inference的第一个非交互协议
  • 用于密文打包的SIMD密文压缩与分解技术
  • SIMD槽折叠技术,以高效操作SIMD密文的槽
  • 综合的实现与评估

2 Preliminaries

符号系统描述如下:
NotationDescriptionNotationDescriptionCclientSserver

        E 
       
      
        ( 
       
      
        ∗ 
       
      
        ) 
       
      
     
       E(*) 
      
     
   E(∗)encryption 
    
     
      
      
        π 
       
      
        ( 
       
      
        ∗ 
       
      
        ) 
       
      
     
       \pi(*) 
      
     
   π(∗)encoding 
    
     
      
      
        E 
       
      
        n 
       
      
        c 
       
      
        ( 
       
      
        ∗ 
       
      
        ) 
       
      
     
       Enc(*) 
      
     
   Enc(∗)encoding+encryption 
    
     
      
       
       
         a 
        
       
         ~ 
        
       
      
     
       \tilde{a} 
      
     
   a~FHE ciphertext 
    
     
      
      
        R 
       
      
        o 
       
      
        t 
       
      
        L 
       
      
        ( 
       
      
        ∗ 
       
      
        ) 
       
      
        / 
       
      
        R 
       
      
        o 
       
      
        t 
       
      
        R 
       
      
        ( 
       
      
        ∗ 
       
      
        ) 
       
      
     
       RotL(*)/RotR(*) 
      
     
   RotL(∗)/RotR(∗)左旋转和右旋转 
    
     
      
      
        S 
       
      
        u 
       
      
        b 
       
      
        s 
       
      
        ( 
       
      
        ∗ 
       
      
        ) 
       
      
     
       Subs(*) 
      
     
   Subs(∗)替换操作 
    
     
      
      
        S 
       
      
        g 
       
      
        n 
       
      
        ( 
       
      
        ∗ 
       
      
        ) 
       
      
     
       Sgn(*) 
      
     
   Sgn(∗)sign操作 
    
     
      
      
        L 
       
      
     
       L 
      
     
   L乘法深度 
    
     
      
       
       
         N 
        
       
         ′ 
        
       
      
     
       N' 
      
     
   N′CKKS的环维数 
    
     
      
      
        N 
       
      
     
       N 
      
     
   N 
    
     
      
      
        N 
       
      
        = 
       
       
       
         N 
        
       
         ′ 
        
       
      
        / 
       
      
        2 
       
      
     
       N=N'/2 
      
     
   N=N′/2 
    
     
      
      
        A 
       
      
     
       A 
      
     
   A输入矩阵 
    
     
      
      
        W 
       
      
     
       W 
      
     
   W权重矩阵

2.1 安全推理和威胁模型

安全推理是一个两方密码学协议,其可以在C和S之间进行模型推理,与此同时还可以保护两个参与方输入隐私。它的正式定义如下:

Definition 1:

针对两方参与者,其中

     S 
    
   
  
    S 
   
  
S持有模型 
 
  
   
   
     M 
    
   
  
    M 
   
  
M,且 
 
  
   
   
     C 
    
   
  
    C 
   
  
C持有输入 
 
  
   
   
     A 
    
   
  
    A 
   
  
A的协议 
 
  
   
   
     Π 
    
   
  
    \Pi 
   
  
Π是安全推理协议,当且仅当以下条件满足时:

(1) 正确性: 该协议的最终输出是正确的推理结果

     M 
    
   
     ( 
    
   
     A 
    
   
     ) 
    
   
  
    M(A) 
   
  
M(A)。

(2) 安全性:

     V 
    
   
     i 
    
   
     e 
    
    
    
      w 
     
    
      C 
     
    
      Π 
     
    
    
    
      ≈ 
     
    
      c 
     
    
   
     S 
    
   
     i 
    
    
    
      m 
     
    
      C 
     
    
   
     ( 
    
   
     A 
    
   
     , 
    
   
     o 
    
   
     u 
    
   
     t 
    
   
     ) 
    
   
  
    View^{\Pi}_C\approx_c Sim_C(A,out) 
   
  
ViewCΠ​≈c​SimC​(A,out),其中 
 
  
   
   
     V 
    
   
     i 
    
   
     e 
    
    
    
      w 
     
    
      C 
     
    
      Π 
     
    
   
  
    View^{\Pi}_C 
   
  
ViewCΠ​表示协议 
 
  
   
   
     Π 
    
   
  
    \Pi 
   
  
Π执行期间 
 
  
   
   
     C 
    
   
  
    C 
   
  
C的视角, 
 
  
   
   
     o 
    
   
     u 
    
   
     t 
    
   
  
    out 
   
  
out表示推理的结果。


 
  
   
   
     V 
    
   
     i 
    
   
     e 
    
    
    
      w 
     
    
      S 
     
    
      Π 
     
    
    
    
      ≈ 
     
    
      S 
     
    
   
     S 
    
   
     i 
    
    
    
      m 
     
    
      S 
     
    
   
     ( 
    
   
     M 
    
   
     ) 
    
   
  
    View^{\Pi}_S\approx_S Sim_S(M) 
   
  
ViewSΠ​≈S​SimS​(M),其中 
 
  
   
   
     V 
    
   
     i 
    
   
     e 
    
    
    
      w 
     
    
      S 
     
    
      Π 
     
    
   
  
    View^{\Pi}_S 
   
  
ViewSΠ​表示协议 
 
  
   
   
     Π 
    
   
  
    \Pi 
   
  
Π执行期间 
 
  
   
   
     S 
    
   
  
    S 
   
  
S​的视角。


 
  
   
   
     S 
    
   
     i 
    
    
    
      m 
     
    
      ∗ 
     
    
   
  
    Sim_* 
   
  
Sim∗​可以理解为理想状态下希望实体 
 
  
   
   
     ∗ 
    
   
  
    * 
   
  
∗可以得到的信息。

假设

     C 
    
   
  
    C 
   
  
C或 
 
  
   
   
     S 
    
   
  
    S 
   
  
S为半诚实对手,其在遵守协议规范的同时也尽可能的在执行过程中手机额外的信息。且假设对手在计算上是有限的。

2.2 Transformer

这里简单介绍一下Transformerd。

图1是transformer的结构与工作流程。它将一个表示为矩阵的嵌入传递给注意层和前馈神经网络,最后根据最终对数最大值输出一个选择向量,且,LayerNorm层被应用于每个块之后。

在这里插入图片描述

transformer的结构和工作流程

Attention:

使用三个矩阵(

      W 
     
    
      Q 
     
    
   
     ∈ 
    
    
    
      R 
     
     
     
       n 
      
     
       × 
      
     
       k 
      
     
    
   
     , 
    
    
    
      W 
     
    
      K 
     
    
   
     ∈ 
    
    
    
      R 
     
     
     
       n 
      
     
       × 
      
     
       k 
      
     
    
   
     , 
    
    
    
      W 
     
    
      V 
     
    
   
     ∈ 
    
    
    
      R 
     
     
     
       n 
      
     
       × 
      
     
       k 
      
     
    
   
  
    W_Q\in\mathbb{R}^{n\times k},W_K\in\mathbb{R}^{n\times k},W_V\in\mathbb{R}^{n\times k} 
   
  
WQ​∈Rn×k,WK​∈Rn×k,WV​∈Rn×k)乘嵌入矩阵 
 
  
   
   
     A 
    
   
     ∈ 
    
    
    
      R 
     
     
     
       m 
      
     
       × 
      
     
       n 
      
     
    
   
  
    A\in \mathbb{R}^{m\times n} 
   
  
A∈Rm×n,生成一个query矩阵 
 
  
   
   
     Q 
    
   
     = 
    
   
     A 
    
   
     ⋅ 
    
    
    
      W 
     
    
      Q 
     
    
   
  
    Q = A·W_Q 
   
  
Q=A⋅WQ​,一个key矩阵 
 
  
   
   
     K 
    
   
     = 
    
   
     A 
    
   
     ⋅ 
    
    
    
      W 
     
    
      K 
     
    
   
  
    K=A·W_K 
   
  
K=A⋅WK​和一个value矩阵 
 
  
   
   
     V 
    
   
     = 
    
   
     A 
    
   
     ⋅ 
    
    
    
      W 
     
    
      V 
     
    
   
  
    V=A·W_V 
   
  
V=A⋅WV​。即对于Attention层的单元,transformer会学习到三个权重矩阵。

attention可以被表示为:

      A 
     
    
      t 
     
    
      t 
     
    
      e 
     
    
      n 
     
    
      t 
     
    
      i 
     
    
      o 
     
    
      n 
     
    
      ( 
     
    
      Q 
     
    
      , 
     
    
      K 
     
    
      , 
     
    
      V 
     
    
      ) 
     
    
      = 
     
    
      S 
     
    
      o 
     
    
      f 
     
    
      t 
     
    
      m 
     
    
      a 
     
    
      x 
     
    
      ( 
     
     
      
      
        Q 
       
       
       
         K 
        
       
         T 
        
       
      
      
      
        k 
       
      
     
    
      ) 
     
    
      ⋅ 
     
    
      V 
     
    
   
     Attention(Q,K,V) = Softmax({QK^T\over{\sqrt k}})·V 
    
   
 Attention(Q,K,V)=Softmax(k​QKT​)⋅V
Layer normalization

该层的输入为

     a 
    
   
     ∈ 
    
    
    
      R 
     
    
      n 
     
    
   
  
    a\in \mathbb{R}^n 
   
  
a∈Rn,均值和标准差分别为 
 
  
   
   
     μ 
    
   
  
    \mu 
   
  
μ和 
 
  
   
   
     σ 
    
   
  
    \sigma 
   
  
σ,则该层的输出 
 
  
   
   
     y 
    
   
     ∈ 
    
    
    
      R 
     
    
      n 
     
    
   
  
    y\in\mathbb{R}^n 
   
  
y∈Rn可以表示为:

  
   
    
     
     
       y 
      
     
       i 
      
     
    
      = 
     
    
      γ 
     
    
      ⋅ 
     
     
      
       
       
         x 
        
       
         i 
        
       
      
        − 
       
      
        μ 
       
      
     
       σ 
      
     
    
      + 
     
    
      β 
     
    
   
     y_i=\gamma·{x_i-\mu\over\sigma}+\beta 
    
   
 yi​=γ⋅σxi​−μ​+β

其中,

     γ 
    
   
     , 
    
   
     β 
    
   
     ∈ 
    
   
     R 
    
   
  
    \gamma,\beta\in\mathbb{R} 
   
  
γ,β∈R​是两个超参数。
Feed-forward

全连接前馈网络层包含两个线性变换以及一个GELU激活函数:

      F 
     
    
      e 
     
    
      e 
     
    
      d 
     
    
      F 
     
    
      o 
     
    
      r 
     
    
      w 
     
    
      a 
     
    
      r 
     
    
      d 
     
    
      ( 
     
    
      X 
     
    
      ) 
     
    
      = 
     
    
      G 
     
    
      E 
     
    
      L 
     
    
      U 
     
    
      ( 
     
    
      X 
     
     
     
       W 
      
     
       1 
      
     
    
      + 
     
     
     
       b 
      
     
       1 
      
     
    
      ) 
     
    
      ⋅ 
     
     
     
       W 
      
     
       2 
      
     
    
      + 
     
     
     
       b 
      
     
       2 
      
     
    
   
     FeedForward(X)=GELU(XW_1+b_1)·W_2+b_2 
    
   
 FeedForward(X)=GELU(XW1​+b1​)⋅W2​+b2​

其中GELU函数计算如下:

      G 
     
    
      E 
     
    
      L 
     
    
      U 
     
    
      ( 
     
    
      x 
     
    
      ) 
     
    
      = 
     
     
     
       1 
      
     
       2 
      
     
    
      x 
     
    
      ⋅ 
     
    
      ( 
     
    
      1 
     
    
      + 
     
    
      e 
     
    
      r 
     
    
      f 
     
    
      ( 
     
     
     
       x 
      
      
      
        2 
       
      
     
    
      ) 
     
    
      ) 
     
    
   
     GELU(x)={1\over 2}x·(1+erf({x\over \sqrt 2})) 
    
   
 GELU(x)=21​x⋅(1+erf(2​x​))

式中,高斯误差函数为

     e 
    
   
     r 
    
   
     f 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     = 
    
    
    
      2 
     
     
     
       π 
      
     
    
    
    
      ∫ 
     
    
      0 
     
    
      x 
     
    
    
    
      e 
     
     
     
       − 
      
      
      
        t 
       
      
        2 
       
      
     
    
   
     d 
    
   
     t 
    
   
  
    erf(x)={2\over\sqrt{\pi}}\int_0^xe^{-t^2}dt 
   
  
erf(x)=π​2​∫0x​e−t2dt​。由于其良好的曲率和非单调性,它被用作激活函数。
Argmax

根据最终对数最大值输出一个选择向量

可以看到,只要我们能够使用FHE实现各个层的计算,就可以实现一个安全Transformer。

2.3 Fully Homomorphic Encryption

FHE可以对加密数据执行任意操作,故FHE是使得我们构建非交互式安全transformer推理得主要工具。RNS-CKKS属于级全同态加密,其可以支持L级深度的乘法。RNS-CKKS的明文和密文均是多项式环

      R 
     
    
      Q 
     
    
   
     = 
    
    
    
      Z 
     
    
      Q 
     
    
   
     [ 
    
   
     X 
    
   
     ] 
    
   
     / 
    
   
     ( 
    
    
    
      X 
     
     
     
       N 
      
     
       ′ 
      
     
    
   
     + 
    
   
     1 
    
   
     ) 
    
   
  
    R_Q=\Z_Q[X]/(X^{N'}+1) 
   
  
RQ​=ZQ​[X]/(XN′+1)上的元素。其中 
 
  
   
   
     Q 
    
   
     = 
    
    
    
      Π 
     
     
     
       i 
      
     
       = 
      
     
       0 
      
     
    
      L 
     
    
    
    
      q 
     
    
      i 
     
    
   
  
    Q=\Pi^L_{i=0}q_i 
   
  
Q=Πi=0L​qi​,且 
 
  
   
    
    
      q 
     
    
      i 
     
    
   
  
    q_i 
   
  
qi​​之间互素。若密文的级别变得太低,则可以运行自举操作来刷新密文到高的级别,以允许更多的计算。

简单地说,自举即利用自同构

      R 
     
     
     
       q 
      
     
       0 
      
     
    
   
     ≅ 
    
    
    
      R 
     
     
     
       q 
      
     
       0 
      
     
    
   
     × 
    
    
    
      R 
     
     
     
       q 
      
     
       1 
      
     
    
   
     × 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     × 
    
    
    
      R 
     
     
     
       q 
      
     
       L 
      
     
    
   
  
    R_{q_0}\cong R_{q_0}\times R_{q_1}\times ... \times R_{q_L} 
   
  
Rq0​​≅Rq0​​×Rq1​​×...×RqL​​,来将密文模从 
 
  
   
    
    
      q 
     
    
      0 
     
    
   
  
    q_0 
   
  
q0​提升到 
 
  
   
    
    
      q 
     
    
      L 
     
    
   
  
    q_L 
   
  
qL​,以及对密文同态评估解密电路。若自举本身消耗K个级别,则刷新后的密文支持 
 
  
   
   
     L 
    
   
     − 
    
   
     K 
    
   
  
    L-K 
   
  
L−K个级深度的计算。

RNS-CKKS支持SIMD操作,其可以加密向量

     a 
    
   
     ∈ 
    
    
    
      R 
     
    
      N 
     
    
   
  
    a\in \R^N 
   
  
a∈RN到一个密文中,且批处理这些加密元素,而不引入其他操作。为了以SIMD格式加密,首先使用编码算法 
 
  
   
   
     π 
    
   
     ( 
    
   
     ∗ 
    
   
     ) 
    
   
  
    \pi(*) 
   
  
π(∗)将向量 
 
  
   
   
     a 
    
   
  
    a 
   
  
a编码为一个 
 
  
   
    
    
      R 
     
    
      Q 
     
    
   
  
    R_Q 
   
  
RQ​上的多项式,然后使用加密算法 
 
  
   
   
     E 
    
   
     ( 
    
   
     ∗ 
    
   
     ) 
    
   
  
    E(*) 
   
  
E(∗)​加密该多项式。

在整篇文章中,我们使用

     E 
    
   
     ( 
    
   
     ∗ 
    
   
     ) 
    
   
  
    E(*) 
   
  
E(∗)表示加密多项式,使用 
 
  
   
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     ∗ 
    
   
     ) 
    
   
  
    Enc(*) 
   
  
Enc(∗)表示以SIMD格式加密向量,即 
 
  
   
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     a 
    
   
     ) 
    
   
     = 
    
   
     E 
    
   
     ( 
    
   
     π 
    
   
     ( 
    
   
     a 
    
   
     ) 
    
   
     ) 
    
   
  
    Enc(a)=E(\pi(a)) 
   
  
Enc(a)=E(π(a)),其中 
 
  
   
   
     a 
    
   
  
    a 
   
  
a是一个向量。

一个特殊的FHE操作:

     c 
    
    
    
      t 
     
    
      ′ 
     
    
   
     ← 
    
   
     S 
    
   
     u 
    
   
     b 
    
   
     s 
    
   
     ( 
    
   
     c 
    
   
     t 
    
   
     , 
    
   
     k 
    
   
     ) 
    
   
  
    ct'\leftarrow Subs(ct,k) 
   
  
ct′←Subs(ct,k):替换操作,该操作以密文 
 
  
   
   
     c 
    
   
     t 
    
   
     = 
    
   
     E 
    
   
     ( 
    
   
     p 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     ) 
    
   
  
    ct=E(p(x)) 
   
  
ct=E(p(x))以及一个奇整数 
 
  
   
   
     k 
    
   
  
    k 
   
  
k作为输入,然后得到新的密文 
 
  
   
   
     c 
    
    
    
      t 
     
    
      ′ 
     
    
   
     = 
    
   
     E 
    
   
     ( 
    
   
     p 
    
   
     ( 
    
    
    
      x 
     
    
      k 
     
    
   
     ) 
    
   
     ) 
    
   
  
    ct'=E(p(x^k)) 
   
  
ct′=E(p(xk))​​​​。

这里的

     S 
    
   
     u 
    
   
     b 
    
   
     s 
    
   
     ( 
    
   
     c 
    
   
     t 
    
   
     , 
    
   
     k 
    
   
     ) 
    
   
  
    Subs(ct,k) 
   
  
Subs(ct,k)应该是一种密钥交换操作,可以描述如下:

已知密文:

     c 
    
   
     t 
    
   
     = 
    
   
     ( 
    
   
     − 
    
   
     a 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     s 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     + 
    
   
     e 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     + 
    
   
     p 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     , 
    
   
     a 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     ) 
    
   
  
    ct=(-a(x)s(x)+e(x)+p(x),a(x)) 
   
  
ct=(−a(x)s(x)+e(x)+p(x),a(x))

将该密文进行自同构操作:

      κ 
     
    
      k 
     
    
   
     ( 
    
   
     c 
    
   
     t 
    
   
     ) 
    
   
     = 
    
   
     ( 
    
   
     − 
    
   
     a 
    
   
     ( 
    
    
    
      x 
     
    
      k 
     
    
   
     ) 
    
   
     s 
    
   
     ( 
    
    
    
      x 
     
    
      k 
     
    
   
     ) 
    
   
     + 
    
   
     e 
    
   
     ( 
    
    
    
      x 
     
    
      k 
     
    
   
     ) 
    
   
     + 
    
   
     p 
    
   
     ( 
    
    
    
      x 
     
    
      k 
     
    
   
     ) 
    
   
     , 
    
   
     a 
    
   
     ( 
    
    
    
      x 
     
    
      k 
     
    
   
     ) 
    
   
     ) 
    
   
  
    \kappa_k(ct)=(-a(x^k)s(x^k)+e(x^k)+p(x^k),a(x^k)) 
   
  
κk​(ct)=(−a(xk)s(xk)+e(xk)+p(xk),a(xk))。

然后得到用户提供的交换密钥:

     k 
    
   
     e 
    
   
     y 
    
   
     = 
    
   
     ( 
    
   
     − 
    
   
     a 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     s 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     + 
    
   
     e 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     + 
    
   
     P 
    
   
     ⋅ 
    
   
     s 
    
   
     ( 
    
    
    
      x 
     
    
      k 
     
    
   
     ) 
    
   
     , 
    
   
     a 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     ) 
    
   
  
    key = (-a(x)s(x)+e(x)+P·s(x^k),a(x)) 
   
  
key=(−a(x)s(x)+e(x)+P⋅s(xk),a(x))

然后执行密钥交换操作:

     c 
    
    
    
      t 
     
    
      ′ 
     
    
   
     = 
    
   
     ( 
    
    
    
      κ 
     
    
      k 
     
    
   
     ( 
    
   
     c 
    
   
     t 
    
   
     ) 
    
   
     [ 
    
   
     0 
    
   
     ] 
    
   
     , 
    
   
     0 
    
   
     ) 
    
   
     + 
    
   
     ( 
    
   
     ⌊ 
    
    
    
      P 
     
     
     
       − 
      
     
       1 
      
     
    
   
     ⋅ 
    
    
    
      κ 
     
    
      k 
     
    
   
     ( 
    
   
     c 
    
   
     t 
    
   
     ) 
    
   
     [ 
    
   
     1 
    
   
     ] 
    
   
     ⋅ 
    
   
     k 
    
   
     e 
    
   
     y 
    
   
     ⌉ 
    
   
     ) 
    
   
  
    ct'=(\kappa_k(ct)[0],0)+(\lfloor P^{-1}·\kappa_k(ct)[1]·key\rceil) 
   
  
ct′=(κk​(ct)[0],0)+(⌊P−1⋅κk​(ct)[1]⋅key⌉)

此时新的密文即

     c 
    
    
    
      t 
     
    
      ′ 
     
    
   
     = 
    
   
     ( 
    
   
     − 
    
   
     a 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     s 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     + 
    
   
     e 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     + 
    
   
     p 
    
   
     ( 
    
    
    
      x 
     
    
      k 
     
    
   
     ) 
    
   
     , 
    
   
     a 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     ) 
    
   
  
    ct'=(-a(x)s(x)+e(x)+p(x^k),a(x)) 
   
  
ct′=(−a(x)s(x)+e(x)+p(xk),a(x))

注意,这里的

     a 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     , 
    
   
     e 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
  
    a(x),e(x) 
   
  
a(x),e(x)是变化的,也就是不同的密文中,这是不同的。

2.4 Homomorphic sign function

由于FHE仅支持线性函数,所以为了实现在FHE下对加密数据的比较,本文需要利用sign函数的多项式近似,即:

      s 
     
    
      i 
     
    
      g 
     
    
      n 
     
    
      ( 
     
    
      x 
     
    
      ) 
     
    
      = 
     
     
     
       f 
      
      
      
        d 
       
      
        f 
       
      
     
    
      ( 
     
     
     
       g 
      
      
      
        d 
       
      
        g 
       
      
     
    
      ( 
     
    
      x 
     
    
      ) 
     
    
      ) 
     
    
      = 
     
     
     
       { 
      
      
       
        
         
          
          
            − 
           
          
            1 
           
          
         
        
        
         
          
          
            ( 
           
          
            − 
           
          
            1 
           
          
            ≤ 
           
          
            x 
           
          
            ≤ 
           
          
            − 
           
           
           
             2 
            
            
            
              − 
             
            
              α 
             
            
           
          
            ) 
           
          
         
        
       
       
        
         
         
           0 
          
         
        
        
         
          
          
            ( 
           
          
            x 
           
          
            = 
           
          
            0 
           
          
            ) 
           
          
         
        
       
       
        
         
         
           1 
          
         
        
        
         
          
          
            ( 
           
           
           
             2 
            
            
            
              − 
             
            
              α 
             
            
           
          
            ≤ 
           
          
            x 
           
          
            ≤ 
           
          
            1 
           
          
            ) 
           
          
         
        
       
      
     
    
   
     sign(x)=f^{d_f}(g^{d_g}(x))=\begin{cases} -1 &(-1\leq x \leq -2^{-\alpha}) \\ 0 &(x = 0) \\ 1 &(2^{-\alpha}\leq x \leq 1) \\ \end{cases} 
    
   
 sign(x)=fdf​(gdg​(x))=⎩⎨⎧​−101​(−1≤x≤−2−α)(x=0)(2−α≤x≤1)​

其中,

     f 
    
   
     ( 
    
   
     ) 
    
   
     , 
    
   
     g 
    
   
     ( 
    
   
     ) 
    
   
  
    f(),g() 
   
  
f(),g()为两个多项式, 
 
  
   
    
    
      d 
     
    
      f 
     
    
   
     , 
    
    
    
      d 
     
    
      g 
     
    
   
  
    d_f,d_g 
   
  
df​,dg​为这两个多项式重复的次数。注意,该多项式近似要求输入x取值范围为[-1,1]。因此,对任何输入 
 
  
   
   
     a 
    
   
     ∈ 
    
   
     [ 
    
    
    
      a 
     
     
     
       m 
      
     
       i 
      
     
       n 
      
     
    
   
     , 
    
    
    
      a 
     
     
     
       m 
      
     
       a 
      
     
       x 
      
     
    
   
     ] 
    
   
  
    a\in [a_{min},a_{max}] 
   
  
a∈[amin​,amax​]都需要进行归一化处理:

  
   
    
    
      x 
     
    
      : 
     
    
      = 
     
    
      a 
     
    
      / 
     
    
      m 
     
    
      a 
     
    
      x 
     
    
      { 
     
    
      ∣ 
     
     
     
       a 
      
      
      
        m 
       
      
        a 
       
      
        x 
       
      
     
    
      ∣ 
     
    
      , 
     
    
      ∣ 
     
     
     
       a 
      
      
      
        m 
       
      
        i 
       
      
        n 
       
      
     
    
      ∣ 
     
    
      } 
     
    
   
     x := a/max\{|a_{max}|,|a_{min}|\} 
    
   
 x:=a/max{∣amax​∣,∣amin​∣}

这里,我们使用Sgn()表示在SIMD密文上同时运行归一化与sign近似函数:

       b 
      
     
       ~ 
      
     
    
      ← 
     
    
      S 
     
    
      g 
     
    
      n 
     
    
      ( 
     
     
     
       a 
      
     
       ~ 
      
     
    
      ) 
     
    
      : 
     
     
     
       b 
      
     
       i 
      
     
    
      = 
     
     
     
       f 
      
      
      
        d 
       
      
        f 
       
      
     
    
      ( 
     
     
     
       g 
      
      
      
        d 
       
      
        g 
       
      
     
    
      ( 
     
     
      
      
        a 
       
      
        i 
       
      
      
      
        m 
       
      
        a 
       
      
        x 
       
      
        { 
       
      
        ∣ 
       
       
       
         a 
        
        
        
          m 
         
        
          a 
         
        
          x 
         
        
       
      
        ∣ 
       
      
        , 
       
      
        ∣ 
       
       
       
         a 
        
        
        
          m 
         
        
          i 
         
        
          n 
         
        
       
      
        ∣ 
       
      
        } 
       
      
     
    
      ) 
     
    
      ) 
     
    
         
     
    
      ∀ 
     
    
      i 
     
    
      ∈ 
     
    
      [ 
     
    
      N 
     
    
      ] 
     
    
   
     \widetilde b\leftarrow Sgn(\widetilde a): b_i =f^{d_f}(g^{d_g}({a_i\over{max\{|a_{max}|,|a_{min}|\}}})) \ \ \forall i \in [N] 
    
   
 b←Sgn(a):bi​=fdf​(gdg​(max{∣amax​∣,∣amin​∣}ai​​))  ∀i∈[N]

在本文的实现中,使用的是9次的

     f 
    
   
     ( 
    
   
     ∗ 
    
   
     ) 
    
   
  
    f(*) 
   
  
f(∗)和 
 
  
   
   
     g 
    
   
     ( 
    
   
     ∗ 
    
   
     ) 
    
   
  
    g(*) 
   
  
g(∗),且设计 
 
  
   
   
     α 
    
   
     = 
    
   
     16 
    
   
     , 
    
    
    
      d 
     
    
      f 
     
    
   
     = 
    
   
     2 
    
   
     , 
    
    
    
      d 
     
    
      g 
     
    
   
     = 
    
   
     2 
    
   
  
    \alpha=16,d_f=2,d_g=2 
   
  
α=16,df​=2,dg​=2​​,然后使用BSGS算法来评估多项式。

3 Basic design

本节介绍NEXUS的基础设计,即在不优化的情况下实现上述transformer的每一层计算,在之后的章节中会对本节的算法进行优化。

3.1 Attention

3.1.1 Matrix multiplication(ciphertext-plaintext)

在Attention层的第一个MatMul步骤,我们需要计算三个密文—明文矩阵乘法:

      Q 
     
    
      : 
     
    
      = 
     
    
      A 
     
    
      ⋅ 
     
     
     
       W 
      
     
       Q 
      
     
    
      ; 
     
     
    
      K 
     
    
      : 
     
    
      = 
     
    
      A 
     
    
      ⋅ 
     
     
     
       W 
      
     
       K 
      
     
    
      ; 
     
     
    
      V 
     
    
      : 
     
    
      = 
     
    
      A 
     
    
      ⋅ 
     
     
     
       W 
      
     
       V 
      
     
    
      ; 
     
    
   
     Q:=A·W_Q;\\ K:=A·W_K;\\ V:=A·W_V; 
    
   
 Q:=A⋅WQ​;K:=A⋅WK​;V:=A⋅WV​;

其中A是我们的输入,

      W 
     
    
      Q 
     
    
   
     , 
    
    
    
      W 
     
    
      K 
     
    
   
     , 
    
    
    
      W 
     
    
      V 
     
    
   
  
    W_Q,W_K,W_V 
   
  
WQ​,WK​,WV​是三个给定矩阵,下面以 
 
  
   
   
     A 
    
   
     ⋅ 
    
    
    
      W 
     
    
      Q 
     
    
   
  
    A·W_Q 
   
  
A⋅WQ​为例来描述这个密文—明文矩阵乘法,该过程同样适用于 
 
  
   
    
    
      W 
     
    
      K 
     
    
   
  
    W_K 
   
  
WK​和 
 
  
   
    
    
      W 
     
    
      V 
     
    
   
  
    W_V 
   
  
WV​。

给定矩阵

     A 
    
   
     ∈ 
    
    
    
      R 
     
     
     
       m 
      
     
       × 
      
     
       n 
      
     
    
   
  
    A\in \mathbb{R}^{m\times n} 
   
  
A∈Rm×n和矩阵 
 
  
   
    
    
      W 
     
    
      Q 
     
    
   
     ∈ 
    
    
    
      R 
     
     
     
       n 
      
     
       × 
      
     
       k 
      
     
    
   
  
    W_Q\in \mathbb{R}^{n\times k} 
   
  
WQ​∈Rn×k,计算矩阵 
 
  
   
   
     Q 
    
   
     : 
    
   
     = 
    
   
     A 
    
   
     ⋅ 
    
    
    
      W 
     
    
      Q 
     
    
   
  
    Q:=A·W_Q 
   
  
Q:=A⋅WQ​。

      a 
     
     
     
       i 
      
     
       , 
      
     
       j 
      
     
    
   
     ∈ 
    
   
     R 
    
   
  
    a_{i,j}\in \mathbb{R} 
   
  
ai,j​∈R表示矩阵A的第i行第j列的元素, 
 
  
   
    
    
      w 
     
    
      j 
     
    
   
     ∈ 
    
    
    
      R 
     
    
      k 
     
    
   
  
    w_j\in \mathbb{R}^k 
   
  
wj​∈Rk表示矩阵 
 
  
   
    
    
      W 
     
    
      Q 
     
    
   
  
    W_Q 
   
  
WQ​的第j行的元素向量, 
 
  
   
    
    
      q 
     
    
      i 
     
    
   
     ∈ 
    
    
    
      R 
     
    
      k 
     
    
   
  
    q_i\in \mathbb{R}^k 
   
  
qi​∈Rk是矩阵 
 
  
   
   
     Q 
    
   
  
    Q 
   
  
Q的第i行的元素向量,即:

  
   
    
     
     
       q 
      
     
       i 
      
     
    
      = 
     
     
     
       ∑ 
      
      
      
        j 
       
      
        ∈ 
       
      
        [ 
       
      
        n 
       
      
        ] 
       
      
     
     
     
       a 
      
      
      
        i 
       
      
        , 
       
      
        j 
       
      
     
    
      ⋅ 
     
     
     
       w 
      
     
       j 
      
     
    
   
     q_i=\sum_{j\in [n]}a_{i,j}·w_j 
    
   
 qi​=j∈[n]∑​ai,j​⋅wj​

因此,上述过程可以描述为,C将A中的每个元素

      a 
     
     
     
       i 
      
     
       , 
      
     
       j 
      
     
    
   
  
    a_{i,j} 
   
  
ai,j​均单独加密为密文发送给S,然后S同态评估MatrixMul,一个演示的示例如下:

在这里插入图片描述

图2 SIMD-based matrix multiplication


在上述描述中,C需要发送

     m 
    
   
     × 
    
   
     n 
    
   
  
    m\times n 
   
  
m×n个密文给S,从某一方面来说,这种开销是比较大的,因此本文在第4节提出一种算法可以将如此类型的 
 
  
   
   
     m 
    
   
     × 
    
   
     n 
    
   
  
    m\times n 
   
  
m×n个密文压缩为 
 
  
   
    
     
     
       m 
      
     
       × 
      
     
       n 
      
     
     
     
       N 
      
     
       ′ 
      
     
    
   
  
    m\times n\over{N'} 
   
  
N′m×n​个密文,即一个密文中存放 
 
  
   
    
    
      N 
     
    
      ′ 
     
    
   
  
    N' 
   
  
N′个元素,随后S可以将压缩后的密文恢复为压缩前的密文形式。

3.1.2 Matrix multiplication(ciphertext-ciphertext)

经过上述步骤后,可以获得加密的

     ( 
    
   
     Q 
    
   
     , 
    
   
     K 
    
   
     , 
    
   
     V 
    
   
     ) 
    
   
  
    (Q,K,V) 
   
  
(Q,K,V),在Attention的第二个MatMul块,S需要计算 
 
  
   
   
     Q 
    
   
     ⋅ 
    
    
    
      K 
     
    
      T 
     
    
   
  
    Q·K^T 
   
  
Q⋅KT。很明显,现在Q的每一行和 
 
  
   
    
    
      K 
     
    
      T 
     
    
   
  
    K^T 
   
  
KT的每一列已经以SIMD的形式加密为 
 
  
   
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     q 
    
   
     ) 
    
   
     , 
    
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
    
    
      k 
     
    
      T 
     
    
   
     ) 
    
   
  
    Enc(q) , Enc(k^T) 
   
  
Enc(q),Enc(kT)。如果S可以计算 
 
  
   
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     q 
    
   
     ) 
    
   
  
    Enc(q) 
   
  
Enc(q)和 
 
  
   
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
    
    
      k 
     
    
      T 
     
    
   
     ) 
    
   
  
    Enc(k^T) 
   
  
Enc(kT)的内积,则可以获得 
 
  
   
   
     Q 
    
   
     ⋅ 
    
    
    
      K 
     
    
      T 
     
    
   
  
    Q·K^T 
   
  
Q⋅KT的加密结果。

由于SIMD,S可以很容易的计算得到Enc(u),其中

     u 
    
   
     = 
    
   
     [ 
    
    
    
      u 
     
    
      0 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      u 
     
     
     
       k 
      
     
       − 
      
     
       1 
      
     
    
   
     ] 
    
   
  
    u=[u_0,...,u_{k-1}] 
   
  
u=[u0​,...,uk−1​]是q和 
 
  
   
    
    
      k 
     
    
      T 
     
    
   
  
    k^T 
   
  
kT的元素级的乘法,现在为了计算内积,S仅仅需要在SIMD下计算 
 
  
   
   
     s 
    
   
     : 
    
   
     = 
    
    
    
      ∑ 
     
     
     
       i 
      
     
       = 
      
     
       0 
      
     
     
     
       k 
      
     
       − 
      
     
       1 
      
     
    
    
    
      u 
     
    
      i 
     
    
   
  
    s:=\sum_{i=0}^{k-1}u_i 
   
  
s:=∑i=0k−1​ui​。

为了计算这个和,我们可以通过k-1次的旋转及加和来计算,从而获得密文Enc([s,s,…,s]),但是本文提出了’QuickSum’算法,该算法仅仅需要logk次旋转就可达到这个目标。'QuickSum’算法在第5节介绍。


进一步,S将计算得到每一行的的m个密文组合到单一密文中,计算方法如下:

       ∑ 
      
      
      
        i 
       
      
        = 
       
      
        0 
       
      
      
      
        m 
       
      
        − 
       
      
        1 
       
      
     
    
      ( 
     
    
      E 
     
    
      n 
     
    
      c 
     
    
      ( 
     
     
     
       s 
      
     
       i 
      
     
    
      , 
     
     
     
       s 
      
     
       i 
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       s 
      
     
       i 
      
     
    
      ) 
     
    
      ⋅ 
     
     
     
       b 
      
     
       i 
      
     
    
      ) 
     
    
   
     \sum_{i=0}^{m-1}(Enc(s_i,s_i,...,s_i)·b_i) 
    
   
 i=0∑m−1​(Enc(si​,si​,...,si​)⋅bi​)

其中

      b 
     
    
      i 
     
    
   
  
    b_i 
   
  
bi​仅在第i个槽的位置是1,其余槽均为0。

易知,输出矩阵为

     A 
    
   
     ∈ 
    
    
    
      R 
     
     
     
       m 
      
     
       × 
      
     
       m 
      
     
    
   
  
    A\in \mathbb{R}^{m\times m} 
   
  
A∈Rm×m,其中A的每行向量以SIMD形式加密,将该结果作为Softmax的输入。
3.1.3 Softmax

Softmax函数需要被应用于A的每一行,该函数评估如下:

          y 
         
        
          i 
         
        
       
         = 
        
        
         
         
           e 
          
         
           x 
          
         
           p 
          
         
           ( 
          
          
          
            a 
           
          
            i 
           
          
         
           − 
          
          
          
            a 
           
           
           
             m 
            
           
             a 
            
           
             x 
            
           
          
         
           ) 
          
         
         
          
          
            ∑ 
           
           
           
             j 
            
           
             = 
            
           
             0 
            
           
           
           
             m 
            
           
             − 
            
           
             1 
            
           
          
         
           e 
          
         
           x 
          
         
           p 
          
         
           ( 
          
          
          
            a 
           
          
            j 
           
          
         
           − 
          
          
          
            a 
           
           
           
             m 
            
           
             a 
            
           
             x 
            
           
          
         
           ) 
          
         
        
       
      
      
      
      
        (1) 
       
      
     
    
   
     y_i={exp(a_i-a_{max})\over{\sum_{j=0}^{m-1}exp(a_j-a_{max})}}\tag{1} 
    
   
 yi​=∑j=0m−1​exp(aj​−amax​)exp(ai​−amax​)​(1)

其中

      a 
     
     
     
       m 
      
     
       a 
      
     
       x 
      
     
    
   
     = 
    
   
     m 
    
   
     a 
    
   
     x 
    
   
     ( 
    
    
    
      a 
     
    
      0 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       m 
      
     
       − 
      
     
       1 
      
     
    
   
     ) 
    
   
  
    a_{max}=max(a_0,...,a_{m-1}) 
   
  
amax​=max(a0​,...,am−1​),从而确保指数函数的每个输入 
 
  
   
   
     ( 
    
    
    
      a 
     
    
      j 
     
    
   
     − 
    
    
    
      a 
     
     
     
       m 
      
     
       a 
      
     
       x 
      
     
    
   
     ) 
    
   
  
    (a_j-a_{max}) 
   
  
(aj​−amax​)​是非正数,保证稳定性。

本文提出了’QuickMax’算法,该算法以

     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     [ 
    
    
    
      a 
     
    
      0 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       m 
      
     
       − 
      
     
       1 
      
     
    
   
     ] 
    
   
     ) 
    
   
  
    Enc([a_0,...,a_{m-1}]) 
   
  
Enc([a0​,...,am−1​])为输入,并输出 
 
  
   
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     [ 
    
    
    
      a 
     
     
     
       m 
      
     
       a 
      
     
       x 
      
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       m 
      
     
       a 
      
     
       x 
      
     
    
   
     ] 
    
   
     ) 
    
   
  
    Enc([a_{max},...,a_{max}]) 
   
  
Enc([amax​,...,amax​])​,且,该算法仅需要logm-1次Sgn操作与logm次旋转操作。该算法描述在第5节。

给定

     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     [ 
    
    
    
      a 
     
    
      0 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       m 
      
     
       − 
      
     
       1 
      
     
    
   
     ] 
    
   
     ) 
    
   
  
    Enc([a_0,...,a_{m-1}]) 
   
  
Enc([a0​,...,am−1​])和 
 
  
   
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     [ 
    
    
    
      a 
     
     
     
       m 
      
     
       a 
      
     
       x 
      
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       m 
      
     
       a 
      
     
       x 
      
     
    
   
     ] 
    
   
     ) 
    
   
  
    Enc([a_{max},...,a_{max}]) 
   
  
Enc([amax​,...,amax​])。

S进行如下步骤计算:

      E 
     
    
      n 
     
    
      c 
     
    
      ( 
     
    
      [ 
     
     
     
       a 
      
     
       0 
      
     
       ′ 
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       a 
      
      
      
        m 
       
      
        − 
       
      
        1 
       
      
     
       ′ 
      
     
    
      ] 
     
    
      ) 
     
    
      = 
     
    
      E 
     
    
      n 
     
    
      c 
     
    
      ( 
     
    
      [ 
     
     
     
       a 
      
     
       0 
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       a 
      
      
      
        m 
       
      
        − 
       
      
        1 
       
      
     
    
      ] 
     
    
      ) 
     
    
      − 
     
    
      E 
     
    
      n 
     
    
      c 
     
    
      ( 
     
    
      [ 
     
     
     
       a 
      
      
      
        m 
       
      
        a 
       
      
        x 
       
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       a 
      
      
      
        m 
       
      
        a 
       
      
        x 
       
      
     
    
      ] 
     
    
      ) 
     
    
   
     Enc([a'_0,...,a'_{m-1}])=Enc([a_0,...,a_{m-1}])-Enc([a_{max},...,a_{max}]) 
    
   
 Enc([a0′​,...,am−1′​])=Enc([a0​,...,am−1​])−Enc([amax​,...,amax​])

然后根据如下公式计算指数函数,这里使用泰勒展开:

      e 
     
    
      x 
     
    
      p 
     
    
      ( 
     
    
      x 
     
    
      ) 
     
    
      ≈ 
     
    
      ( 
     
    
      1 
     
    
      + 
     
     
     
       x 
      
      
      
        2 
       
      
        r 
       
      
     
     
     
       ) 
      
      
      
        2 
       
      
        r 
       
      
     
    
      , 
     
    
      x 
     
    
      ≤ 
     
    
      0 
     
    
   
     exp(x)\approx(1+{x\over{2^r}})^{2^r},x\leq 0 
    
   
 exp(x)≈(1+2rx​)2r,x≤0

其中

     r 
    
   
     = 
    
   
     6 
    
   
  
    r=6 
   
  
r=6,此时平均误差被限制在 
 
  
   
   
     1 
    
    
    
      0 
     
     
     
       − 
      
     
       5 
      
     
    
   
  
    10^{-5} 
   
  
10−5,即S以SIMD格式计算指数函数:

  
   
    
    
      E 
     
    
      n 
     
    
      c 
     
    
      ( 
     
     
     
       e 
      
     
       0 
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       e 
      
      
      
        m 
       
      
        − 
       
      
        1 
       
      
     
    
      ) 
     
    
      = 
     
    
      e 
     
    
      x 
     
    
      p 
     
    
      ( 
     
    
      E 
     
    
      n 
     
    
      c 
     
    
      ( 
     
    
      [ 
     
     
     
       a 
      
     
       0 
      
     
       ′ 
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       a 
      
      
      
        m 
       
      
        − 
       
      
        1 
       
      
     
       ′ 
      
     
    
      ] 
     
    
      ) 
     
    
      ) 
     
    
   
     Enc(e_0,...,e_{m-1})=exp(Enc([a'_0,...,a'_{m-1}])) 
    
   
 Enc(e0​,...,em−1​)=exp(Enc([a0′​,...,am−1′​]))

很明显,这里

      e 
     
    
      j 
     
    
   
     = 
    
   
     e 
    
   
     x 
    
   
     p 
    
   
     ( 
    
    
    
      a 
     
    
      j 
     
    
      ′ 
     
    
   
     ) 
    
   
  
    e_j=exp(a'_j) 
   
  
ej​=exp(aj′​)。

接下来,S应用

     Q 
    
   
     u 
    
   
     i 
    
   
     c 
    
   
     k 
    
   
     S 
    
   
     u 
    
   
     m 
    
   
     ( 
    
   
     ∗ 
    
   
     ) 
    
   
  
    QuickSum(*) 
   
  
QuickSum(∗)算法来获得 
 
  
   
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     [ 
    
    
    
      ∑ 
     
     
     
       j 
      
     
       = 
      
     
       0 
      
     
     
     
       m 
      
     
       − 
      
     
       1 
      
     
    
    
    
      e 
     
    
      j 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      ∑ 
     
     
     
       j 
      
     
       = 
      
     
       0 
      
     
     
     
       m 
      
     
       − 
      
     
       1 
      
     
    
    
    
      e 
     
    
      j 
     
    
   
     ] 
    
   
     ) 
    
   
  
    Enc([\sum^{m-1}_{j=0}e_j,...,\sum^{m-1}_{j=0}e_j]) 
   
  
Enc([∑j=0m−1​ej​,...,∑j=0m−1​ej​])​。

进一步的,S使用文献[21,24]中的Goldschmidt除法算法来计算:

      E 
     
    
      n 
     
    
      c 
     
    
      ( 
     
     
     
       y 
      
     
       0 
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       y 
      
      
      
        m 
       
      
        − 
       
      
        1 
       
      
     
    
      ) 
     
    
      = 
     
     
      
      
        E 
       
      
        n 
       
      
        c 
       
      
        ( 
       
       
       
         e 
        
       
         0 
        
       
      
        , 
       
      
        . 
       
      
        . 
       
      
        . 
       
      
        , 
       
       
       
         e 
        
        
        
          m 
         
        
          − 
         
        
          1 
         
        
       
      
        ) 
       
      
      
      
        E 
       
      
        n 
       
      
        c 
       
      
        ( 
       
      
        [ 
       
       
       
         ∑ 
        
        
        
          j 
         
        
          = 
         
        
          0 
         
        
        
        
          m 
         
        
          − 
         
        
          1 
         
        
       
       
       
         e 
        
       
         j 
        
       
      
        , 
       
      
        . 
       
      
        . 
       
      
        . 
       
      
        , 
       
       
       
         ∑ 
        
        
        
          j 
         
        
          = 
         
        
          0 
         
        
        
        
          m 
         
        
          − 
         
        
          1 
         
        
       
       
       
         e 
        
       
         j 
        
       
      
        ] 
       
      
        ) 
       
      
     
    
   
     Enc(y_0,...,y_{m-1})={Enc(e_0,...,e_{m-1})\over Enc([\sum^{m-1}_{j=0}e_j,...,\sum^{m-1}_{j=0}e_j])} 
    
   
 Enc(y0​,...,ym−1​)=Enc([∑j=0m−1​ej​,...,∑j=0m−1​ej​])Enc(e0​,...,em−1​)​

Softmax算法的详细描述如算法1所示:

在这里插入图片描述

3.1.4 Matrix multiplication(ciphertext-ciphertext)

这里是Attention的最后一个MatMul块,该块的计算原理同3.1.2节完全一致。

3.2 Layer normalization

本文的归一化表示如下(但是不太清楚这个归一化使用的是什么计算公式):

在这里插入图片描述

3.3 Feed forward

前馈网络层涉及到两个矩阵乘法以及一个GELU。矩阵乘法如上文所述来计算。GELU可以使用下述分段多项式来近似,当输入

     x 
    
   
     ∈ 
    
   
     [ 
    
   
     − 
    
   
     60 
    
   
     , 
    
   
     60 
    
   
     ] 
    
   
  
    x\in [-60,60] 
   
  
x∈[−60,60],则可以确保误差在 
 
  
   
   
     1 
    
    
    
      0 
     
     
     
       − 
      
     
       3 
      
     
    
   
  
    10^{-3} 
   
  
10−3内。

  
   
    
    
      G 
     
    
      E 
     
    
      L 
     
    
      U 
     
    
      ( 
     
    
      x 
     
    
      ) 
     
    
      = 
     
    
      ∈ 
     
     
     
       { 
      
      
       
        
         
         
           0 
          
         
        
        
         
          
          
            ( 
           
          
            x 
           
          
            ≤ 
           
          
            − 
           
          
            4 
           
          
            ) 
           
          
         
        
       
       
        
         
          
          
            P 
           
          
            ( 
           
          
            x 
           
          
            ) 
           
          
            = 
           
           
           
             ∑ 
            
            
            
              i 
             
            
              = 
             
            
              0 
             
            
            
            
              i 
             
            
              = 
             
            
              3 
             
            
           
           
           
             c 
            
           
             i 
            
           
           
           
             x 
            
           
             i 
            
           
          
         
        
        
         
          
          
            ( 
           
          
            − 
           
          
            4 
           
          
            < 
           
          
            x 
           
          
            ≤ 
           
          
            − 
           
          
            1.95 
           
          
            ) 
           
          
         
        
       
       
        
         
          
          
            Q 
           
          
            ( 
           
          
            x 
           
          
            ) 
           
          
            = 
           
           
           
             ∑ 
            
            
            
              i 
             
            
              = 
             
            
              0 
             
            
            
            
              i 
             
            
              = 
             
            
              6 
             
            
           
           
           
             d 
            
           
             i 
            
           
           
           
             x 
            
           
             i 
            
           
          
         
        
        
         
          
          
            ( 
           
          
            − 
           
          
            1.95 
           
          
            < 
           
          
            x 
           
          
            ≤ 
           
          
            3 
           
          
            ) 
           
          
         
        
       
       
        
         
         
           x 
          
         
        
        
         
          
          
            ( 
           
          
            x 
           
          
            > 
           
          
            3 
           
          
            ) 
           
          
         
        
       
      
     
    
   
     GELU(x)=\in \begin{cases} 0 &(x\leq -4) \\ P(x)=\sum_{i=0}^{i=3}c_ix^i &(-4<x\leq -1.95) \\ Q(x)=\sum_{i=0}^{i=6}d_ix^i &(-1.95<x\leq 3) \\ x &(x>3) \end{cases} 
    
   
 GELU(x)=∈⎩⎨⎧​0P(x)=∑i=0i=3​ci​xiQ(x)=∑i=0i=6​di​xix​(x≤−4)(−4<x≤−1.95)(−1.95<x≤3)(x>3)​

首先,使用Sgn操作获得四个加密bit:

      b 
     
    
      0 
     
    
   
     , 
    
    
    
      b 
     
    
      1 
     
    
   
     , 
    
    
    
      b 
     
    
      2 
     
    
   
     , 
    
    
    
      b 
     
    
      3 
     
    
   
  
    b_0,b_1,b_2,b_3 
   
  
b0​,b1​,b2​,b3​,当且仅当输入x属于第i段时, 
 
  
   
    
    
      b 
     
    
      i 
     
    
   
     = 
    
   
     1 
    
   
  
    b_i=1 
   
  
bi​=1,否则 
 
  
   
    
    
      b 
     
    
      i 
     
    
   
     = 
    
   
     0 
    
   
  
    b_i=0 
   
  
bi​=0,如此,GELU(x)函数可以表示为: 
 
  
   
   
     G 
    
   
     E 
    
   
     L 
    
   
     U 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     : 
    
   
     = 
    
    
    
      b 
     
    
      0 
     
    
   
     ⋅ 
    
   
     0 
    
   
     + 
    
    
    
      b 
     
    
      1 
     
    
   
     ⋅ 
    
   
     P 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     + 
    
    
    
      b 
     
    
      2 
     
    
   
     ⋅ 
    
   
     Q 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     + 
    
    
    
      b 
     
    
      3 
     
    
   
     ⋅ 
    
   
     x 
    
   
  
    GELU(x):=b_0·0+b_1·P(x)+b_2·Q(x)+b_3·x 
   
  
GELU(x):=b0​⋅0+b1​⋅P(x)+b2​⋅Q(x)+b3​⋅x​。

完整的Secure GELU算法可以表示如下:

在这里插入图片描述

3.4 Argmax

transformer最终的输出应该是一个选择向量

     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     [ 
    
    
    
      b 
     
    
      0 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      b 
     
     
     
       m 
      
     
       − 
      
     
       1 
      
     
    
   
     ] 
    
   
     ) 
    
   
  
    Enc([b_0,...,b_{m-1}]) 
   
  
Enc([b0​,...,bm−1​]),其中 
 
  
   
    
    
      b 
     
    
      i 
     
    
   
     = 
    
   
     1 
    
   
       
    
   
     i 
    
   
     f 
    
   
       
    
    
    
      a 
     
    
      i 
     
    
   
     = 
    
   
     m 
    
   
     a 
    
   
     x 
    
   
     ( 
    
    
    
      a 
     
    
      0 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       m 
      
     
       − 
      
     
       1 
      
     
    
   
     ) 
    
   
  
    b_i=1 \ if \ a_i=max(a_0,...,a_{m-1}) 
   
  
bi​=1 if ai​=max(a0​,...,am−1​),其他情况下 
 
  
   
    
    
      b 
     
    
      i 
     
    
   
     = 
    
   
     0 
    
   
  
    b_i=0 
   
  
bi​=0​。

因此,本文的Secure Argmax算法如下:

在这里插入图片描述

3.5 Placement of bootstrapping

由于bootstrapping操作是昂贵的,因此合理的放置bootstrapping的位置是至关重要的。

在这里插入图片描述

图4 Placement of bootstrapping for a BERT-base transformer

4. SIMD密文的压缩和分解

假设C想要发送N’个密文给S,且每个密文以SIMD方式加密N个相同的值,Enc([

      a 
     
    
      0 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
    
      0 
     
    
   
  
    a_0,...,a_0 
   
  
a0​,...,a0​]),…,Enc([ 
 
  
   
    
    
      a 
     
     
      
      
        N 
       
      
        ′ 
       
      
     
       − 
      
     
       1 
      
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
      
      
        N 
       
      
        ′ 
       
      
     
       − 
      
     
       1 
      
     
    
   
  
    a_{N'-1},...,a_{N'-1} 
   
  
aN′−1​,...,aN′−1​​​])。

SIMD密文的压缩算法

C将向量[

      a 
     
    
      0 
     
    
   
     , 
    
    
    
      a 
     
    
      1 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
      
      
        N 
       
      
        ′ 
       
      
     
       − 
      
     
       1 
      
     
    
   
  
    a_0,a_1,...,a_{N'-1} 
   
  
a0​,a1​,...,aN′−1​]的各个元素打包到一个多项式的系数中,即:

  
   
    
    
      p 
     
    
      ( 
     
    
      x 
     
    
      ) 
     
    
      = 
     
     
     
       a 
      
     
       0 
      
     
    
      + 
     
     
     
       a 
      
     
       1 
      
     
    
      x 
     
    
      + 
     
     
     
       a 
      
     
       2 
      
     
     
     
       x 
      
     
       2 
      
     
    
      + 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      + 
     
     
     
       a 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        1 
       
      
     
     
     
       x 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        1 
       
      
     
    
   
     p(x)=a_0+a_1x+a_2x^2+...+a_{N'-1}x^{N'-1} 
    
   
 p(x)=a0​+a1​x+a2​x2+...+aN′−1​xN′−1

然后将该多项式加密

       p 
      
     
       ~ 
      
     
    
      0 
     
    
   
     = 
    
   
     E 
    
   
     ( 
    
   
     p 
    
   
     ( 
    
   
     x 
    
   
     ) 
    
   
     ) 
    
   
  
    \widetilde p_0=E(p(x)) 
   
  
p​0​=E(p(x))发送给S。

然后S可以对密文

       p 
      
     
       ~ 
      
     
    
      0 
     
    
   
  
    \widetilde p_0 
   
  
p​0​分解从而得到压缩前 
 
  
   
    
    
      N 
     
    
      ′ 
     
    
   
  
    N' 
   
  
N′个SIMD密文。

S分解密文

       p 
      
     
       ~ 
      
     
    
      0 
     
    
   
  
    \widetilde p_0 
   
  
p​0​过程如下:

SIMD密文的分解算法:

(1)执行

     S 
    
   
     u 
    
   
     b 
    
   
     s 
    
   
     ( 
    
    
     
     
       p 
      
     
       ~ 
      
     
    
      0 
     
    
   
     , 
    
    
    
      N 
     
    
      ′ 
     
    
   
     + 
    
   
     1 
    
   
     ) 
    
   
  
    Subs(\widetilde p_0, N'+1) 
   
  
Subs(p​0​,N′+1)返回:

  
   
    
    
      E 
     
    
      ( 
     
     
     
       a 
      
     
       0 
      
     
    
      + 
     
     
     
       a 
      
     
       1 
      
     
     
     
       x 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        + 
       
      
        1 
       
      
     
    
      + 
     
     
     
       a 
      
     
       2 
      
     
     
     
       x 
      
      
      
        ( 
       
       
       
         N 
        
       
         ′ 
        
       
      
        + 
       
      
        1 
       
       
       
         ) 
        
       
         2 
        
       
      
     
    
      + 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      + 
     
     
     
       a 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        1 
       
      
     
     
     
       x 
      
      
      
        ( 
       
       
       
         N 
        
       
         ′ 
        
       
      
        + 
       
      
        1 
       
       
       
         ) 
        
        
         
         
           N 
          
         
           ′ 
          
         
        
          − 
         
        
          1 
         
        
       
      
     
    
      ) 
     
     
    
      = 
     
    
      E 
     
    
      ( 
     
     
     
       a 
      
     
       0 
      
     
    
      + 
     
     
     
       a 
      
     
       1 
      
     
    
      ( 
     
    
      − 
     
    
      x 
     
    
      ) 
     
    
      + 
     
     
     
       a 
      
     
       2 
      
     
    
      ( 
     
    
      − 
     
    
      x 
     
     
     
       ) 
      
     
       2 
      
     
    
      ) 
     
    
      + 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      + 
     
     
     
       a 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        1 
       
      
     
    
      ( 
     
    
      − 
     
    
      x 
     
     
     
       ) 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        1 
       
      
     
    
      ) 
     
    
   
     E(a_0+a_1x^{N'+1}+a_2x^{(N'+1)^2}+...+a_{N'-1}x^{(N'+1)^{N'-1}}) \\ =E(a_0+a_1(-x)+a_2(-x)^2)+...+a_{N'-1}(-x)^{N'-1}) 
    
   
 E(a0​+a1​xN′+1+a2​x(N′+1)2+...+aN′−1​x(N′+1)N′−1)=E(a0​+a1​(−x)+a2​(−x)2)+...+aN′−1​(−x)N′−1)

注意,

      x 
     
     
     
       N 
      
     
       ′ 
      
     
    
   
     + 
    
   
     1 
    
   
     ≡ 
    
   
     0 
    
   
        
    
   
     ( 
    
   
     m 
    
   
     o 
    
   
     d 
    
   
        
    
    
    
      x 
     
     
     
       N 
      
     
       ′ 
      
     
    
   
     + 
    
   
     1 
    
   
     ) 
    
   
  
    x^{N'}+1 \equiv 0 \ \ (mod \ \ x^{N'} + 1) 
   
  
xN′+1≡0  (mod  xN′+1),因此 
 
  
   
    
    
      x 
     
     
      
      
        N 
       
      
        ′ 
       
      
     
       + 
      
     
       1 
      
     
    
   
     = 
    
    
    
      x 
     
     
     
       N 
      
     
       ′ 
      
     
    
   
     ∗ 
    
   
     x 
    
   
     = 
    
   
     − 
    
   
     x 
    
   
       
    
   
     ( 
    
   
     m 
    
   
     o 
    
   
     d 
    
   
        
    
    
    
      x 
     
     
     
       N 
      
     
       ′ 
      
     
    
   
     + 
    
   
     1 
    
   
     ) 
    
   
  
    x^{N'+1} = x^{N'} * x = -x \ (mod \ \ x^{N'}+1) 
   
  
xN′+1=xN′∗x=−x (mod  xN′+1)​,这里的 
 
  
   
   
     N 
    
   
     ’ 
    
   
  
    N’ 
   
  
N’也就是分圆环的次数。

(2)执行

       p 
      
     
       ~ 
      
     
    
      0 
     
    
   
     + 
    
   
     S 
    
   
     u 
    
   
     b 
    
   
     s 
    
   
     ( 
    
    
     
     
       p 
      
     
       ~ 
      
     
    
      0 
     
    
   
     , 
    
    
    
      N 
     
    
      ′ 
     
    
   
     + 
    
   
     1 
    
   
     ) 
    
   
  
    \widetilde p_0+Subs(\widetilde p_0,N'+1) 
   
  
p​0​+Subs(p​0​,N′+1)​操作,移除p(x)的所有奇数项。

  
   
    
     
     
       a 
      
     
       0 
      
     
    
      + 
     
     
     
       a 
      
     
       1 
      
     
    
      x 
     
    
      + 
     
     
     
       a 
      
     
       2 
      
     
     
     
       x 
      
     
       2 
      
     
    
      + 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      + 
     
     
     
       a 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        1 
       
      
     
     
     
       x 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        + 
       
      
        1 
       
      
     
     
    
      + 
     
     
     
       a 
      
     
       0 
      
     
    
      + 
     
     
     
       a 
      
     
       1 
      
     
    
      ( 
     
    
      − 
     
    
      x 
     
    
      ) 
     
    
      + 
     
     
     
       a 
      
     
       2 
      
     
    
      ( 
     
    
      − 
     
    
      x 
     
     
     
       ) 
      
     
       2 
      
     
    
      ) 
     
    
      + 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      + 
     
     
     
       a 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        1 
       
      
     
    
      ( 
     
    
      − 
     
    
      x 
     
     
     
       ) 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        1 
       
      
     
     
    
      = 
     
     
     
       a 
      
     
       0 
      
     
    
      + 
     
    
      0 
     
    
      x 
     
    
      + 
     
     
     
       a 
      
     
       2 
      
     
     
     
       x 
      
     
       2 
      
     
    
      + 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      + 
     
     
     
       a 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        2 
       
      
     
     
     
       x 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        2 
       
      
     
    
      + 
     
    
      0 
     
     
     
       x 
      
      
       
       
         N 
        
       
         ′ 
        
       
      
        − 
       
      
        1 
       
      
     
    
   
     a_0+a_1x+a_2x^2+...+a_{N'-1}x^{N'+1} \\+ a_0+a_1(-x)+a_2(-x)^2)+...+a_{N'-1}(-x)^{N'-1}\\= a_0+0x+a_2x^2+...+a_{N'-2}x^{N'-2}+0x^{N'-1} 
    
   
 a0​+a1​x+a2​x2+...+aN′−1​xN′+1+a0​+a1​(−x)+a2​(−x)2)+...+aN′−1​(−x)N′−1=a0​+0x+a2​x2+...+aN′−2​xN′−2+0xN′−1

(3)通过

     l 
    
   
     o 
    
   
     g 
    
    
    
      N 
     
    
      ′ 
     
    
   
  
    log N' 
   
  
logN′次 
 
  
   
   
     S 
    
   
     u 
    
   
     b 
    
   
     s 
    
   
     ( 
    
   
     ) 
    
   
  
    Subs() 
   
  
Subs()操作,S可以提取得到密文: 
 
  
   
   
     E 
    
   
     ( 
    
    
    
      a 
     
    
      0 
     
    
   
     + 
    
   
     0 
    
    
    
      x 
     
    
      1 
     
    
   
     + 
    
   
     0 
    
    
    
      x 
     
    
      2 
     
    
   
     + 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     + 
    
   
     0 
    
    
    
      x 
     
     
      
      
        N 
       
      
        ′ 
       
      
     
       − 
      
     
       1 
      
     
    
   
     ) 
    
   
  
    E(a_0+0x^1+0x^2+...+0x^{N'-1}) 
   
  
E(a0​+0x1+0x2+...+0xN′−1),实际上,这就是密文Enc([ 
 
  
   
    
    
      a 
     
    
      0 
     
    
   
     , 
    
    
    
      a 
     
    
      0 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
    
      0 
     
    
   
  
    a_0,a_0,...,a_0 
   
  
a0​,a0​,...,a0​])。完整的操作流程如下:

在这里插入图片描述

类似地,为了提取E(

      a 
     
    
      1 
     
    
   
     + 
    
   
     0 
    
    
    
      x 
     
    
      1 
     
    
   
     + 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     + 
    
   
     0 
    
    
    
      x 
     
     
      
      
        N 
       
      
        ′ 
       
      
     
       − 
      
     
       1 
      
     
    
   
  
    a_1+0x^1+...+0x^{N'-1} 
   
  
a1​+0x1+...+0xN′−1),S应该左旋明文多项式p(x)一个单位,通过乘以 
 
  
   
    
    
      x 
     
     
     
       − 
      
     
       1 
      
     
    
   
  
    x^{-1} 
   
  
x−1,然后再次执行上述的提取过程。通过执行 
 
  
   
   
     N 
    
   
     ‘ 
    
   
  
    N‘ 
   
  
N‘次该提取过程,S可以获得向量[ 
 
  
   
    
    
      a 
     
    
      0 
     
    
   
     , 
    
    
    
      a 
     
    
      1 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
      
      
        N 
       
      
        ′ 
       
      
     
       − 
      
     
       1 
      
     
    
   
  
    a_0,a_1,...,a_{N'-1} 
   
  
a0​,a1​,...,aN′−1​​]中每个元素的单独SIMD格式加密。

然而上述过程需要执行

     ( 
    
    
    
      N 
     
    
      ′ 
     
    
   
     ⋅ 
    
   
     l 
    
   
     o 
    
   
     g 
    
    
    
      N 
     
    
      ′ 
     
    
   
     ) 
    
   
  
    (N'·logN') 
   
  
(N′⋅logN′)次 
 
  
   
   
     S 
    
   
     u 
    
   
     b 
    
   
     s 
    
   
     ( 
    
   
     ) 
    
   
  
    Subs() 
   
  
Subs()操作。对比之下,本文提出一种算法,可以实现相同的目标,但是仅需要 
 
  
   
   
     2 
    
    
    
      N 
     
    
      ′ 
     
    
   
  
    2N' 
   
  
2N′次 
 
  
   
   
     S 
    
   
     u 
    
   
     b 
    
   
     s 
    
   
     ( 
    
   
     ) 
    
   
  
    Subs() 
   
  
Subs()操作。该算法可以简单地描述如下:

在这里插入图片描述


算法5是Secure Decompression的详细描述:

在这里插入图片描述

下面提供上述分解操作的理论证明:

Theorem 1:

仅有常数项的多项式的加密E(

      a 
     
    
      s 
     
    
   
     + 
    
   
     0 
    
    
    
      x 
     
    
      1 
     
    
   
     + 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     + 
    
   
     0 
    
    
    
      x 
     
     
      
      
        N 
       
      
        ′ 
       
      
     
       − 
      
     
       1 
      
     
    
   
  
    a_s+0x^1+...+0x^{N'-1} 
   
  
as​+0x1+...+0xN′−1)是向量[ 
 
  
   
    
    
      a 
     
    
      s 
     
    
   
     , 
    
    
    
      a 
     
    
      s 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
    
      s 
     
    
   
  
    a_s,a_s,...,a_s 
   
  
as​,as​,...,as​]的加密Enc([ 
 
  
   
    
    
      a 
     
    
      s 
     
    
   
     , 
    
    
    
      a 
     
    
      s 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
    
      s 
     
    
   
  
    a_s,a_s,...,a_s 
   
  
as​,as​,...,as​​])。

在这里插入图片描述


4.1 Application to matrix multiplication

压缩分解技术可以自然地应用于MatrixMul,此外,基于下面的观察结果,本文进一步优化了矩阵乘法,观察到在transformer推理过程中,对于不同输入的矩阵

     A 
    
   
     ∈ 
    
    
    
      R 
     
     
     
       m 
      
     
       × 
      
     
       n 
      
     
    
   
  
    A\in \R^{m\times n} 
   
  
A∈Rm×n需要乘以相同的矩阵 
 
  
   
   
     W 
    
   
     ∈ 
    
    
    
      R 
     
     
     
       n 
      
     
       × 
      
     
       k 
      
     
    
   
  
    W\in \R^{n\times k} 
   
  
W∈Rn×k​。

     A 
    
   
     = 
    
   
     [ 
    
    
    
      a 
     
    
      0 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       n 
      
     
       − 
      
     
       1 
      
     
    
   
     ] 
    
   
  
    A = [a_0,...,a_{n-1}] 
   
  
A=[a0​,...,an−1​],其中 
 
  
   
    
    
      a 
     
    
      i 
     
    
   
     ∈ 
    
    
    
      R 
     
    
      m 
     
    
   
  
    a_i\in \R^m 
   
  
ai​∈Rm表示矩阵 
 
  
   
   
     A 
    
   
  
    A 
   
  
A的第 
 
  
   
   
     i 
    
   
  
    i 
   
  
i行。假设S和C需要生成t个响应词,即有t个输入矩阵:

  
   
    
     
     
       A 
      
     
       0 
      
     
    
      = 
     
    
      [ 
     
     
     
       a 
      
      
      
        0 
       
      
        , 
       
      
        0 
       
      
     
    
      , 
     
     
     
       a 
      
      
      
        0 
       
      
        , 
       
      
        1 
       
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       a 
      
      
      
        0 
       
      
        , 
       
      
        n 
       
      
        − 
       
      
        1 
       
      
     
    
      ] 
     
     
     
     
       A 
      
     
       1 
      
     
    
      = 
     
    
      [ 
     
     
     
       a 
      
      
      
        1 
       
      
        , 
       
      
        0 
       
      
     
    
      , 
     
     
     
       a 
      
      
      
        1 
       
      
        , 
       
      
        1 
       
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       a 
      
      
      
        1 
       
      
        , 
       
      
        n 
       
      
        − 
       
      
        1 
       
      
     
    
      ] 
     
     
    
      . 
     
    
      . 
     
    
      . 
     
     
     
     
       A 
      
     
       0 
      
     
    
      = 
     
    
      [ 
     
     
     
       a 
      
      
      
        t 
       
      
        − 
       
      
        1 
       
      
        , 
       
      
        0 
       
      
     
    
      , 
     
     
     
       a 
      
      
      
        t 
       
      
        − 
       
      
        1 
       
      
        , 
       
      
        1 
       
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       a 
      
      
      
        t 
       
      
        − 
       
      
        1 
       
      
        , 
       
      
        n 
       
      
        − 
       
      
        1 
       
      
     
    
      ] 
     
    
   
     A_0=[a_{0,0},a_{0,1},...,a_{0,n-1}] \\ A_1=[a_{1,0},a_{1,1},...,a_{1,n-1}] \\ ... \\ A_0=[a_{t-1,0},a_{t-1,1},...,a_{t-1,n-1}] 
    
   
 A0​=[a0,0​,a0,1​,...,a0,n−1​]A1​=[a1,0​,a1,1​,...,a1,n−1​]...A0​=[at−1,0​,at−1,1​,...,at−1,n−1​]

      a 
     
    
      i 
     
    
      ′ 
     
    
   
     = 
    
    
    
      [ 
     
     
      
       
        
         
         
           a 
          
          
          
            0 
           
          
            , 
           
          
            i 
           
          
         
        
       
      
      
       
        
         
         
           a 
          
          
          
            1 
           
          
            , 
           
          
            i 
           
          
         
        
       
      
      
       
        
         
         
           . 
          
         
           . 
          
         
           . 
          
         
        
       
      
      
       
        
         
         
           a 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
            , 
           
          
            i 
           
          
         
        
       
      
     
    
      ] 
     
    
   
  
    a'_i=\left[\begin{matrix} a_{0,i} \\ a_{1,i} \\ ... \\ a_{t-1,i} \end{matrix} \right] 
   
  
ai′​=​a0,i​a1,i​...at−1,i​​​和 
 
  
   
    
    
      q 
     
    
      j 
     
    
      ′ 
     
    
   
     : 
    
   
     = 
    
    
    
      ∑ 
     
     
     
       i 
      
     
       = 
      
     
       0 
      
     
     
     
       n 
      
     
       − 
      
     
       1 
      
     
    
    
    
      a 
     
    
      i 
     
    
      ′ 
     
    
    
    
      w 
     
     
     
       i 
      
     
       , 
      
     
       j 
      
     
    
   
         
    
   
     ∀ 
    
   
     j 
    
   
     ∈ 
    
   
     [ 
    
   
     k 
    
   
     ] 
    
   
  
    q'_j:=\sum^{n-1}_{i=0}a'_iw_{i,j}\ \ \ \forall j\in [k] 
   
  
qj′​:=∑i=0n−1​ai′​wi,j​   ∀j∈[k],则有

  
   
    
     
     
       Q 
      
     
       ′ 
      
     
    
      = 
     
     
     
       q 
      
     
       0 
      
     
       ′ 
      
     
    
      ∣ 
     
    
      ∣ 
     
     
     
       q 
      
     
       1 
      
     
       ′ 
      
     
    
      ∣ 
     
    
      ∣ 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      ∣ 
     
    
      ∣ 
     
     
     
       q 
      
      
      
        k 
       
      
        − 
       
      
        1 
       
      
     
       ′ 
      
     
    
      = 
     
     
     
       [ 
      
      
       
        
         
          
           
           
             A 
            
           
             0 
            
           
          
            W 
           
          
         
        
       
       
        
         
          
           
           
             A 
            
           
             1 
            
           
          
            W 
           
          
         
        
       
       
        
         
          
          
            . 
           
          
            . 
           
          
            . 
           
          
         
        
       
       
        
         
          
           
           
             A 
            
            
            
              t 
             
            
              − 
             
            
              1 
             
            
           
          
            W 
           
          
         
        
       
      
     
       ] 
      
     
    
   
     Q'=q'_0||q'_1||...||q'_{k-1}=\left[\begin{matrix} A_0W \\ A_1W \\ ... \\ A_{t-1}W \end{matrix} \right] 
    
   
 Q′=q0′​∣∣q1′​∣∣...∣∣qk−1′​=​A0​WA1​W...At−1​W​​

在这里插入图片描述


预计算阶段:

这里,我们引入一个预计算阶段,其中S使用上述提到的密文压缩技术,将压缩后的密文

     ( 
    
   
     E 
    
   
     n 
    
    
    
      c 
     
    
      S 
     
    
   
     ( 
    
   
     [ 
    
    
     
      
       
       
         w 
        
        
        
          i 
         
        
          , 
         
        
          j 
         
        
       
      
        , 
       
       
       
         w 
        
        
        
          i 
         
        
          , 
         
        
          j 
         
        
       
      
        , 
       
      
        . 
       
      
        . 
       
      
        . 
       
      
        , 
       
       
       
         w 
        
        
        
          i 
         
        
          . 
         
        
          j 
         
        
       
      
     
       ⏟ 
      
     
     
     
       t 
      
     
       × 
      
     
       m 
      
     
    
   
     ] 
    
   
     ) 
    
   
        
    
   
     ∀ 
    
   
     i 
    
   
     ∈ 
    
   
     [ 
    
   
     n 
    
   
     ] 
    
   
     , 
    
   
     j 
    
   
     ∈ 
    
   
     [ 
    
   
     k 
    
   
     ] 
    
   
     ) 
    
   
  
    (Enc_S([\underbrace{w_{i,j},w_{i,j},...,w_{i.j}}_{t\times m}])\ \ \forall i \in [n],j\in [k]) 
   
  
(EncS​([t×mwi,j​,wi,j​,...,wi.j​​​])  ∀i∈[n],j∈[k])发送给C。注意,该传输仅只发生一次,除非模型发生改变。接下来,C对压缩的密文执行分解技术,以获得 
 
  
   
   
     E 
    
   
     n 
    
    
    
      c 
     
    
      S 
     
    
   
     ( 
    
   
     [ 
    
    
     
      
       
       
         w 
        
        
        
          i 
         
        
          , 
         
        
          j 
         
        
       
      
        , 
       
       
       
         w 
        
        
        
          i 
         
        
          , 
         
        
          j 
         
        
       
      
        , 
       
      
        . 
       
      
        . 
       
      
        . 
       
      
        , 
       
       
       
         w 
        
        
        
          i 
         
        
          . 
         
        
          j 
         
        
       
      
     
       ⏟ 
      
     
     
     
       t 
      
     
       × 
      
     
       m 
      
     
    
   
     ] 
    
   
     ) 
    
   
        
    
   
     ∀ 
    
   
     i 
    
   
     ∈ 
    
   
     [ 
    
   
     n 
    
   
     ] 
    
   
     , 
    
   
     j 
    
   
     ∈ 
    
   
     [ 
    
   
     k 
    
   
     ] 
    
   
  
    Enc_S([\underbrace{w_{i,j},w_{i,j},...,w_{i.j}}_{t\times m}])\ \ \forall i \in [n],j\in [k] 
   
  
EncS​([t×mwi,j​,wi,j​,...,wi.j​​​])  ∀i∈[n],j∈[k]。在预计算阶段C并没有关于输入的信息,采样 
 
  
   
   
     U 
    
   
     ∈ 
    
    
    
      R 
     
     
     
       ( 
      
     
       t 
      
     
       m 
      
     
       ) 
      
     
       × 
      
     
       n 
      
     
    
   
  
    U\in \R^{(tm)\times n} 
   
  
U∈R(tm)×n,然后计算:

  
   
    
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       S 
      
     
    
      ( 
     
     
     
       v 
      
     
       j 
      
     
    
      ) 
     
    
      ← 
     
     
     
       ∑ 
      
      
      
        i 
       
      
        = 
       
      
        0 
       
      
      
      
        n 
       
      
        − 
       
      
        1 
       
      
     
    
      ( 
     
     
     
       u 
      
     
       i 
      
     
    
      × 
     
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       S 
      
     
    
      ( 
     
    
      [ 
     
     
     
       w 
      
      
      
        i 
       
      
        , 
       
      
        j 
       
      
     
    
      , 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      , 
     
     
     
       w 
      
      
      
        i 
       
      
        , 
       
      
        j 
       
      
     
    
      ] 
     
    
      ) 
     
    
      ) 
     
    
          
     
    
      ∀ 
     
    
      j 
     
    
      ∈ 
     
    
      [ 
     
    
      k 
     
    
      ] 
     
    
   
     Enc_S(v_j)\leftarrow\sum^{n-1}_{i=0}(u_i\times Enc_S([w_{i,j},...,w_{i,j}]))\ \ \ \forall j\in[k] 
    
   
 EncS​(vj​)←i=0∑n−1​(ui​×EncS​([wi,j​,...,wi,j​]))   ∀j∈[k]

其中

      u 
     
    
      i 
     
    
   
  
    u_i 
   
  
ui​是矩阵 
 
  
   
   
     U 
    
   
  
    U 
   
  
U的第i列。接下来,C使用自己的密钥来加密 
 
  
   
   
     E 
    
   
     n 
    
    
    
      c 
     
    
      S 
     
    
   
     ( 
    
    
    
      v 
     
    
      j 
     
    
   
     ) 
    
   
  
    Enc_S(v_j) 
   
  
EncS​(vj​)以获得 
 
  
   
   
     E 
    
   
     n 
    
    
    
      c 
     
    
      C 
     
    
   
     ( 
    
   
     E 
    
   
     n 
    
    
    
      c 
     
    
      S 
     
    
   
     ( 
    
    
    
      v 
     
    
      j 
     
    
   
     ) 
    
   
     ) 
    
   
  
    Enc_C(Enc_S(v_j)) 
   
  
EncC​(EncS​(vj​)),并将其发送给S。注意 
 
  
   
   
     E 
    
   
     n 
    
    
    
      c 
     
    
      S 
     
    
   
     ( 
    
   
     E 
    
   
     n 
    
    
    
      c 
     
    
      C 
     
    
   
     ( 
    
    
    
      v 
     
    
      j 
     
    
   
     ) 
    
   
     ) 
    
   
     = 
    
   
     E 
    
   
     n 
    
    
    
      c 
     
    
      C 
     
    
   
     ( 
    
   
     E 
    
   
     n 
    
    
    
      c 
     
    
      S 
     
    
   
     ( 
    
    
    
      v 
     
    
      j 
     
    
   
     ) 
    
   
     ) 
    
   
  
    Enc_S(Enc_C(v_j))=Enc_C(Enc_S(v_j)) 
   
  
EncS​(EncC​(vj​))=EncC​(EncS​(vj​)),故S可以对其进行解密,从而获得 
 
  
   
   
     E 
    
   
     n 
    
    
    
      c 
     
    
      C 
     
    
   
     ( 
    
    
    
      v 
     
    
      j 
     
    
   
     ) 
    
   
  
    Enc_C(v_j) 
   
  
EncC​(vj​)。注意,这里的 
 
  
   
    
    
      v 
     
    
      j 
     
    
   
  
    v_j 
   
  
vj​是矩阵 
 
  
   
   
     U 
    
   
     ⋅ 
    
   
     W 
    
   
  
    U·W 
   
  
U⋅W的第 
 
  
   
   
     j 
    
   
  
    j 
   
  
j列。

切换不同用户加密密钥的过程计算如下:

给定

     c 
    
    
    
      t 
     
    
      S 
     
    
   
     = 
    
   
     ( 
    
   
     − 
    
   
     a 
    
    
    
      s 
     
    
      S 
     
    
   
     + 
    
   
     m 
    
   
     + 
    
   
     e 
    
   
     ) 
    
   
  
    ct_S=(-as_S+m+e) 
   
  
ctS​=(−asS​+m+e),

使用

      s 
     
    
      C 
     
    
   
  
    s_C 
   
  
sC​加密有 
 
  
   
   
     c 
    
    
    
      t 
     
     
     
       C 
      
     
       , 
      
     
       S 
      
     
    
   
     = 
    
   
     ( 
    
   
     − 
    
   
     a 
    
    
    
      s 
     
    
      S 
     
    
   
     − 
    
   
     a 
    
    
    
      s 
     
    
      C 
     
    
   
     + 
    
   
     m 
    
   
     + 
    
   
     e 
    
   
     + 
    
    
    
      e 
     
    
      ′ 
     
    
   
     , 
    
   
     a 
    
   
     ) 
    
   
  
    ct_{C,S}=(-as_S-as_C+m+e+e',a) 
   
  
ctC,S​=(−asS​−asC​+m+e+e′,a),

使用

      s 
     
    
      S 
     
    
   
  
    s_S 
   
  
sS​解密有:  
 
  
   
   
     c 
    
    
    
      t 
     
    
      C 
     
    
   
     = 
    
   
     ( 
    
   
     − 
    
   
     a 
    
    
    
      s 
     
    
      S 
     
    
   
     − 
    
   
     a 
    
    
    
      s 
     
    
      C 
     
    
   
     + 
    
   
     m 
    
   
     + 
    
   
     e 
    
   
     + 
    
    
    
      e 
     
    
      ′ 
     
    
   
     , 
    
   
     a 
    
   
     ) 
    
   
     + 
    
   
     ( 
    
   
     a 
    
    
    
      s 
     
    
      S 
     
    
   
     , 
    
   
     0 
    
   
     ) 
    
   
     = 
    
   
     ( 
    
   
     − 
    
   
     a 
    
    
    
      s 
     
    
      C 
     
    
   
     + 
    
   
     m 
    
   
     + 
    
   
     e 
    
   
     + 
    
    
    
      e 
     
    
      ′ 
     
    
   
     ) 
    
   
  
    ct_C=(-as_S-as_C+m+e+e',a)+(as_S,0)=(-as_C+m+e+e') 
   
  
ctC​=(−asS​−asC​+m+e+e′,a)+(asS​,0)=(−asC​+m+e+e′).

在线处理阶段:

此时,C知道输入的信息

      A 
     
    
      ′ 
     
    
   
     = 
    
    
    
      a 
     
    
      0 
     
    
      ′ 
     
    
   
     ∣ 
    
   
     ∣ 
    
    
    
      a 
     
    
      1 
     
    
      ′ 
     
    
   
     ∣ 
    
   
     ∣ 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     ∣ 
    
   
     ∣ 
    
    
    
      a 
     
     
     
       n 
      
     
       − 
      
     
       1 
      
     
    
      ′ 
     
    
   
  
    A'=a'_0||a'_1||...||a'_{n-1} 
   
  
A′=a0′​∣∣a1′​∣∣...∣∣an−1′​,然后C将明文 
 
  
   
   
     ( 
    
    
    
      A 
     
    
      ′ 
     
    
   
     − 
    
   
     U 
    
   
     ) 
    
   
  
    (A'-U) 
   
  
(A′−U)发送给S,注意,由于S不知道U的值,故S也不清楚 
 
  
   
    
    
      A 
     
    
      ′ 
     
    
   
  
    A' 
   
  
A′的值,然后S可以计算:

  
   
    
    
      ( 
     
     
     
       A 
      
     
       ′ 
      
     
    
      − 
     
    
      U 
     
    
      ) 
     
    
      ⋅ 
     
    
      W 
     
    
      + 
     
    
      ( 
     
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       C 
      
     
    
      ( 
     
     
     
       v 
      
     
       0 
      
     
    
      ) 
     
    
      ∣ 
     
    
      ∣ 
     
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       C 
      
     
    
      ( 
     
     
     
       v 
      
     
       1 
      
     
    
      ) 
     
    
      ∣ 
     
    
      ∣ 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      ∣ 
     
    
      ∣ 
     
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       C 
      
     
    
      ( 
     
     
     
       v 
      
      
      
        k 
       
      
        − 
       
      
        1 
       
      
     
    
      ) 
     
    
      ) 
     
     
    
      = 
     
    
      ( 
     
     
     
       A 
      
     
       ′ 
      
     
    
      W 
     
    
      − 
     
    
      V 
     
    
      ) 
     
    
      + 
     
    
      ( 
     
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       C 
      
     
    
      ( 
     
     
     
       v 
      
     
       0 
      
     
    
      ) 
     
    
      ∣ 
     
    
      ∣ 
     
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       C 
      
     
    
      ( 
     
     
     
       v 
      
     
       1 
      
     
    
      ) 
     
    
      ∣ 
     
    
      ∣ 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      ∣ 
     
    
      ∣ 
     
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       C 
      
     
    
      ( 
     
     
     
       v 
      
      
      
        k 
       
      
        − 
       
      
        1 
       
      
     
    
      ) 
     
    
      ) 
     
     
    
      = 
     
    
      ( 
     
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       C 
      
     
    
      ( 
     
     
     
       q 
      
     
       0 
      
     
       ′ 
      
     
    
      ) 
     
    
      ∣ 
     
    
      ∣ 
     
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       C 
      
     
    
      ( 
     
     
     
       q 
      
     
       1 
      
     
       ′ 
      
     
    
      ) 
     
    
      ∣ 
     
    
      ∣ 
     
    
      . 
     
    
      . 
     
    
      . 
     
    
      ∣ 
     
    
      ∣ 
     
    
      E 
     
    
      n 
     
     
     
       c 
      
     
       C 
      
     
    
      ( 
     
     
     
       q 
      
      
      
        k 
       
      
        − 
       
      
        1 
       
      
     
       ′ 
      
     
    
      ) 
     
    
      ) 
     
    
   
     (A'-U)·W + (Enc_C(v_0)||Enc_C(v_1)||...||Enc_C(v_{k-1}))\\ =(A'W-V)+(Enc_C(v_0)||Enc_C(v_1)||...||Enc_C(v_{k-1})) \\ =(Enc_C(q'_0)||Enc_C(q'_1)||...||Enc_C(q'_{k-1})) 
    
   
 (A′−U)⋅W+(EncC​(v0​)∣∣EncC​(v1​)∣∣...∣∣EncC​(vk−1​))=(A′W−V)+(EncC​(v0​)∣∣EncC​(v1​)∣∣...∣∣EncC​(vk−1​))=(EncC​(q0′​)∣∣EncC​(q1′​)∣∣...∣∣EncC​(qk−1′​))

其中

      q 
     
    
      j 
     
    
      ′ 
     
    
   
  
    q'_j 
   
  
qj′​是矩阵 
 
  
   
    
    
      Q 
     
    
      ′ 
     
    
   
  
    Q' 
   
  
Q′的第 
 
  
   
   
     j 
    
   
  
    j 
   
  
j列。算法6描述了优化后的矩阵乘法细节:

在这里插入图片描述

可以注意到,只需要预计算的过程交互一次,此后C可以直接向S发送明文信息

      A 
     
    
      ′ 
     
    
   
     − 
    
   
     U 
    
   
  
    A'-U 
   
  
A′−U,而不会泄露 
 
  
   
    
    
      A 
     
    
      ′ 
     
    
   
  
    A' 
   
  
A′的信息。

5 SIMD槽折叠算法

回想矩阵Q,K,V的行向量是使用SIMD方式加密的。而上述介绍的一系列操作,如内积,Softmax,LayerNorm和Argmax等, 均涉及到利用所有槽元素计算函数

     f 
    
   
     ( 
    
   
     ∗ 
    
   
     ) 
    
   
  
    f(*) 
   
  
f(∗),并将得到的结果放置到所有槽上。例如给定 
 
  
   
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     [ 
    
    
    
      a 
     
    
      0 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       N 
      
     
       − 
      
     
       1 
      
     
    
   
     ] 
    
   
     ) 
    
   
  
    Enc([a_0,...,a_{N-1}]) 
   
  
Enc([a0​,...,aN−1​]),然后想要获得 
 
  
   
   
     E 
    
   
     n 
    
   
     c 
    
   
     ( 
    
   
     [ 
    
   
     s 
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
   
     s 
    
   
     ] 
    
   
     ) 
    
   
  
    Enc([s,...,s]) 
   
  
Enc([s,...,s]),其中 
 
  
   
   
     s 
    
   
     = 
    
    
    
      ∑ 
     
     
     
       i 
      
     
       = 
      
     
       0 
      
     
     
     
       N 
      
     
       − 
      
     
       1 
      
     
    
    
    
      a 
     
    
      i 
     
    
   
  
    s=\sum^{N-1}_{i=0}a_i 
   
  
s=∑i=0N−1​ai​,此时 
 
  
   
   
     f 
    
   
     ( 
    
   
     ∗ 
    
   
     ) 
    
   
  
    f(*) 
   
  
f(∗)即求和函数。

本节提供了一种通用的解决方案,只要函数

     f 
    
   
     ( 
    
   
     ∗ 
    
   
     ) 
    
   
  
    f(*) 
   
  
f(∗)满足:

  
   
    
    
      f 
     
    
      ( 
     
    
      f 
     
    
      ( 
     
     
     
       a 
      
     
       0 
      
     
    
      , 
     
     
     
       a 
      
     
       1 
      
     
    
      ) 
     
    
      , 
     
     
     
       a 
      
     
       2 
      
     
    
      ) 
     
    
      = 
     
    
      f 
     
    
      ( 
     
     
     
       a 
      
     
       0 
      
     
    
      , 
     
    
      f 
     
    
      ( 
     
     
     
       a 
      
     
       1 
      
     
    
      , 
     
     
     
       a 
      
     
       2 
      
     
    
      ) 
     
    
      ) 
     
    
   
     f(f(a_0,a_1),a_2)=f(a_0,f(a_1,a_2)) 
    
   
 f(f(a0​,a1​),a2​)=f(a0​,f(a1​,a2​))

算法7描述了槽折叠算法的细节:

在这里插入图片描述

这里是一个简单的例子,可以看到算法7的实现流程:

在这里插入图片描述

5.1 QuickSum

给定

     [ 
    
    
    
      a 
     
    
      0 
     
    
   
     , 
    
    
    
      a 
     
    
      1 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       n 
      
     
       − 
      
     
       1 
      
     
    
   
     , 
    
   
     0 
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
   
     0 
    
   
     ] 
    
   
  
    [a_0,a_1,...,a_{n-1},0,...,0] 
   
  
[a0​,a1​,...,an−1​,0,...,0],为了获得 
 
  
   
   
     [ 
    
    
    
      ∑ 
     
     
     
       i 
      
     
       = 
      
     
       0 
      
     
     
     
       N 
      
     
       − 
      
     
       1 
      
     
    
    
    
      a 
     
    
      i 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      ∑ 
     
     
     
       i 
      
     
       = 
      
     
       0 
      
     
     
     
       N 
      
     
       − 
      
     
       1 
      
     
    
    
    
      a 
     
    
      i 
     
    
   
     , 
    
   
     0 
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
   
     0 
    
   
     ] 
    
   
  
    [\sum^{N-1}_{i=0}a_i,...,\sum^{N-1}_{i=0}a_i,0,...,0] 
   
  
[∑i=0N−1​ai​,...,∑i=0N−1​ai​,0,...,0],可以将算法7的第5行替换为 
 
  
   
    
    
      s 
     
    
      ~ 
     
    
   
     ← 
    
    
    
      s 
     
    
      ~ 
     
    
   
     + 
    
    
    
      a 
     
    
      ~ 
     
    
   
  
    \tilde{s}\leftarrow\tilde{s}+\tilde{a} 
   
  
s~←s~+a~。

5.2 QuickMax

给定

     [ 
    
    
    
      a 
     
    
      0 
     
    
   
     , 
    
    
    
      a 
     
    
      1 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       n 
      
     
       − 
      
     
       1 
      
     
    
   
     , 
    
   
     0 
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
   
     0 
    
   
     ] 
    
   
  
    [a_0,a_1,...,a_{n-1},0,...,0] 
   
  
[a0​,a1​,...,an−1​,0,...,0],为了获得 
 
  
   
    
    
      a 
     
     
     
       m 
      
     
       a 
      
     
       x 
      
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       m 
      
     
       a 
      
     
       x 
      
     
    
   
     , 
    
   
     0 
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
   
     0 
    
   
  
    a_{max},...,a_{max},0,...,0 
   
  
amax​,...,amax​,0,...,0,其中 
 
  
   
    
    
      a 
     
     
     
       m 
      
     
       a 
      
     
       x 
      
     
    
   
     = 
    
   
     m 
    
   
     a 
    
   
     x 
    
   
     ( 
    
    
    
      a 
     
    
      0 
     
    
   
     , 
    
    
    
      a 
     
    
      1 
     
    
   
     , 
    
   
     . 
    
   
     . 
    
   
     . 
    
   
     , 
    
    
    
      a 
     
     
     
       n 
      
     
       − 
      
     
       1 
      
     
    
   
     ) 
    
   
  
    a_{max}=max(a_0,a_1,...,a_{n-1}) 
   
  
amax​=max(a0​,a1​,...,an−1​),很明显 
 
  
   
   
     m 
    
   
     a 
    
   
     x 
    
   
     ( 
    
   
     a 
    
   
     , 
    
   
     b 
    
   
     ) 
    
   
  
    max(a,b) 
   
  
max(a,b)可以表示为:

  
   
    
    
      m 
     
    
      a 
     
    
      x 
     
    
      ( 
     
    
      a 
     
    
      , 
     
    
      b 
     
    
      ) 
     
    
      = 
     
     
      
      
        a 
       
      
        + 
       
      
        b 
       
      
        + 
       
      
        ( 
       
      
        a 
       
      
        − 
       
      
        b 
       
      
        ) 
       
      
        ⋅ 
       
      
        S 
       
      
        g 
       
      
        n 
       
      
        ( 
       
      
        a 
       
      
        − 
       
      
        b 
       
      
        ) 
       
      
     
       2 
      
     
    
   
     max(a,b)={a+b+(a-b)·Sgn(a-b)\over 2} 
    
   
 max(a,b)=2a+b+(a−b)⋅Sgn(a−b)​

因此可以将算法7的第5行替换为:

       s 
      
     
       ~ 
      
     
    
      ← 
     
    
      0.5 
     
    
      ⊗ 
     
    
      ( 
     
     
     
       a 
      
     
       ~ 
      
     
    
      ⊕ 
     
     
     
       s 
      
     
       ~ 
      
     
    
      ⊕ 
     
    
      ( 
     
     
     
       a 
      
     
       ~ 
      
     
    
      ⊖ 
     
     
     
       s 
      
     
       ~ 
      
     
    
      ) 
     
    
      ⊗ 
     
    
      S 
     
    
      g 
     
    
      n 
     
    
      ( 
     
     
     
       a 
      
     
       ~ 
      
     
    
      ⊖ 
     
     
     
       s 
      
     
       ~ 
      
     
    
      ) 
     
    
      ) 
     
    
   
     \tilde{s}\leftarrow 0.5\otimes(\tilde{a}\oplus\tilde{s}\oplus(\tilde{a}\ominus\tilde{s})\otimes Sgn(\tilde{a}\ominus\tilde{s})) 
    
   
 s~←0.5⊗(a~⊕s~⊕(a~⊖s~)⊗Sgn(a~⊖s~))

6. Conclusion

本文提出了NEXUS系统,可以说是第一个不需要客户端和服务器进行交互的安全transformer推理协议。本文提出了适用于RNS-CKKS的一系列新协议,以使得服务器可以高效且精确的在加密数据上计算transformer的每一层。


本文转载自: https://blog.csdn.net/Wwj2981/article/details/142500904
版权归原作者 lucky_wjie 所有, 如有侵权,请联系我们删除。

“基于CKKS的非交互式安全Transformer推理实现”的评论:

还没有评论