0


基于CKKS的非交互式安全Transformer推理实现

Secure Transformer Inference Made Non-interactive

本文介绍了如何使用CKKS来计算transformer推理的每个部分。同时给出了一系列优化算法。主要涉及到的计算算法有以下几种: 密文的压缩与分解技术、SIMD槽折叠技术、Sgn()、QuickSum、QuickMax、密文-明文矩阵相乘、密文-密文矩阵相乘法、Softmax算法、归一化、GELU函数、Argmax函数等等。其中密文的压缩与分解技术,和SIMD槽折叠技术是本文的核心创新算法。

Abstract

随着ChatGPT的普及,安全transformer推理已经成为一个突出了研究主题。已有的解决方法通常是交互式的,涉及到客户端和服务端之间大量的通信负载和交互轮次。

本文提出NEXUS,这是第一个用于安全transformer推理的非交互式协议,其中客户端仅需要提交一个加密输入,然后等待来自服务器的加密结果即可。NEXUS的核心是两个创新的技术:SIMD密文压缩和分解技术,以及SIMD槽折叠技术。此外,同24年的另外一个解决方案相比,本方法达到了2.8倍的加速,且减少了368.6倍的带宽消耗。

1 Introduction

Transformers,例如GPT和BERT,已经彻底改变了AI领域。Transformer擅长于广泛领域的应用,比如语言翻译,内容生成以及问题回答。然而这些应用总是涉及到敏感数据,从而导致越来越多地关于用户隐私的担忧。例,OpenAI开发的ChatGPT作为一种在线推理服务,以及为开发人员提供的远程API,其中使用者通过提交prompts或者消息可以很容易地访问这些服务。尽管这些方法是方便的,但是由于使用者提交的数据可能包含敏感信息,故而造成了严重的隐私风险。

Secure inference是一种两方密码协议,该协议使模型推理以如下方式处理运行,即服务器S不会了解到关于客户C提交的输入的任何信息,且C不会了解到关于S的模型的任何信息,仅仅能得到最终的推理结果。

该协议大多被设计于安全CNNs推

  1. [
  2. 2
  3. ,
  4. 27
  5. ,
  6. 30
  7. ,
  8. 36
  9. ]
  10. [2,27,30,36]
  11. [2,27,30,36],最近的许多工作也支持基于Transformer的模型
  12. [
  13. 10
  14. ,
  15. 24
  16. ,
  17. 26
  18. ,
  19. 35
  20. ,
  21. 38
  22. ,
  23. 40
  24. ]
  25. [10,24,26,35,38,40]
  26. [10,24,26,35,38,40]​,值得注意的是,这些安全Transformer模型大多都是交互式的,因此会导致巨大的通信开销和交互轮次,这里我们必须强调非交互式安全Transformer推理的重要性。
本文贡献:

本文中,我们提出了NEXUS,第一个secure transformer inference的非交互协议。通过NEXUS,C使用RNS-CKKS加密输入,S对FHE加密数据执行transformer。CKKS的SIMD技术被应用于批处理

  1. N
  2. =
  3. 2
  4. 15
  5. N=2^{15}
  6. N=215个数据,多项式近似可以用于处理非线性函数,比如GELUsoftmax,层归一化和argmax

NEXUS不需要对模型进行任何重训练与微调,且为了提高NEXUS的效率,我们提出了两种新颖的且基础的技术。

  • SIMD密文压缩与分解:该技术可以将2N个SIMD密文压缩为一个密文,然后可以使用4N个密文—明文乘法和替换将其解压回来。该技术可以大大减少客户端和服务器之间传输的密文数量,而不会为后续计算带来任何额外的开销。

在这里插入图片描述

  • SIMD槽折叠:在所有SIMD槽中计算关联函数f(),例如sum和max。结果值会自动的填充SIMD密文的槽,允许将其应用于原始密文的每个槽。

本文贡献总结如下:

  • secure transformer inference的第一个非交互协议
  • 用于密文打包的SIMD密文压缩与分解技术
  • SIMD槽折叠技术,以高效操作SIMD密文的槽
  • 综合的实现与评估

2 Preliminaries

符号系统描述如下:
NotationDescriptionNotationDescriptionCclientSserver

  1. E
  2. (
  3. )
  4. E(*)
  5. E(∗)encryption
  6. π
  7. (
  8. )
  9. \pi(*)
  10. π(∗)encoding
  11. E
  12. n
  13. c
  14. (
  15. )
  16. Enc(*)
  17. Enc(∗)encoding+encryption
  18. a
  19. ~
  20. \tilde{a}
  21. a~FHE ciphertext
  22. R
  23. o
  24. t
  25. L
  26. (
  27. )
  28. /
  29. R
  30. o
  31. t
  32. R
  33. (
  34. )
  35. RotL(*)/RotR(*)
  36. RotL(∗)/RotR(∗)左旋转和右旋转
  37. S
  38. u
  39. b
  40. s
  41. (
  42. )
  43. Subs(*)
  44. Subs(∗)替换操作
  45. S
  46. g
  47. n
  48. (
  49. )
  50. Sgn(*)
  51. Sgn(∗)sign操作
  52. L
  53. L
  54. L乘法深度
  55. N
  56. N'
  57. N′CKKS的环维数
  58. N
  59. N
  60. N
  61. N
  62. =
  63. N
  64. /
  65. 2
  66. N=N'/2
  67. N=N′/2
  68. A
  69. A
  70. A输入矩阵
  71. W
  72. W
  73. W权重矩阵

2.1 安全推理和威胁模型

安全推理是一个两方密码学协议,其可以在C和S之间进行模型推理,与此同时还可以保护两个参与方输入隐私。它的正式定义如下:

Definition 1:

针对两方参与者,其中

  1. S
  2. S
  3. S持有模型
  4. M
  5. M
  6. M,且
  7. C
  8. C
  9. C持有输入
  10. A
  11. A
  12. A的协议
  13. Π
  14. \Pi
  15. Π是安全推理协议,当且仅当以下条件满足时:

(1) 正确性: 该协议的最终输出是正确的推理结果

  1. M
  2. (
  3. A
  4. )
  5. M(A)
  6. M(A)。

(2) 安全性:

  1. V
  2. i
  3. e
  4. w
  5. C
  6. Π
  7. c
  8. S
  9. i
  10. m
  11. C
  12. (
  13. A
  14. ,
  15. o
  16. u
  17. t
  18. )
  19. View^{\Pi}_C\approx_c Sim_C(A,out)
  20. ViewCΠ​≈cSimC​(A,out),其中
  21. V
  22. i
  23. e
  24. w
  25. C
  26. Π
  27. View^{\Pi}_C
  28. ViewCΠ​表示协议
  29. Π
  30. \Pi
  31. Π执行期间
  32. C
  33. C
  34. C的视角,
  35. o
  36. u
  37. t
  38. out
  39. out表示推理的结果。
  40. V
  41. i
  42. e
  43. w
  44. S
  45. Π
  46. S
  47. S
  48. i
  49. m
  50. S
  51. (
  52. M
  53. )
  54. View^{\Pi}_S\approx_S Sim_S(M)
  55. ViewSΠ​≈SSimS​(M),其中
  56. V
  57. i
  58. e
  59. w
  60. S
  61. Π
  62. View^{\Pi}_S
  63. ViewSΠ​表示协议
  64. Π
  65. \Pi
  66. Π执行期间
  67. S
  68. S
  69. S​的视角。
  70. S
  71. i
  72. m
  73. Sim_*
  74. Sim∗​可以理解为理想状态下希望实体
  75. *
  76. ∗可以得到的信息。

假设

  1. C
  2. C
  3. C
  4. S
  5. S
  6. S为半诚实对手,其在遵守协议规范的同时也尽可能的在执行过程中手机额外的信息。且假设对手在计算上是有限的。

2.2 Transformer

这里简单介绍一下Transformerd。

图1是transformer的结构与工作流程。它将一个表示为矩阵的嵌入传递给注意层和前馈神经网络,最后根据最终对数最大值输出一个选择向量,且,LayerNorm层被应用于每个块之后。

在这里插入图片描述

transformer的结构和工作流程

Attention:

使用三个矩阵(

  1. W
  2. Q
  3. R
  4. n
  5. ×
  6. k
  7. ,
  8. W
  9. K
  10. R
  11. n
  12. ×
  13. k
  14. ,
  15. W
  16. V
  17. R
  18. n
  19. ×
  20. k
  21. W_Q\in\mathbb{R}^{n\times k},W_K\in\mathbb{R}^{n\times k},W_V\in\mathbb{R}^{n\times k}
  22. WQ​∈Rn×k,WK​∈Rn×k,WV​∈Rn×k)乘嵌入矩阵
  23. A
  24. R
  25. m
  26. ×
  27. n
  28. A\in \mathbb{R}^{m\times n}
  29. ARm×n,生成一个query矩阵
  30. Q
  31. =
  32. A
  33. W
  34. Q
  35. Q = A·W_Q
  36. Q=AWQ​,一个key矩阵
  37. K
  38. =
  39. A
  40. W
  41. K
  42. K=A·W_K
  43. K=AWK​和一个value矩阵
  44. V
  45. =
  46. A
  47. W
  48. V
  49. V=A·W_V
  50. V=AWV​。即对于Attention层的单元,transformer会学习到三个权重矩阵。

attention可以被表示为:

  1. A
  2. t
  3. t
  4. e
  5. n
  6. t
  7. i
  8. o
  9. n
  10. (
  11. Q
  12. ,
  13. K
  14. ,
  15. V
  16. )
  17. =
  18. S
  19. o
  20. f
  21. t
  22. m
  23. a
  24. x
  25. (
  26. Q
  27. K
  28. T
  29. k
  30. )
  31. V
  32. Attention(Q,K,V) = Softmax({QK^T\over{\sqrt k}})·V
  33. Attention(Q,K,V)=Softmax(kQKT​)⋅V
Layer normalization

该层的输入为

  1. a
  2. R
  3. n
  4. a\in \mathbb{R}^n
  5. aRn,均值和标准差分别为
  6. μ
  7. \mu
  8. μ和
  9. σ
  10. \sigma
  11. σ,则该层的输出
  12. y
  13. R
  14. n
  15. y\in\mathbb{R}^n
  16. yRn可以表示为:
  17. y
  18. i
  19. =
  20. γ
  21. x
  22. i
  23. μ
  24. σ
  25. +
  26. β
  27. y_i=\gamma·{x_i-\mu\over\sigma}+\beta
  28. yi​=γ⋅σxi​−μ​+β

其中,

  1. γ
  2. ,
  3. β
  4. R
  5. \gamma,\beta\in\mathbb{R}
  6. γ,β∈R​是两个超参数。
Feed-forward

全连接前馈网络层包含两个线性变换以及一个GELU激活函数:

  1. F
  2. e
  3. e
  4. d
  5. F
  6. o
  7. r
  8. w
  9. a
  10. r
  11. d
  12. (
  13. X
  14. )
  15. =
  16. G
  17. E
  18. L
  19. U
  20. (
  21. X
  22. W
  23. 1
  24. +
  25. b
  26. 1
  27. )
  28. W
  29. 2
  30. +
  31. b
  32. 2
  33. FeedForward(X)=GELU(XW_1+b_1W_2+b_2
  34. FeedForward(X)=GELU(XW1​+b1​)⋅W2​+b2

其中GELU函数计算如下:

  1. G
  2. E
  3. L
  4. U
  5. (
  6. x
  7. )
  8. =
  9. 1
  10. 2
  11. x
  12. (
  13. 1
  14. +
  15. e
  16. r
  17. f
  18. (
  19. x
  20. 2
  21. )
  22. )
  23. GELU(x)={1\over 2}x·(1+erf({x\over \sqrt 2}))
  24. GELU(x)=21x⋅(1+erf(2x​))

式中,高斯误差函数为

  1. e
  2. r
  3. f
  4. (
  5. x
  6. )
  7. =
  8. 2
  9. π
  10. 0
  11. x
  12. e
  13. t
  14. 2
  15. d
  16. t
  17. erf(x)={2\over\sqrt{\pi}}\int_0^xe^{-t^2}dt
  18. erf(x)=π​2​∫0xet2dt​。由于其良好的曲率和非单调性,它被用作激活函数。
Argmax

根据最终对数最大值输出一个选择向量

可以看到,只要我们能够使用FHE实现各个层的计算,就可以实现一个安全Transformer。

2.3 Fully Homomorphic Encryption

FHE可以对加密数据执行任意操作,故FHE是使得我们构建非交互式安全transformer推理得主要工具。RNS-CKKS属于级全同态加密,其可以支持L级深度的乘法。RNS-CKKS的明文和密文均是多项式环

  1. R
  2. Q
  3. =
  4. Z
  5. Q
  6. [
  7. X
  8. ]
  9. /
  10. (
  11. X
  12. N
  13. +
  14. 1
  15. )
  16. R_Q=\Z_Q[X]/(X^{N'}+1)
  17. RQ​=ZQ​[X]/(XN′+1)上的元素。其中
  18. Q
  19. =
  20. Π
  21. i
  22. =
  23. 0
  24. L
  25. q
  26. i
  27. Q=\Pi^L_{i=0}q_i
  28. Q=Πi=0L​qi​,且
  29. q
  30. i
  31. q_i
  32. qi​​之间互素。若密文的级别变得太低,则可以运行自举操作来刷新密文到高的级别,以允许更多的计算。

简单地说,自举即利用自同构

  1. R
  2. q
  3. 0
  4. R
  5. q
  6. 0
  7. ×
  8. R
  9. q
  10. 1
  11. ×
  12. .
  13. .
  14. .
  15. ×
  16. R
  17. q
  18. L
  19. R_{q_0}\cong R_{q_0}\times R_{q_1}\times ... \times R_{q_L}
  20. Rq0​​≅Rq0​​×Rq1​​×...×RqL​​,来将密文模从
  21. q
  22. 0
  23. q_0
  24. q0​提升到
  25. q
  26. L
  27. q_L
  28. qL​,以及对密文同态评估解密电路。若自举本身消耗K个级别,则刷新后的密文支持
  29. L
  30. K
  31. L-K
  32. LK个级深度的计算。

RNS-CKKS支持SIMD操作,其可以加密向量

  1. a
  2. R
  3. N
  4. a\in \R^N
  5. aRN到一个密文中,且批处理这些加密元素,而不引入其他操作。为了以SIMD格式加密,首先使用编码算法
  6. π
  7. (
  8. )
  9. \pi(*)
  10. π(∗)将向量
  11. a
  12. a
  13. a编码为一个
  14. R
  15. Q
  16. R_Q
  17. RQ​上的多项式,然后使用加密算法
  18. E
  19. (
  20. )
  21. E(*)
  22. E(∗)​加密该多项式。

在整篇文章中,我们使用

  1. E
  2. (
  3. )
  4. E(*)
  5. E(∗)表示加密多项式,使用
  6. E
  7. n
  8. c
  9. (
  10. )
  11. Enc(*)
  12. Enc(∗)表示以SIMD格式加密向量,即
  13. E
  14. n
  15. c
  16. (
  17. a
  18. )
  19. =
  20. E
  21. (
  22. π
  23. (
  24. a
  25. )
  26. )
  27. Enc(a)=E(\pi(a))
  28. Enc(a)=E(π(a)),其中
  29. a
  30. a
  31. a是一个向量。

一个特殊的FHE操作:

  1. c
  2. t
  3. S
  4. u
  5. b
  6. s
  7. (
  8. c
  9. t
  10. ,
  11. k
  12. )
  13. ct'\leftarrow Subs(ct,k)
  14. ct′←Subs(ct,k):替换操作,该操作以密文
  15. c
  16. t
  17. =
  18. E
  19. (
  20. p
  21. (
  22. x
  23. )
  24. )
  25. ct=E(p(x))
  26. ct=E(p(x))以及一个奇整数
  27. k
  28. k
  29. k作为输入,然后得到新的密文
  30. c
  31. t
  32. =
  33. E
  34. (
  35. p
  36. (
  37. x
  38. k
  39. )
  40. )
  41. ct'=E(p(x^k))
  42. ct′=E(p(xk))​​​​。

这里的

  1. S
  2. u
  3. b
  4. s
  5. (
  6. c
  7. t
  8. ,
  9. k
  10. )
  11. Subs(ct,k)
  12. Subs(ct,k)应该是一种密钥交换操作,可以描述如下:

已知密文:

  1. c
  2. t
  3. =
  4. (
  5. a
  6. (
  7. x
  8. )
  9. s
  10. (
  11. x
  12. )
  13. +
  14. e
  15. (
  16. x
  17. )
  18. +
  19. p
  20. (
  21. x
  22. )
  23. ,
  24. a
  25. (
  26. x
  27. )
  28. )
  29. ct=(-a(x)s(x)+e(x)+p(x),a(x))
  30. ct=(−a(x)s(x)+e(x)+p(x),a(x))

将该密文进行自同构操作:

  1. κ
  2. k
  3. (
  4. c
  5. t
  6. )
  7. =
  8. (
  9. a
  10. (
  11. x
  12. k
  13. )
  14. s
  15. (
  16. x
  17. k
  18. )
  19. +
  20. e
  21. (
  22. x
  23. k
  24. )
  25. +
  26. p
  27. (
  28. x
  29. k
  30. )
  31. ,
  32. a
  33. (
  34. x
  35. k
  36. )
  37. )
  38. \kappa_k(ct)=(-a(x^k)s(x^k)+e(x^k)+p(x^k),a(x^k))
  39. κk​(ct)=(−a(xk)s(xk)+e(xk)+p(xk),a(xk))。

然后得到用户提供的交换密钥:

  1. k
  2. e
  3. y
  4. =
  5. (
  6. a
  7. (
  8. x
  9. )
  10. s
  11. (
  12. x
  13. )
  14. +
  15. e
  16. (
  17. x
  18. )
  19. +
  20. P
  21. s
  22. (
  23. x
  24. k
  25. )
  26. ,
  27. a
  28. (
  29. x
  30. )
  31. )
  32. key = (-a(x)s(x)+e(x)+P·s(x^k),a(x))
  33. key=(−a(x)s(x)+e(x)+Ps(xk),a(x))

然后执行密钥交换操作:

  1. c
  2. t
  3. =
  4. (
  5. κ
  6. k
  7. (
  8. c
  9. t
  10. )
  11. [
  12. 0
  13. ]
  14. ,
  15. 0
  16. )
  17. +
  18. (
  19. P
  20. 1
  21. κ
  22. k
  23. (
  24. c
  25. t
  26. )
  27. [
  28. 1
  29. ]
  30. k
  31. e
  32. y
  33. )
  34. ct'=(\kappa_k(ct)[0],0)+(\lfloor P^{-1}·\kappa_k(ct)[1]·key\rceil)
  35. ct′=(κk​(ct)[0],0)+(⌊P−1⋅κk​(ct)[1]⋅key⌉)

此时新的密文即

  1. c
  2. t
  3. =
  4. (
  5. a
  6. (
  7. x
  8. )
  9. s
  10. (
  11. x
  12. )
  13. +
  14. e
  15. (
  16. x
  17. )
  18. +
  19. p
  20. (
  21. x
  22. k
  23. )
  24. ,
  25. a
  26. (
  27. x
  28. )
  29. )
  30. ct'=(-a(x)s(x)+e(x)+p(x^k),a(x))
  31. ct′=(−a(x)s(x)+e(x)+p(xk),a(x))

注意,这里的

  1. a
  2. (
  3. x
  4. )
  5. ,
  6. e
  7. (
  8. x
  9. )
  10. a(x),e(x)
  11. a(x),e(x)是变化的,也就是不同的密文中,这是不同的。

2.4 Homomorphic sign function

由于FHE仅支持线性函数,所以为了实现在FHE下对加密数据的比较,本文需要利用sign函数的多项式近似,即:

  1. s
  2. i
  3. g
  4. n
  5. (
  6. x
  7. )
  8. =
  9. f
  10. d
  11. f
  12. (
  13. g
  14. d
  15. g
  16. (
  17. x
  18. )
  19. )
  20. =
  21. {
  22. 1
  23. (
  24. 1
  25. x
  26. 2
  27. α
  28. )
  29. 0
  30. (
  31. x
  32. =
  33. 0
  34. )
  35. 1
  36. (
  37. 2
  38. α
  39. x
  40. 1
  41. )
  42. sign(x)=f^{d_f}(g^{d_g}(x))=\begin{cases} -1 &(-1\leq x \leq -2^{-\alpha}) \\ 0 &(x = 0) \\ 1 &(2^{-\alpha}\leq x \leq 1) \\ \end{cases}
  43. sign(x)=fdf​(gdg​(x))=⎩⎨⎧​−101​(−1x≤−2−α)(x=0)(2−α≤x1)​

其中,

  1. f
  2. (
  3. )
  4. ,
  5. g
  6. (
  7. )
  8. f(),g()
  9. f(),g()为两个多项式,
  10. d
  11. f
  12. ,
  13. d
  14. g
  15. d_f,d_g
  16. df​,dg​为这两个多项式重复的次数。注意,该多项式近似要求输入x取值范围为[-1,1]。因此,对任何输入
  17. a
  18. [
  19. a
  20. m
  21. i
  22. n
  23. ,
  24. a
  25. m
  26. a
  27. x
  28. ]
  29. a\in [a_{min},a_{max}]
  30. a∈[amin​,amax​]都需要进行归一化处理:
  31. x
  32. :
  33. =
  34. a
  35. /
  36. m
  37. a
  38. x
  39. {
  40. a
  41. m
  42. a
  43. x
  44. ,
  45. a
  46. m
  47. i
  48. n
  49. }
  50. x := a/max\{|a_{max}|,|a_{min}|\}
  51. x:=a/max{∣amax​∣,∣amin​∣}

这里,我们使用Sgn()表示在SIMD密文上同时运行归一化与sign近似函数:

  1. b
  2. ~
  3. S
  4. g
  5. n
  6. (
  7. a
  8. ~
  9. )
  10. :
  11. b
  12. i
  13. =
  14. f
  15. d
  16. f
  17. (
  18. g
  19. d
  20. g
  21. (
  22. a
  23. i
  24. m
  25. a
  26. x
  27. {
  28. a
  29. m
  30. a
  31. x
  32. ,
  33. a
  34. m
  35. i
  36. n
  37. }
  38. )
  39. )
  40. i
  41. [
  42. N
  43. ]
  44. \widetilde b\leftarrow Sgn(\widetilde a): b_i =f^{d_f}(g^{d_g}({a_i\over{max\{|a_{max}|,|a_{min}|\}}})) \ \ \forall i \in [N]
  45. bSgn(a):bi​=fdf​(gdg​(max{∣amax​∣,∣amin​∣}ai​​)) i∈[N]

在本文的实现中,使用的是9次的

  1. f
  2. (
  3. )
  4. f(*)
  5. f(∗)和
  6. g
  7. (
  8. )
  9. g(*)
  10. g(∗),且设计
  11. α
  12. =
  13. 16
  14. ,
  15. d
  16. f
  17. =
  18. 2
  19. ,
  20. d
  21. g
  22. =
  23. 2
  24. \alpha=16,d_f=2,d_g=2
  25. α=16,df​=2,dg​=2​​,然后使用BSGS算法来评估多项式。

3 Basic design

本节介绍NEXUS的基础设计,即在不优化的情况下实现上述transformer的每一层计算,在之后的章节中会对本节的算法进行优化。

3.1 Attention

3.1.1 Matrix multiplication(ciphertext-plaintext)

在Attention层的第一个MatMul步骤,我们需要计算三个密文—明文矩阵乘法:

  1. Q
  2. :
  3. =
  4. A
  5. W
  6. Q
  7. ;
  8. K
  9. :
  10. =
  11. A
  12. W
  13. K
  14. ;
  15. V
  16. :
  17. =
  18. A
  19. W
  20. V
  21. ;
  22. Q:=A·W_Q;\\ K:=A·W_K;\\ V:=A·W_V;
  23. Q:=AWQ​;K:=AWK​;V:=AWV​;

其中A是我们的输入,

  1. W
  2. Q
  3. ,
  4. W
  5. K
  6. ,
  7. W
  8. V
  9. W_Q,W_K,W_V
  10. WQ​,WK​,WV​是三个给定矩阵,下面以
  11. A
  12. W
  13. Q
  14. A·W_Q
  15. AWQ​为例来描述这个密文—明文矩阵乘法,该过程同样适用于
  16. W
  17. K
  18. W_K
  19. WK​和
  20. W
  21. V
  22. W_V
  23. WV​。

给定矩阵

  1. A
  2. R
  3. m
  4. ×
  5. n
  6. A\in \mathbb{R}^{m\times n}
  7. ARm×n和矩阵
  8. W
  9. Q
  10. R
  11. n
  12. ×
  13. k
  14. W_Q\in \mathbb{R}^{n\times k}
  15. WQ​∈Rn×k,计算矩阵
  16. Q
  17. :
  18. =
  19. A
  20. W
  21. Q
  22. Q:=A·W_Q
  23. Q:=AWQ​。

  1. a
  2. i
  3. ,
  4. j
  5. R
  6. a_{i,j}\in \mathbb{R}
  7. ai,j​∈R表示矩阵A的第i行第j列的元素,
  8. w
  9. j
  10. R
  11. k
  12. w_j\in \mathbb{R}^k
  13. wj​∈Rk表示矩阵
  14. W
  15. Q
  16. W_Q
  17. WQ​的第j行的元素向量,
  18. q
  19. i
  20. R
  21. k
  22. q_i\in \mathbb{R}^k
  23. qi​∈Rk是矩阵
  24. Q
  25. Q
  26. Q的第i行的元素向量,即:
  27. q
  28. i
  29. =
  30. j
  31. [
  32. n
  33. ]
  34. a
  35. i
  36. ,
  37. j
  38. w
  39. j
  40. q_i=\sum_{j\in [n]}a_{i,jw_j
  41. qi​=j∈[n]∑​ai,j​⋅wj

因此,上述过程可以描述为,C将A中的每个元素

  1. a
  2. i
  3. ,
  4. j
  5. a_{i,j}
  6. ai,j​均单独加密为密文发送给S,然后S同态评估MatrixMul,一个演示的示例如下:

在这里插入图片描述

图2 SIMD-based matrix multiplication


在上述描述中,C需要发送

  1. m
  2. ×
  3. n
  4. m\times n
  5. m×n个密文给S,从某一方面来说,这种开销是比较大的,因此本文在第4节提出一种算法可以将如此类型的
  6. m
  7. ×
  8. n
  9. m\times n
  10. m×n个密文压缩为
  11. m
  12. ×
  13. n
  14. N
  15. m\times n\over{N'}
  16. N′m×n​个密文,即一个密文中存放
  17. N
  18. N'
  19. N′个元素,随后S可以将压缩后的密文恢复为压缩前的密文形式。

3.1.2 Matrix multiplication(ciphertext-ciphertext)

经过上述步骤后,可以获得加密的

  1. (
  2. Q
  3. ,
  4. K
  5. ,
  6. V
  7. )
  8. (Q,K,V)
  9. (Q,K,V),在Attention的第二个MatMul块,S需要计算
  10. Q
  11. K
  12. T
  13. Q·K^T
  14. QKT。很明显,现在Q的每一行和
  15. K
  16. T
  17. K^T
  18. KT的每一列已经以SIMD的形式加密为
  19. E
  20. n
  21. c
  22. (
  23. q
  24. )
  25. ,
  26. E
  27. n
  28. c
  29. (
  30. k
  31. T
  32. )
  33. Enc(q) , Enc(k^T)
  34. Enc(q),Enc(kT)。如果S可以计算
  35. E
  36. n
  37. c
  38. (
  39. q
  40. )
  41. Enc(q)
  42. Enc(q)和
  43. E
  44. n
  45. c
  46. (
  47. k
  48. T
  49. )
  50. Enc(k^T)
  51. Enc(kT)的内积,则可以获得
  52. Q
  53. K
  54. T
  55. Q·K^T
  56. QKT的加密结果。

由于SIMD,S可以很容易的计算得到Enc(u),其中

  1. u
  2. =
  3. [
  4. u
  5. 0
  6. ,
  7. .
  8. .
  9. .
  10. ,
  11. u
  12. k
  13. 1
  14. ]
  15. u=[u_0,...,u_{k-1}]
  16. u=[u0​,...,uk1​]是q
  17. k
  18. T
  19. k^T
  20. kT的元素级的乘法,现在为了计算内积,S仅仅需要在SIMD下计算
  21. s
  22. :
  23. =
  24. i
  25. =
  26. 0
  27. k
  28. 1
  29. u
  30. i
  31. s:=\sum_{i=0}^{k-1}u_i
  32. s:=∑i=0k1ui​。

为了计算这个和,我们可以通过k-1次的旋转及加和来计算,从而获得密文Enc([s,s,…,s]),但是本文提出了’QuickSum’算法,该算法仅仅需要logk次旋转就可达到这个目标。'QuickSum’算法在第5节介绍。


进一步,S将计算得到每一行的的m个密文组合到单一密文中,计算方法如下:

  1. i
  2. =
  3. 0
  4. m
  5. 1
  6. (
  7. E
  8. n
  9. c
  10. (
  11. s
  12. i
  13. ,
  14. s
  15. i
  16. ,
  17. .
  18. .
  19. .
  20. ,
  21. s
  22. i
  23. )
  24. b
  25. i
  26. )
  27. \sum_{i=0}^{m-1}(Enc(s_i,s_i,...,s_ib_i)
  28. i=0m1​(Enc(si​,si​,...,si​)⋅bi​)

其中

  1. b
  2. i
  3. b_i
  4. bi​仅在第i个槽的位置是1,其余槽均为0

易知,输出矩阵为

  1. A
  2. R
  3. m
  4. ×
  5. m
  6. A\in \mathbb{R}^{m\times m}
  7. ARm×m,其中A的每行向量以SIMD形式加密,将该结果作为Softmax的输入。
3.1.3 Softmax

Softmax函数需要被应用于A的每一行,该函数评估如下:

  1. y
  2. i
  3. =
  4. e
  5. x
  6. p
  7. (
  8. a
  9. i
  10. a
  11. m
  12. a
  13. x
  14. )
  15. j
  16. =
  17. 0
  18. m
  19. 1
  20. e
  21. x
  22. p
  23. (
  24. a
  25. j
  26. a
  27. m
  28. a
  29. x
  30. )
  31. (1)
  32. y_i={exp(a_i-a_{max})\over{\sum_{j=0}^{m-1}exp(a_j-a_{max})}}\tag{1}
  33. yi​=∑j=0m1exp(aj​−amax​)exp(ai​−amax​)​(1)

其中

  1. a
  2. m
  3. a
  4. x
  5. =
  6. m
  7. a
  8. x
  9. (
  10. a
  11. 0
  12. ,
  13. .
  14. .
  15. .
  16. ,
  17. a
  18. m
  19. 1
  20. )
  21. a_{max}=max(a_0,...,a_{m-1})
  22. amax​=max(a0​,...,am1​),从而确保指数函数的每个输入
  23. (
  24. a
  25. j
  26. a
  27. m
  28. a
  29. x
  30. )
  31. (a_j-a_{max})
  32. (aj​−amax​)​是非正数,保证稳定性。

本文提出了’QuickMax’算法,该算法以

  1. E
  2. n
  3. c
  4. (
  5. [
  6. a
  7. 0
  8. ,
  9. .
  10. .
  11. .
  12. ,
  13. a
  14. m
  15. 1
  16. ]
  17. )
  18. Enc([a_0,...,a_{m-1}])
  19. Enc([a0​,...,am1​])为输入,并输出
  20. E
  21. n
  22. c
  23. (
  24. [
  25. a
  26. m
  27. a
  28. x
  29. ,
  30. .
  31. .
  32. .
  33. ,
  34. a
  35. m
  36. a
  37. x
  38. ]
  39. )
  40. Enc([a_{max},...,a_{max}])
  41. Enc([amax​,...,amax​])​,且,该算法仅需要logm-1Sgn操作与logm次旋转操作。该算法描述在第5节。

给定

  1. E
  2. n
  3. c
  4. (
  5. [
  6. a
  7. 0
  8. ,
  9. .
  10. .
  11. .
  12. ,
  13. a
  14. m
  15. 1
  16. ]
  17. )
  18. Enc([a_0,...,a_{m-1}])
  19. Enc([a0​,...,am1​])和
  20. E
  21. n
  22. c
  23. (
  24. [
  25. a
  26. m
  27. a
  28. x
  29. ,
  30. .
  31. .
  32. .
  33. ,
  34. a
  35. m
  36. a
  37. x
  38. ]
  39. )
  40. Enc([a_{max},...,a_{max}])
  41. Enc([amax​,...,amax​])。

S进行如下步骤计算:

  1. E
  2. n
  3. c
  4. (
  5. [
  6. a
  7. 0
  8. ,
  9. .
  10. .
  11. .
  12. ,
  13. a
  14. m
  15. 1
  16. ]
  17. )
  18. =
  19. E
  20. n
  21. c
  22. (
  23. [
  24. a
  25. 0
  26. ,
  27. .
  28. .
  29. .
  30. ,
  31. a
  32. m
  33. 1
  34. ]
  35. )
  36. E
  37. n
  38. c
  39. (
  40. [
  41. a
  42. m
  43. a
  44. x
  45. ,
  46. .
  47. .
  48. .
  49. ,
  50. a
  51. m
  52. a
  53. x
  54. ]
  55. )
  56. Enc([a'_0,...,a'_{m-1}])=Enc([a_0,...,a_{m-1}])-Enc([a_{max},...,a_{max}])
  57. Enc([a0′​,...,am1′​])=Enc([a0​,...,am1​])−Enc([amax​,...,amax​])

然后根据如下公式计算指数函数,这里使用泰勒展开:

  1. e
  2. x
  3. p
  4. (
  5. x
  6. )
  7. (
  8. 1
  9. +
  10. x
  11. 2
  12. r
  13. )
  14. 2
  15. r
  16. ,
  17. x
  18. 0
  19. exp(x)\approx(1+{x\over{2^r}})^{2^r},x\leq 0
  20. exp(x)≈(1+2rx​)2r,x0

其中

  1. r
  2. =
  3. 6
  4. r=6
  5. r=6,此时平均误差被限制在
  6. 1
  7. 0
  8. 5
  9. 10^{-5}
  10. 105,即SSIMD格式计算指数函数:
  11. E
  12. n
  13. c
  14. (
  15. e
  16. 0
  17. ,
  18. .
  19. .
  20. .
  21. ,
  22. e
  23. m
  24. 1
  25. )
  26. =
  27. e
  28. x
  29. p
  30. (
  31. E
  32. n
  33. c
  34. (
  35. [
  36. a
  37. 0
  38. ,
  39. .
  40. .
  41. .
  42. ,
  43. a
  44. m
  45. 1
  46. ]
  47. )
  48. )
  49. Enc(e_0,...,e_{m-1})=exp(Enc([a'_0,...,a'_{m-1}]))
  50. Enc(e0​,...,em1​)=exp(Enc([a0′​,...,am1′​]))

很明显,这里

  1. e
  2. j
  3. =
  4. e
  5. x
  6. p
  7. (
  8. a
  9. j
  10. )
  11. e_j=exp(a'_j)
  12. ej​=exp(aj′​)。

接下来,S应用

  1. Q
  2. u
  3. i
  4. c
  5. k
  6. S
  7. u
  8. m
  9. (
  10. )
  11. QuickSum(*)
  12. QuickSum(∗)算法来获得
  13. E
  14. n
  15. c
  16. (
  17. [
  18. j
  19. =
  20. 0
  21. m
  22. 1
  23. e
  24. j
  25. ,
  26. .
  27. .
  28. .
  29. ,
  30. j
  31. =
  32. 0
  33. m
  34. 1
  35. e
  36. j
  37. ]
  38. )
  39. Enc([\sum^{m-1}_{j=0}e_j,...,\sum^{m-1}_{j=0}e_j])
  40. Enc([∑j=0m1ej​,...,∑j=0m1ej​])​。

进一步的,S使用文献[21,24]中的Goldschmidt除法算法来计算:

  1. E
  2. n
  3. c
  4. (
  5. y
  6. 0
  7. ,
  8. .
  9. .
  10. .
  11. ,
  12. y
  13. m
  14. 1
  15. )
  16. =
  17. E
  18. n
  19. c
  20. (
  21. e
  22. 0
  23. ,
  24. .
  25. .
  26. .
  27. ,
  28. e
  29. m
  30. 1
  31. )
  32. E
  33. n
  34. c
  35. (
  36. [
  37. j
  38. =
  39. 0
  40. m
  41. 1
  42. e
  43. j
  44. ,
  45. .
  46. .
  47. .
  48. ,
  49. j
  50. =
  51. 0
  52. m
  53. 1
  54. e
  55. j
  56. ]
  57. )
  58. Enc(y_0,...,y_{m-1})={Enc(e_0,...,e_{m-1})\over Enc([\sum^{m-1}_{j=0}e_j,...,\sum^{m-1}_{j=0}e_j])}
  59. Enc(y0​,...,ym1​)=Enc([∑j=0m1ej​,...,∑j=0m1ej​])Enc(e0​,...,em1​)​

Softmax算法的详细描述如算法1所示:

在这里插入图片描述

3.1.4 Matrix multiplication(ciphertext-ciphertext)

这里是Attention的最后一个MatMul块,该块的计算原理同3.1.2节完全一致。

3.2 Layer normalization

本文的归一化表示如下(但是不太清楚这个归一化使用的是什么计算公式):

在这里插入图片描述

3.3 Feed forward

前馈网络层涉及到两个矩阵乘法以及一个GELU。矩阵乘法如上文所述来计算。GELU可以使用下述分段多项式来近似,当输入

  1. x
  2. [
  3. 60
  4. ,
  5. 60
  6. ]
  7. x\in [-60,60]
  8. x∈[−60,60],则可以确保误差在
  9. 1
  10. 0
  11. 3
  12. 10^{-3}
  13. 103内。
  14. G
  15. E
  16. L
  17. U
  18. (
  19. x
  20. )
  21. =
  22. {
  23. 0
  24. (
  25. x
  26. 4
  27. )
  28. P
  29. (
  30. x
  31. )
  32. =
  33. i
  34. =
  35. 0
  36. i
  37. =
  38. 3
  39. c
  40. i
  41. x
  42. i
  43. (
  44. 4
  45. <
  46. x
  47. 1.95
  48. )
  49. Q
  50. (
  51. x
  52. )
  53. =
  54. i
  55. =
  56. 0
  57. i
  58. =
  59. 6
  60. d
  61. i
  62. x
  63. i
  64. (
  65. 1.95
  66. <
  67. x
  68. 3
  69. )
  70. x
  71. (
  72. x
  73. >
  74. 3
  75. )
  76. GELU(x)=\in \begin{cases} 0 &(x\leq -4) \\ P(x)=\sum_{i=0}^{i=3}c_ix^i &(-4<x\leq -1.95) \\ Q(x)=\sum_{i=0}^{i=6}d_ix^i &(-1.95<x\leq 3) \\ x &(x>3) \end{cases}
  77. GELU(x)=∈⎩⎨⎧​0P(x)=∑i=0i=3cixiQ(x)=∑i=0i=6dixix​(x≤−4)(−4<x≤−1.95)(−1.95<x3)(x>3)​

首先,使用Sgn操作获得四个加密bit:

  1. b
  2. 0
  3. ,
  4. b
  5. 1
  6. ,
  7. b
  8. 2
  9. ,
  10. b
  11. 3
  12. b_0,b_1,b_2,b_3
  13. b0​,b1​,b2​,b3​,当且仅当输入x属于第i段时,
  14. b
  15. i
  16. =
  17. 1
  18. b_i=1
  19. bi​=1,否则
  20. b
  21. i
  22. =
  23. 0
  24. b_i=0
  25. bi​=0,如此,GELU(x)函数可以表示为:
  26. G
  27. E
  28. L
  29. U
  30. (
  31. x
  32. )
  33. :
  34. =
  35. b
  36. 0
  37. 0
  38. +
  39. b
  40. 1
  41. P
  42. (
  43. x
  44. )
  45. +
  46. b
  47. 2
  48. Q
  49. (
  50. x
  51. )
  52. +
  53. b
  54. 3
  55. x
  56. GELU(x):=b_0·0+b_1·P(x)+b_2·Q(x)+b_3·x
  57. GELU(x):=b0​⋅0+b1​⋅P(x)+b2​⋅Q(x)+b3​⋅x​。

完整的Secure GELU算法可以表示如下:

在这里插入图片描述

3.4 Argmax

transformer最终的输出应该是一个选择向量

  1. E
  2. n
  3. c
  4. (
  5. [
  6. b
  7. 0
  8. ,
  9. .
  10. .
  11. .
  12. ,
  13. b
  14. m
  15. 1
  16. ]
  17. )
  18. Enc([b_0,...,b_{m-1}])
  19. Enc([b0​,...,bm1​]),其中
  20. b
  21. i
  22. =
  23. 1
  24. i
  25. f
  26. a
  27. i
  28. =
  29. m
  30. a
  31. x
  32. (
  33. a
  34. 0
  35. ,
  36. .
  37. .
  38. .
  39. ,
  40. a
  41. m
  42. 1
  43. )
  44. b_i=1 \ if \ a_i=max(a_0,...,a_{m-1})
  45. bi​=1 if ai​=max(a0​,...,am1​),其他情况下
  46. b
  47. i
  48. =
  49. 0
  50. b_i=0
  51. bi​=0​。

因此,本文的Secure Argmax算法如下:

在这里插入图片描述

3.5 Placement of bootstrapping

由于bootstrapping操作是昂贵的,因此合理的放置bootstrapping的位置是至关重要的。

在这里插入图片描述

图4 Placement of bootstrapping for a BERT-base transformer

4. SIMD密文的压缩和分解

假设C想要发送N’个密文给S,且每个密文以SIMD方式加密N个相同的值,Enc([

  1. a
  2. 0
  3. ,
  4. .
  5. .
  6. .
  7. ,
  8. a
  9. 0
  10. a_0,...,a_0
  11. a0​,...,a0​]),…,Enc([
  12. a
  13. N
  14. 1
  15. ,
  16. .
  17. .
  18. .
  19. ,
  20. a
  21. N
  22. 1
  23. a_{N'-1},...,a_{N'-1}
  24. aN′−1​,...,aN′−1​​​])。

SIMD密文的压缩算法

C将向量[

  1. a
  2. 0
  3. ,
  4. a
  5. 1
  6. ,
  7. .
  8. .
  9. .
  10. ,
  11. a
  12. N
  13. 1
  14. a_0,a_1,...,a_{N'-1}
  15. a0​,a1​,...,aN′−1​]的各个元素打包到一个多项式的系数中,即:
  16. p
  17. (
  18. x
  19. )
  20. =
  21. a
  22. 0
  23. +
  24. a
  25. 1
  26. x
  27. +
  28. a
  29. 2
  30. x
  31. 2
  32. +
  33. .
  34. .
  35. .
  36. +
  37. a
  38. N
  39. 1
  40. x
  41. N
  42. 1
  43. p(x)=a_0+a_1x+a_2x^2+...+a_{N'-1}x^{N'-1}
  44. p(x)=a0​+a1​x+a2​x2+...+aN′−1​xN′−1

然后将该多项式加密

  1. p
  2. ~
  3. 0
  4. =
  5. E
  6. (
  7. p
  8. (
  9. x
  10. )
  11. )
  12. \widetilde p_0=E(p(x))
  13. p0​=E(p(x))发送给S

然后S可以对密文

  1. p
  2. ~
  3. 0
  4. \widetilde p_0
  5. p0​分解从而得到压缩前
  6. N
  7. N'
  8. N′个SIMD密文。

S分解密文

  1. p
  2. ~
  3. 0
  4. \widetilde p_0
  5. p0​过程如下:

SIMD密文的分解算法:

(1)执行

  1. S
  2. u
  3. b
  4. s
  5. (
  6. p
  7. ~
  8. 0
  9. ,
  10. N
  11. +
  12. 1
  13. )
  14. Subs(\widetilde p_0, N'+1)
  15. Subs(p​0​,N′+1)返回:
  16. E
  17. (
  18. a
  19. 0
  20. +
  21. a
  22. 1
  23. x
  24. N
  25. +
  26. 1
  27. +
  28. a
  29. 2
  30. x
  31. (
  32. N
  33. +
  34. 1
  35. )
  36. 2
  37. +
  38. .
  39. .
  40. .
  41. +
  42. a
  43. N
  44. 1
  45. x
  46. (
  47. N
  48. +
  49. 1
  50. )
  51. N
  52. 1
  53. )
  54. =
  55. E
  56. (
  57. a
  58. 0
  59. +
  60. a
  61. 1
  62. (
  63. x
  64. )
  65. +
  66. a
  67. 2
  68. (
  69. x
  70. )
  71. 2
  72. )
  73. +
  74. .
  75. .
  76. .
  77. +
  78. a
  79. N
  80. 1
  81. (
  82. x
  83. )
  84. N
  85. 1
  86. )
  87. E(a_0+a_1x^{N'+1}+a_2x^{(N'+1)^2}+...+a_{N'-1}x^{(N'+1)^{N'-1}}) \\ =E(a_0+a_1(-x)+a_2(-x)^2)+...+a_{N'-1}(-x)^{N'-1})
  88. E(a0​+a1xN′+1+a2x(N′+1)2+...+aN′−1x(N′+1)N′−1)=E(a0​+a1​(−x)+a2​(−x)2)+...+aN′−1​(−x)N′−1)

注意,

  1. x
  2. N
  3. +
  4. 1
  5. 0
  6. (
  7. m
  8. o
  9. d
  10. x
  11. N
  12. +
  13. 1
  14. )
  15. x^{N'}+1 \equiv 0 \ \ (mod \ \ x^{N'} + 1)
  16. xN′+10 (mod xN′+1),因此
  17. x
  18. N
  19. +
  20. 1
  21. =
  22. x
  23. N
  24. x
  25. =
  26. x
  27. (
  28. m
  29. o
  30. d
  31. x
  32. N
  33. +
  34. 1
  35. )
  36. x^{N'+1} = x^{N'} * x = -x \ (mod \ \ x^{N'}+1)
  37. xN′+1=xN′∗x=−x (mod xN′+1)​,这里的
  38. N
  39. N’
  40. N’也就是分圆环的次数。

(2)执行

  1. p
  2. ~
  3. 0
  4. +
  5. S
  6. u
  7. b
  8. s
  9. (
  10. p
  11. ~
  12. 0
  13. ,
  14. N
  15. +
  16. 1
  17. )
  18. \widetilde p_0+Subs(\widetilde p_0,N'+1)
  19. p​0​+Subs(p​0​,N′+1)​操作,移除p(x)的所有奇数项。
  20. a
  21. 0
  22. +
  23. a
  24. 1
  25. x
  26. +
  27. a
  28. 2
  29. x
  30. 2
  31. +
  32. .
  33. .
  34. .
  35. +
  36. a
  37. N
  38. 1
  39. x
  40. N
  41. +
  42. 1
  43. +
  44. a
  45. 0
  46. +
  47. a
  48. 1
  49. (
  50. x
  51. )
  52. +
  53. a
  54. 2
  55. (
  56. x
  57. )
  58. 2
  59. )
  60. +
  61. .
  62. .
  63. .
  64. +
  65. a
  66. N
  67. 1
  68. (
  69. x
  70. )
  71. N
  72. 1
  73. =
  74. a
  75. 0
  76. +
  77. 0
  78. x
  79. +
  80. a
  81. 2
  82. x
  83. 2
  84. +
  85. .
  86. .
  87. .
  88. +
  89. a
  90. N
  91. 2
  92. x
  93. N
  94. 2
  95. +
  96. 0
  97. x
  98. N
  99. 1
  100. a_0+a_1x+a_2x^2+...+a_{N'-1}x^{N'+1} \\+ a_0+a_1(-x)+a_2(-x)^2)+...+a_{N'-1}(-x)^{N'-1}\\= a_0+0x+a_2x^2+...+a_{N'-2}x^{N'-2}+0x^{N'-1}
  101. a0​+a1x+a2x2+...+aN′−1xN′+1+a0​+a1​(−x)+a2​(−x)2)+...+aN′−1​(−x)N′−1=a0​+0x+a2x2+...+aN′−2xN′−2+0xN′−1

(3)通过

  1. l
  2. o
  3. g
  4. N
  5. log N'
  6. logN′次
  7. S
  8. u
  9. b
  10. s
  11. (
  12. )
  13. Subs()
  14. Subs()操作,S可以提取得到密文:
  15. E
  16. (
  17. a
  18. 0
  19. +
  20. 0
  21. x
  22. 1
  23. +
  24. 0
  25. x
  26. 2
  27. +
  28. .
  29. .
  30. .
  31. +
  32. 0
  33. x
  34. N
  35. 1
  36. )
  37. E(a_0+0x^1+0x^2+...+0x^{N'-1})
  38. E(a0​+0x1+0x2+...+0xN′−1),实际上,这就是密文Enc([
  39. a
  40. 0
  41. ,
  42. a
  43. 0
  44. ,
  45. .
  46. .
  47. .
  48. ,
  49. a
  50. 0
  51. a_0,a_0,...,a_0
  52. a0​,a0​,...,a0​])。完整的操作流程如下:

在这里插入图片描述

类似地,为了提取E(

  1. a
  2. 1
  3. +
  4. 0
  5. x
  6. 1
  7. +
  8. .
  9. .
  10. .
  11. +
  12. 0
  13. x
  14. N
  15. 1
  16. a_1+0x^1+...+0x^{N'-1}
  17. a1​+0x1+...+0xN′−1),S应该左旋明文多项式p(x)一个单位,通过乘以
  18. x
  19. 1
  20. x^{-1}
  21. x−1,然后再次执行上述的提取过程。通过执行
  22. N
  23. N‘
  24. N‘次该提取过程,S可以获得向量[
  25. a
  26. 0
  27. ,
  28. a
  29. 1
  30. ,
  31. .
  32. .
  33. .
  34. ,
  35. a
  36. N
  37. 1
  38. a_0,a_1,...,a_{N'-1}
  39. a0​,a1​,...,aN′−1​​]中每个元素的单独SIMD格式加密。

然而上述过程需要执行

  1. (
  2. N
  3. l
  4. o
  5. g
  6. N
  7. )
  8. (N'·logN')
  9. (N′⋅logN′)次
  10. S
  11. u
  12. b
  13. s
  14. (
  15. )
  16. Subs()
  17. Subs()操作。对比之下,本文提出一种算法,可以实现相同的目标,但是仅需要
  18. 2
  19. N
  20. 2N'
  21. 2N′次
  22. S
  23. u
  24. b
  25. s
  26. (
  27. )
  28. Subs()
  29. Subs()操作。该算法可以简单地描述如下:

在这里插入图片描述


算法5是Secure Decompression的详细描述:

在这里插入图片描述

下面提供上述分解操作的理论证明:

Theorem 1:

仅有常数项的多项式的加密E(

  1. a
  2. s
  3. +
  4. 0
  5. x
  6. 1
  7. +
  8. .
  9. .
  10. .
  11. +
  12. 0
  13. x
  14. N
  15. 1
  16. a_s+0x^1+...+0x^{N'-1}
  17. as​+0x1+...+0xN′−1)是向量[
  18. a
  19. s
  20. ,
  21. a
  22. s
  23. ,
  24. .
  25. .
  26. .
  27. ,
  28. a
  29. s
  30. a_s,a_s,...,a_s
  31. as​,as​,...,as​]的加密Enc([
  32. a
  33. s
  34. ,
  35. a
  36. s
  37. ,
  38. .
  39. .
  40. .
  41. ,
  42. a
  43. s
  44. a_s,a_s,...,a_s
  45. as​,as​,...,as​​])。

在这里插入图片描述


4.1 Application to matrix multiplication

压缩分解技术可以自然地应用于MatrixMul,此外,基于下面的观察结果,本文进一步优化了矩阵乘法,观察到在transformer推理过程中,对于不同输入的矩阵

  1. A
  2. R
  3. m
  4. ×
  5. n
  6. A\in \R^{m\times n}
  7. ARm×n需要乘以相同的矩阵
  8. W
  9. R
  10. n
  11. ×
  12. k
  13. W\in \R^{n\times k}
  14. WRn×k​。

  1. A
  2. =
  3. [
  4. a
  5. 0
  6. ,
  7. .
  8. .
  9. .
  10. ,
  11. a
  12. n
  13. 1
  14. ]
  15. A = [a_0,...,a_{n-1}]
  16. A=[a0​,...,an1​],其中
  17. a
  18. i
  19. R
  20. m
  21. a_i\in \R^m
  22. ai​∈Rm表示矩阵
  23. A
  24. A
  25. A的第
  26. i
  27. i
  28. i行。假设SC需要生成t个响应词,即有t个输入矩阵:
  29. A
  30. 0
  31. =
  32. [
  33. a
  34. 0
  35. ,
  36. 0
  37. ,
  38. a
  39. 0
  40. ,
  41. 1
  42. ,
  43. .
  44. .
  45. .
  46. ,
  47. a
  48. 0
  49. ,
  50. n
  51. 1
  52. ]
  53. A
  54. 1
  55. =
  56. [
  57. a
  58. 1
  59. ,
  60. 0
  61. ,
  62. a
  63. 1
  64. ,
  65. 1
  66. ,
  67. .
  68. .
  69. .
  70. ,
  71. a
  72. 1
  73. ,
  74. n
  75. 1
  76. ]
  77. .
  78. .
  79. .
  80. A
  81. 0
  82. =
  83. [
  84. a
  85. t
  86. 1
  87. ,
  88. 0
  89. ,
  90. a
  91. t
  92. 1
  93. ,
  94. 1
  95. ,
  96. .
  97. .
  98. .
  99. ,
  100. a
  101. t
  102. 1
  103. ,
  104. n
  105. 1
  106. ]
  107. A_0=[a_{0,0},a_{0,1},...,a_{0,n-1}] \\ A_1=[a_{1,0},a_{1,1},...,a_{1,n-1}] \\ ... \\ A_0=[a_{t-1,0},a_{t-1,1},...,a_{t-1,n-1}]
  108. A0​=[a0,0​,a0,1​,...,a0,n1​]A1​=[a1,0​,a1,1​,...,a1,n1​]...A0​=[at1,0​,at1,1​,...,at1,n1​]

  1. a
  2. i
  3. =
  4. [
  5. a
  6. 0
  7. ,
  8. i
  9. a
  10. 1
  11. ,
  12. i
  13. .
  14. .
  15. .
  16. a
  17. t
  18. 1
  19. ,
  20. i
  21. ]
  22. a'_i=\left[\begin{matrix} a_{0,i} \\ a_{1,i} \\ ... \\ a_{t-1,i} \end{matrix} \right]
  23. ai′​=​a0,i​a1,i​...at−1,i​​​和
  24. q
  25. j
  26. :
  27. =
  28. i
  29. =
  30. 0
  31. n
  32. 1
  33. a
  34. i
  35. w
  36. i
  37. ,
  38. j
  39. j
  40. [
  41. k
  42. ]
  43. q'_j:=\sum^{n-1}_{i=0}a'_iw_{i,j}\ \ \ \forall j\in [k]
  44. qj′​:=∑i=0n−1​ai′​wi,j​ ∀j∈[k],则有
  45. Q
  46. =
  47. q
  48. 0
  49. q
  50. 1
  51. .
  52. .
  53. .
  54. q
  55. k
  56. 1
  57. =
  58. [
  59. A
  60. 0
  61. W
  62. A
  63. 1
  64. W
  65. .
  66. .
  67. .
  68. A
  69. t
  70. 1
  71. W
  72. ]
  73. Q'=q'_0||q'_1||...||q'_{k-1}=\left[\begin{matrix} A_0W \\ A_1W \\ ... \\ A_{t-1}W \end{matrix} \right]
  74. Q′=q0′​∣∣q1′​∣∣...∣∣qk−1′​=​A0​WA1​W...At−1​W​​

在这里插入图片描述


预计算阶段:

这里,我们引入一个预计算阶段,其中S使用上述提到的密文压缩技术,将压缩后的密文

  1. (
  2. E
  3. n
  4. c
  5. S
  6. (
  7. [
  8. w
  9. i
  10. ,
  11. j
  12. ,
  13. w
  14. i
  15. ,
  16. j
  17. ,
  18. .
  19. .
  20. .
  21. ,
  22. w
  23. i
  24. .
  25. j
  26. t
  27. ×
  28. m
  29. ]
  30. )
  31. i
  32. [
  33. n
  34. ]
  35. ,
  36. j
  37. [
  38. k
  39. ]
  40. )
  41. (Enc_S([\underbrace{w_{i,j},w_{i,j},...,w_{i.j}}_{t\times m}])\ \ \forall i \in [n],j\in [k])
  42. (EncS​([t×mwi,j​,wi,j​,...,wi.j​​​]) i∈[n],j∈[k])发送给C。注意,该传输仅只发生一次,除非模型发生改变。接下来,C对压缩的密文执行分解技术,以获得
  43. E
  44. n
  45. c
  46. S
  47. (
  48. [
  49. w
  50. i
  51. ,
  52. j
  53. ,
  54. w
  55. i
  56. ,
  57. j
  58. ,
  59. .
  60. .
  61. .
  62. ,
  63. w
  64. i
  65. .
  66. j
  67. t
  68. ×
  69. m
  70. ]
  71. )
  72. i
  73. [
  74. n
  75. ]
  76. ,
  77. j
  78. [
  79. k
  80. ]
  81. Enc_S([\underbrace{w_{i,j},w_{i,j},...,w_{i.j}}_{t\times m}])\ \ \forall i \in [n],j\in [k]
  82. EncS​([t×mwi,j​,wi,j​,...,wi.j​​​]) i∈[n],j∈[k]。在预计算阶段C并没有关于输入的信息,采样
  83. U
  84. R
  85. (
  86. t
  87. m
  88. )
  89. ×
  90. n
  91. U\in \R^{(tm)\times n}
  92. UR(tmn,然后计算:
  93. E
  94. n
  95. c
  96. S
  97. (
  98. v
  99. j
  100. )
  101. i
  102. =
  103. 0
  104. n
  105. 1
  106. (
  107. u
  108. i
  109. ×
  110. E
  111. n
  112. c
  113. S
  114. (
  115. [
  116. w
  117. i
  118. ,
  119. j
  120. ,
  121. .
  122. .
  123. .
  124. ,
  125. w
  126. i
  127. ,
  128. j
  129. ]
  130. )
  131. )
  132. j
  133. [
  134. k
  135. ]
  136. Enc_S(v_j)\leftarrow\sum^{n-1}_{i=0}(u_i\times Enc_S([w_{i,j},...,w_{i,j}]))\ \ \ \forall j\in[k]
  137. EncS​(vj​)←i=0n1​(ui​×EncS​([wi,j​,...,wi,j​])) j∈[k]

其中

  1. u
  2. i
  3. u_i
  4. ui​是矩阵
  5. U
  6. U
  7. U的第i列。接下来,C使用自己的密钥来加密
  8. E
  9. n
  10. c
  11. S
  12. (
  13. v
  14. j
  15. )
  16. Enc_S(v_j)
  17. EncS​(vj​)以获得
  18. E
  19. n
  20. c
  21. C
  22. (
  23. E
  24. n
  25. c
  26. S
  27. (
  28. v
  29. j
  30. )
  31. )
  32. Enc_C(Enc_S(v_j))
  33. EncC​(EncS​(vj​)),并将其发送给S。注意
  34. E
  35. n
  36. c
  37. S
  38. (
  39. E
  40. n
  41. c
  42. C
  43. (
  44. v
  45. j
  46. )
  47. )
  48. =
  49. E
  50. n
  51. c
  52. C
  53. (
  54. E
  55. n
  56. c
  57. S
  58. (
  59. v
  60. j
  61. )
  62. )
  63. Enc_S(Enc_C(v_j))=Enc_C(Enc_S(v_j))
  64. EncS​(EncC​(vj​))=EncC​(EncS​(vj​)),故S可以对其进行解密,从而获得
  65. E
  66. n
  67. c
  68. C
  69. (
  70. v
  71. j
  72. )
  73. Enc_C(v_j)
  74. EncC​(vj​)。注意,这里的
  75. v
  76. j
  77. v_j
  78. vj​是矩阵
  79. U
  80. W
  81. U·W
  82. UW的第
  83. j
  84. j
  85. j列。

切换不同用户加密密钥的过程计算如下:

给定

  1. c
  2. t
  3. S
  4. =
  5. (
  6. a
  7. s
  8. S
  9. +
  10. m
  11. +
  12. e
  13. )
  14. ct_S=(-as_S+m+e)
  15. ctS​=(−asS​+m+e),

使用

  1. s
  2. C
  3. s_C
  4. sC​加密有
  5. c
  6. t
  7. C
  8. ,
  9. S
  10. =
  11. (
  12. a
  13. s
  14. S
  15. a
  16. s
  17. C
  18. +
  19. m
  20. +
  21. e
  22. +
  23. e
  24. ,
  25. a
  26. )
  27. ct_{C,S}=(-as_S-as_C+m+e+e',a)
  28. ctC,S​=(−asS​−asC​+m+e+e′,a),

使用

  1. s
  2. S
  3. s_S
  4. sS​解密有:
  5. c
  6. t
  7. C
  8. =
  9. (
  10. a
  11. s
  12. S
  13. a
  14. s
  15. C
  16. +
  17. m
  18. +
  19. e
  20. +
  21. e
  22. ,
  23. a
  24. )
  25. +
  26. (
  27. a
  28. s
  29. S
  30. ,
  31. 0
  32. )
  33. =
  34. (
  35. a
  36. s
  37. C
  38. +
  39. m
  40. +
  41. e
  42. +
  43. e
  44. )
  45. ct_C=(-as_S-as_C+m+e+e',a)+(as_S,0)=(-as_C+m+e+e')
  46. ctC​=(−asS​−asC​+m+e+e′,a)+(asS​,0)=(−asC​+m+e+e′).

在线处理阶段:

此时,C知道输入的信息

  1. A
  2. =
  3. a
  4. 0
  5. a
  6. 1
  7. .
  8. .
  9. .
  10. a
  11. n
  12. 1
  13. A'=a'_0||a'_1||...||a'_{n-1}
  14. A′=a0′​∣∣a1′​∣∣...∣∣an1′​,然后C将明文
  15. (
  16. A
  17. U
  18. )
  19. (A'-U)
  20. (A′−U)发送给S,注意,由于S不知道U的值,故S也不清楚
  21. A
  22. A'
  23. A′的值,然后S可以计算:
  24. (
  25. A
  26. U
  27. )
  28. W
  29. +
  30. (
  31. E
  32. n
  33. c
  34. C
  35. (
  36. v
  37. 0
  38. )
  39. E
  40. n
  41. c
  42. C
  43. (
  44. v
  45. 1
  46. )
  47. .
  48. .
  49. .
  50. E
  51. n
  52. c
  53. C
  54. (
  55. v
  56. k
  57. 1
  58. )
  59. )
  60. =
  61. (
  62. A
  63. W
  64. V
  65. )
  66. +
  67. (
  68. E
  69. n
  70. c
  71. C
  72. (
  73. v
  74. 0
  75. )
  76. E
  77. n
  78. c
  79. C
  80. (
  81. v
  82. 1
  83. )
  84. .
  85. .
  86. .
  87. E
  88. n
  89. c
  90. C
  91. (
  92. v
  93. k
  94. 1
  95. )
  96. )
  97. =
  98. (
  99. E
  100. n
  101. c
  102. C
  103. (
  104. q
  105. 0
  106. )
  107. E
  108. n
  109. c
  110. C
  111. (
  112. q
  113. 1
  114. )
  115. .
  116. .
  117. .
  118. E
  119. n
  120. c
  121. C
  122. (
  123. q
  124. k
  125. 1
  126. )
  127. )
  128. (A'-U)·W + (Enc_C(v_0)||Enc_C(v_1)||...||Enc_C(v_{k-1}))\\ =(A'W-V)+(Enc_C(v_0)||Enc_C(v_1)||...||Enc_C(v_{k-1})) \\ =(Enc_C(q'_0)||Enc_C(q'_1)||...||Enc_C(q'_{k-1}))
  129. (A′−U)⋅W+(EncC​(v0​)∣∣EncC​(v1​)∣∣...∣∣EncC​(vk−1​))=(A′W−V)+(EncC​(v0​)∣∣EncC​(v1​)∣∣...∣∣EncC​(vk−1​))=(EncC​(q0′​)∣∣EncC​(q1′​)∣∣...∣∣EncC​(qk−1′​))

其中

  1. q
  2. j
  3. q'_j
  4. qj′​是矩阵
  5. Q
  6. Q'
  7. Q′的第
  8. j
  9. j
  10. j列。算法6描述了优化后的矩阵乘法细节:

在这里插入图片描述

可以注意到,只需要预计算的过程交互一次,此后C可以直接向S发送明文信息

  1. A
  2. U
  3. A'-U
  4. A′−U,而不会泄露
  5. A
  6. A'
  7. A′的信息。

5 SIMD槽折叠算法

回想矩阵Q,K,V的行向量是使用SIMD方式加密的。而上述介绍的一系列操作,如内积,Softmax,LayerNorm和Argmax等, 均涉及到利用所有槽元素计算函数

  1. f
  2. (
  3. )
  4. f(*)
  5. f(∗),并将得到的结果放置到所有槽上。例如给定
  6. E
  7. n
  8. c
  9. (
  10. [
  11. a
  12. 0
  13. ,
  14. .
  15. .
  16. .
  17. ,
  18. a
  19. N
  20. 1
  21. ]
  22. )
  23. Enc([a_0,...,a_{N-1}])
  24. Enc([a0​,...,aN1​]),然后想要获得
  25. E
  26. n
  27. c
  28. (
  29. [
  30. s
  31. ,
  32. .
  33. .
  34. .
  35. ,
  36. s
  37. ]
  38. )
  39. Enc([s,...,s])
  40. Enc([s,...,s]),其中
  41. s
  42. =
  43. i
  44. =
  45. 0
  46. N
  47. 1
  48. a
  49. i
  50. s=\sum^{N-1}_{i=0}a_i
  51. s=∑i=0N1ai​,此时
  52. f
  53. (
  54. )
  55. f(*)
  56. f(∗)即求和函数。

本节提供了一种通用的解决方案,只要函数

  1. f
  2. (
  3. )
  4. f(*)
  5. f(∗)满足:
  6. f
  7. (
  8. f
  9. (
  10. a
  11. 0
  12. ,
  13. a
  14. 1
  15. )
  16. ,
  17. a
  18. 2
  19. )
  20. =
  21. f
  22. (
  23. a
  24. 0
  25. ,
  26. f
  27. (
  28. a
  29. 1
  30. ,
  31. a
  32. 2
  33. )
  34. )
  35. f(f(a_0,a_1),a_2)=f(a_0,f(a_1,a_2))
  36. f(f(a0​,a1​),a2​)=f(a0​,f(a1​,a2​))

算法7描述了槽折叠算法的细节:

在这里插入图片描述

这里是一个简单的例子,可以看到算法7的实现流程:

在这里插入图片描述

5.1 QuickSum

给定

  1. [
  2. a
  3. 0
  4. ,
  5. a
  6. 1
  7. ,
  8. .
  9. .
  10. .
  11. ,
  12. a
  13. n
  14. 1
  15. ,
  16. 0
  17. ,
  18. .
  19. .
  20. .
  21. ,
  22. 0
  23. ]
  24. [a_0,a_1,...,a_{n-1},0,...,0]
  25. [a0​,a1​,...,an1​,0,...,0],为了获得
  26. [
  27. i
  28. =
  29. 0
  30. N
  31. 1
  32. a
  33. i
  34. ,
  35. .
  36. .
  37. .
  38. ,
  39. i
  40. =
  41. 0
  42. N
  43. 1
  44. a
  45. i
  46. ,
  47. 0
  48. ,
  49. .
  50. .
  51. .
  52. ,
  53. 0
  54. ]
  55. [\sum^{N-1}_{i=0}a_i,...,\sum^{N-1}_{i=0}a_i,0,...,0]
  56. [∑i=0N1ai​,...,∑i=0N1ai​,0,...,0],可以将算法7的第5行替换为
  57. s
  58. ~
  59. s
  60. ~
  61. +
  62. a
  63. ~
  64. \tilde{s}\leftarrow\tilde{s}+\tilde{a}
  65. s~←s~+a~。

5.2 QuickMax

给定

  1. [
  2. a
  3. 0
  4. ,
  5. a
  6. 1
  7. ,
  8. .
  9. .
  10. .
  11. ,
  12. a
  13. n
  14. 1
  15. ,
  16. 0
  17. ,
  18. .
  19. .
  20. .
  21. ,
  22. 0
  23. ]
  24. [a_0,a_1,...,a_{n-1},0,...,0]
  25. [a0​,a1​,...,an1​,0,...,0],为了获得
  26. a
  27. m
  28. a
  29. x
  30. ,
  31. .
  32. .
  33. .
  34. ,
  35. a
  36. m
  37. a
  38. x
  39. ,
  40. 0
  41. ,
  42. .
  43. .
  44. .
  45. ,
  46. 0
  47. a_{max},...,a_{max},0,...,0
  48. amax​,...,amax​,0,...,0,其中
  49. a
  50. m
  51. a
  52. x
  53. =
  54. m
  55. a
  56. x
  57. (
  58. a
  59. 0
  60. ,
  61. a
  62. 1
  63. ,
  64. .
  65. .
  66. .
  67. ,
  68. a
  69. n
  70. 1
  71. )
  72. a_{max}=max(a_0,a_1,...,a_{n-1})
  73. amax​=max(a0​,a1​,...,an1​),很明显
  74. m
  75. a
  76. x
  77. (
  78. a
  79. ,
  80. b
  81. )
  82. max(a,b)
  83. max(a,b)可以表示为:
  84. m
  85. a
  86. x
  87. (
  88. a
  89. ,
  90. b
  91. )
  92. =
  93. a
  94. +
  95. b
  96. +
  97. (
  98. a
  99. b
  100. )
  101. S
  102. g
  103. n
  104. (
  105. a
  106. b
  107. )
  108. 2
  109. max(a,b)={a+b+(a-bSgn(a-b)\over 2}
  110. max(a,b)=2a+b+(ab)⋅Sgn(ab)​

因此可以将算法7的第5行替换为:

  1. s
  2. ~
  3. 0.5
  4. (
  5. a
  6. ~
  7. s
  8. ~
  9. (
  10. a
  11. ~
  12. s
  13. ~
  14. )
  15. S
  16. g
  17. n
  18. (
  19. a
  20. ~
  21. s
  22. ~
  23. )
  24. )
  25. \tilde{s}\leftarrow 0.5\otimes(\tilde{a}\oplus\tilde{s}\oplus(\tilde{a}\ominus\tilde{s})\otimes Sgn(\tilde{a}\ominus\tilde{s}))
  26. s~←0.5⊗(a~⊕s~⊕(a~⊖s~)⊗Sgn(a~⊖s~))

6. Conclusion

本文提出了NEXUS系统,可以说是第一个不需要客户端和服务器进行交互的安全transformer推理协议。本文提出了适用于RNS-CKKS的一系列新协议,以使得服务器可以高效且精确的在加密数据上计算transformer的每一层。


本文转载自: https://blog.csdn.net/Wwj2981/article/details/142500904
版权归原作者 lucky_wjie 所有, 如有侵权,请联系我们删除。

“基于CKKS的非交互式安全Transformer推理实现”的评论:

还没有评论