Secure Transformer Inference Made Non-interactive
本文介绍了如何使用CKKS来计算transformer推理的每个部分。同时给出了一系列优化算法。主要涉及到的计算算法有以下几种: 密文的压缩与分解技术、SIMD槽折叠技术、Sgn()、QuickSum、QuickMax、密文-明文矩阵相乘、密文-密文矩阵相乘法、Softmax算法、归一化、GELU函数、Argmax函数等等。其中密文的压缩与分解技术,和SIMD槽折叠技术是本文的核心创新算法。
Abstract
随着ChatGPT的普及,安全transformer推理已经成为一个突出了研究主题。已有的解决方法通常是交互式的,涉及到客户端和服务端之间大量的通信负载和交互轮次。
本文提出NEXUS,这是第一个用于安全transformer推理的非交互式协议,其中客户端仅需要提交一个加密输入,然后等待来自服务器的加密结果即可。NEXUS的核心是两个创新的技术:SIMD密文压缩和分解技术,以及SIMD槽折叠技术。此外,同24年的另外一个解决方案相比,本方法达到了2.8倍的加速,且减少了368.6倍的带宽消耗。
1 Introduction
Transformers,例如GPT和BERT,已经彻底改变了AI领域。Transformer擅长于广泛领域的应用,比如语言翻译,内容生成以及问题回答。然而这些应用总是涉及到敏感数据,从而导致越来越多地关于用户隐私的担忧。例,OpenAI开发的ChatGPT作为一种在线推理服务,以及为开发人员提供的远程API,其中使用者通过提交prompts或者消息可以很容易地访问这些服务。尽管这些方法是方便的,但是由于使用者提交的数据可能包含敏感信息,故而造成了严重的隐私风险。
Secure inference是一种两方密码协议,该协议使模型推理以如下方式处理运行,即服务器S不会了解到关于客户C提交的输入的任何信息,且C不会了解到关于S的模型的任何信息,仅仅能得到最终的推理结果。
该协议大多被设计于安全CNNs推
[
2
,
27
,
30
,
36
]
[2,27,30,36]
[2,27,30,36],最近的许多工作也支持基于Transformer的模型
[
10
,
24
,
26
,
35
,
38
,
40
]
[10,24,26,35,38,40]
[10,24,26,35,38,40],值得注意的是,这些安全Transformer模型大多都是交互式的,因此会导致巨大的通信开销和交互轮次,这里我们必须强调非交互式安全Transformer推理的重要性。
本文贡献:
本文中,我们提出了NEXUS,第一个secure transformer inference的非交互协议。通过NEXUS,C使用RNS-CKKS加密输入,S对FHE加密数据执行transformer。CKKS的SIMD技术被应用于批处理
N
=
2
15
N=2^{15}
N=215个数据,多项式近似可以用于处理非线性函数,比如GELU,softmax,层归一化和argmax。
NEXUS不需要对模型进行任何重训练与微调,且为了提高NEXUS的效率,我们提出了两种新颖的且基础的技术。
- SIMD密文压缩与分解:该技术可以将2N个SIMD密文压缩为一个密文,然后可以使用4N个密文—明文乘法和替换将其解压回来。该技术可以大大减少客户端和服务器之间传输的密文数量,而不会为后续计算带来任何额外的开销。
- SIMD槽折叠:在所有SIMD槽中计算关联函数f(),例如sum和max。结果值会自动的填充SIMD密文的槽,允许将其应用于原始密文的每个槽。
本文贡献总结如下:
- secure transformer inference的第一个非交互协议
- 用于密文打包的SIMD密文压缩与分解技术
- SIMD槽折叠技术,以高效操作SIMD密文的槽
- 综合的实现与评估
2 Preliminaries
符号系统描述如下:
NotationDescriptionNotationDescriptionCclientSserver
E
(
∗
)
E(*)
E(∗)encryption
π
(
∗
)
\pi(*)
π(∗)encoding
E
n
c
(
∗
)
Enc(*)
Enc(∗)encoding+encryption
a
~
\tilde{a}
a~FHE ciphertext
R
o
t
L
(
∗
)
/
R
o
t
R
(
∗
)
RotL(*)/RotR(*)
RotL(∗)/RotR(∗)左旋转和右旋转
S
u
b
s
(
∗
)
Subs(*)
Subs(∗)替换操作
S
g
n
(
∗
)
Sgn(*)
Sgn(∗)sign操作
L
L
L乘法深度
N
′
N'
N′CKKS的环维数
N
N
N
N
=
N
′
/
2
N=N'/2
N=N′/2
A
A
A输入矩阵
W
W
W权重矩阵
2.1 安全推理和威胁模型
安全推理是一个两方密码学协议,其可以在C和S之间进行模型推理,与此同时还可以保护两个参与方输入隐私。它的正式定义如下:
Definition 1:
针对两方参与者,其中
S
S
S持有模型
M
M
M,且
C
C
C持有输入
A
A
A的协议
Π
\Pi
Π是安全推理协议,当且仅当以下条件满足时:
(1) 正确性: 该协议的最终输出是正确的推理结果
M
(
A
)
M(A)
M(A)。
(2) 安全性:
V
i
e
w
C
Π
≈
c
S
i
m
C
(
A
,
o
u
t
)
View^{\Pi}_C\approx_c Sim_C(A,out)
ViewCΠ≈cSimC(A,out),其中
V
i
e
w
C
Π
View^{\Pi}_C
ViewCΠ表示协议
Π
\Pi
Π执行期间
C
C
C的视角,
o
u
t
out
out表示推理的结果。
V
i
e
w
S
Π
≈
S
S
i
m
S
(
M
)
View^{\Pi}_S\approx_S Sim_S(M)
ViewSΠ≈SSimS(M),其中
V
i
e
w
S
Π
View^{\Pi}_S
ViewSΠ表示协议
Π
\Pi
Π执行期间
S
S
S的视角。
S
i
m
∗
Sim_*
Sim∗可以理解为理想状态下希望实体
∗
*
∗可以得到的信息。
假设
C
C
C或
S
S
S为半诚实对手,其在遵守协议规范的同时也尽可能的在执行过程中手机额外的信息。且假设对手在计算上是有限的。
2.2 Transformer
这里简单介绍一下Transformerd。
图1是transformer的结构与工作流程。它将一个表示为矩阵的嵌入传递给注意层和前馈神经网络,最后根据最终对数最大值输出一个选择向量,且,LayerNorm层被应用于每个块之后。
transformer的结构和工作流程
Attention:
使用三个矩阵(
W
Q
∈
R
n
×
k
,
W
K
∈
R
n
×
k
,
W
V
∈
R
n
×
k
W_Q\in\mathbb{R}^{n\times k},W_K\in\mathbb{R}^{n\times k},W_V\in\mathbb{R}^{n\times k}
WQ∈Rn×k,WK∈Rn×k,WV∈Rn×k)乘嵌入矩阵
A
∈
R
m
×
n
A\in \mathbb{R}^{m\times n}
A∈Rm×n,生成一个query矩阵
Q
=
A
⋅
W
Q
Q = A·W_Q
Q=A⋅WQ,一个key矩阵
K
=
A
⋅
W
K
K=A·W_K
K=A⋅WK和一个value矩阵
V
=
A
⋅
W
V
V=A·W_V
V=A⋅WV。即对于Attention层的单元,transformer会学习到三个权重矩阵。
attention可以被表示为:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
m
a
x
(
Q
K
T
k
)
⋅
V
Attention(Q,K,V) = Softmax({QK^T\over{\sqrt k}})·V
Attention(Q,K,V)=Softmax(kQKT)⋅V
Layer normalization
该层的输入为
a
∈
R
n
a\in \mathbb{R}^n
a∈Rn,均值和标准差分别为
μ
\mu
μ和
σ
\sigma
σ,则该层的输出
y
∈
R
n
y\in\mathbb{R}^n
y∈Rn可以表示为:
y
i
=
γ
⋅
x
i
−
μ
σ
+
β
y_i=\gamma·{x_i-\mu\over\sigma}+\beta
yi=γ⋅σxi−μ+β
其中,
γ
,
β
∈
R
\gamma,\beta\in\mathbb{R}
γ,β∈R是两个超参数。
Feed-forward
全连接前馈网络层包含两个线性变换以及一个GELU激活函数:
F
e
e
d
F
o
r
w
a
r
d
(
X
)
=
G
E
L
U
(
X
W
1
+
b
1
)
⋅
W
2
+
b
2
FeedForward(X)=GELU(XW_1+b_1)·W_2+b_2
FeedForward(X)=GELU(XW1+b1)⋅W2+b2
其中GELU函数计算如下:
G
E
L
U
(
x
)
=
1
2
x
⋅
(
1
+
e
r
f
(
x
2
)
)
GELU(x)={1\over 2}x·(1+erf({x\over \sqrt 2}))
GELU(x)=21x⋅(1+erf(2x))
式中,高斯误差函数为
e
r
f
(
x
)
=
2
π
∫
0
x
e
−
t
2
d
t
erf(x)={2\over\sqrt{\pi}}\int_0^xe^{-t^2}dt
erf(x)=π2∫0xe−t2dt。由于其良好的曲率和非单调性,它被用作激活函数。
Argmax
根据最终对数最大值输出一个选择向量
可以看到,只要我们能够使用FHE实现各个层的计算,就可以实现一个安全Transformer。
2.3 Fully Homomorphic Encryption
FHE可以对加密数据执行任意操作,故FHE是使得我们构建非交互式安全transformer推理得主要工具。RNS-CKKS属于级全同态加密,其可以支持L级深度的乘法。RNS-CKKS的明文和密文均是多项式环
R
Q
=
Z
Q
[
X
]
/
(
X
N
′
+
1
)
R_Q=\Z_Q[X]/(X^{N'}+1)
RQ=ZQ[X]/(XN′+1)上的元素。其中
Q
=
Π
i
=
0
L
q
i
Q=\Pi^L_{i=0}q_i
Q=Πi=0Lqi,且
q
i
q_i
qi之间互素。若密文的级别变得太低,则可以运行自举操作来刷新密文到高的级别,以允许更多的计算。
简单地说,自举即利用自同构
R
q
0
≅
R
q
0
×
R
q
1
×
.
.
.
×
R
q
L
R_{q_0}\cong R_{q_0}\times R_{q_1}\times ... \times R_{q_L}
Rq0≅Rq0×Rq1×...×RqL,来将密文模从
q
0
q_0
q0提升到
q
L
q_L
qL,以及对密文同态评估解密电路。若自举本身消耗K个级别,则刷新后的密文支持
L
−
K
L-K
L−K个级深度的计算。
RNS-CKKS支持SIMD操作,其可以加密向量
a
∈
R
N
a\in \R^N
a∈RN到一个密文中,且批处理这些加密元素,而不引入其他操作。为了以SIMD格式加密,首先使用编码算法
π
(
∗
)
\pi(*)
π(∗)将向量
a
a
a编码为一个
R
Q
R_Q
RQ上的多项式,然后使用加密算法
E
(
∗
)
E(*)
E(∗)加密该多项式。
在整篇文章中,我们使用
E
(
∗
)
E(*)
E(∗)表示加密多项式,使用
E
n
c
(
∗
)
Enc(*)
Enc(∗)表示以SIMD格式加密向量,即
E
n
c
(
a
)
=
E
(
π
(
a
)
)
Enc(a)=E(\pi(a))
Enc(a)=E(π(a)),其中
a
a
a是一个向量。
一个特殊的FHE操作:
c
t
′
←
S
u
b
s
(
c
t
,
k
)
ct'\leftarrow Subs(ct,k)
ct′←Subs(ct,k):替换操作,该操作以密文
c
t
=
E
(
p
(
x
)
)
ct=E(p(x))
ct=E(p(x))以及一个奇整数
k
k
k作为输入,然后得到新的密文
c
t
′
=
E
(
p
(
x
k
)
)
ct'=E(p(x^k))
ct′=E(p(xk))。
这里的
S
u
b
s
(
c
t
,
k
)
Subs(ct,k)
Subs(ct,k)应该是一种密钥交换操作,可以描述如下:
已知密文:
c
t
=
(
−
a
(
x
)
s
(
x
)
+
e
(
x
)
+
p
(
x
)
,
a
(
x
)
)
ct=(-a(x)s(x)+e(x)+p(x),a(x))
ct=(−a(x)s(x)+e(x)+p(x),a(x))
将该密文进行自同构操作:
κ
k
(
c
t
)
=
(
−
a
(
x
k
)
s
(
x
k
)
+
e
(
x
k
)
+
p
(
x
k
)
,
a
(
x
k
)
)
\kappa_k(ct)=(-a(x^k)s(x^k)+e(x^k)+p(x^k),a(x^k))
κk(ct)=(−a(xk)s(xk)+e(xk)+p(xk),a(xk))。
然后得到用户提供的交换密钥:
k
e
y
=
(
−
a
(
x
)
s
(
x
)
+
e
(
x
)
+
P
⋅
s
(
x
k
)
,
a
(
x
)
)
key = (-a(x)s(x)+e(x)+P·s(x^k),a(x))
key=(−a(x)s(x)+e(x)+P⋅s(xk),a(x))
然后执行密钥交换操作:
c
t
′
=
(
κ
k
(
c
t
)
[
0
]
,
0
)
+
(
⌊
P
−
1
⋅
κ
k
(
c
t
)
[
1
]
⋅
k
e
y
⌉
)
ct'=(\kappa_k(ct)[0],0)+(\lfloor P^{-1}·\kappa_k(ct)[1]·key\rceil)
ct′=(κk(ct)[0],0)+(⌊P−1⋅κk(ct)[1]⋅key⌉)
此时新的密文即
c
t
′
=
(
−
a
(
x
)
s
(
x
)
+
e
(
x
)
+
p
(
x
k
)
,
a
(
x
)
)
ct'=(-a(x)s(x)+e(x)+p(x^k),a(x))
ct′=(−a(x)s(x)+e(x)+p(xk),a(x))
注意,这里的
a
(
x
)
,
e
(
x
)
a(x),e(x)
a(x),e(x)是变化的,也就是不同的密文中,这是不同的。
2.4 Homomorphic sign function
由于FHE仅支持线性函数,所以为了实现在FHE下对加密数据的比较,本文需要利用sign函数的多项式近似,即:
s
i
g
n
(
x
)
=
f
d
f
(
g
d
g
(
x
)
)
=
{
−
1
(
−
1
≤
x
≤
−
2
−
α
)
0
(
x
=
0
)
1
(
2
−
α
≤
x
≤
1
)
sign(x)=f^{d_f}(g^{d_g}(x))=\begin{cases} -1 &(-1\leq x \leq -2^{-\alpha}) \\ 0 &(x = 0) \\ 1 &(2^{-\alpha}\leq x \leq 1) \\ \end{cases}
sign(x)=fdf(gdg(x))=⎩⎨⎧−101(−1≤x≤−2−α)(x=0)(2−α≤x≤1)
其中,
f
(
)
,
g
(
)
f(),g()
f(),g()为两个多项式,
d
f
,
d
g
d_f,d_g
df,dg为这两个多项式重复的次数。注意,该多项式近似要求输入x取值范围为[-1,1]。因此,对任何输入
a
∈
[
a
m
i
n
,
a
m
a
x
]
a\in [a_{min},a_{max}]
a∈[amin,amax]都需要进行归一化处理:
x
:
=
a
/
m
a
x
{
∣
a
m
a
x
∣
,
∣
a
m
i
n
∣
}
x := a/max\{|a_{max}|,|a_{min}|\}
x:=a/max{∣amax∣,∣amin∣}
这里,我们使用Sgn()表示在SIMD密文上同时运行归一化与sign近似函数:
b
~
←
S
g
n
(
a
~
)
:
b
i
=
f
d
f
(
g
d
g
(
a
i
m
a
x
{
∣
a
m
a
x
∣
,
∣
a
m
i
n
∣
}
)
)
∀
i
∈
[
N
]
\widetilde b\leftarrow Sgn(\widetilde a): b_i =f^{d_f}(g^{d_g}({a_i\over{max\{|a_{max}|,|a_{min}|\}}})) \ \ \forall i \in [N]
b←Sgn(a):bi=fdf(gdg(max{∣amax∣,∣amin∣}ai)) ∀i∈[N]
在本文的实现中,使用的是9次的
f
(
∗
)
f(*)
f(∗)和
g
(
∗
)
g(*)
g(∗),且设计
α
=
16
,
d
f
=
2
,
d
g
=
2
\alpha=16,d_f=2,d_g=2
α=16,df=2,dg=2,然后使用BSGS算法来评估多项式。
3 Basic design
本节介绍NEXUS的基础设计,即在不优化的情况下实现上述transformer的每一层计算,在之后的章节中会对本节的算法进行优化。
3.1 Attention
3.1.1 Matrix multiplication(ciphertext-plaintext)
在Attention层的第一个MatMul步骤,我们需要计算三个密文—明文矩阵乘法:
Q
:
=
A
⋅
W
Q
;
K
:
=
A
⋅
W
K
;
V
:
=
A
⋅
W
V
;
Q:=A·W_Q;\\ K:=A·W_K;\\ V:=A·W_V;
Q:=A⋅WQ;K:=A⋅WK;V:=A⋅WV;
其中A是我们的输入,
W
Q
,
W
K
,
W
V
W_Q,W_K,W_V
WQ,WK,WV是三个给定矩阵,下面以
A
⋅
W
Q
A·W_Q
A⋅WQ为例来描述这个密文—明文矩阵乘法,该过程同样适用于
W
K
W_K
WK和
W
V
W_V
WV。
给定矩阵
A
∈
R
m
×
n
A\in \mathbb{R}^{m\times n}
A∈Rm×n和矩阵
W
Q
∈
R
n
×
k
W_Q\in \mathbb{R}^{n\times k}
WQ∈Rn×k,计算矩阵
Q
:
=
A
⋅
W
Q
Q:=A·W_Q
Q:=A⋅WQ。
设
a
i
,
j
∈
R
a_{i,j}\in \mathbb{R}
ai,j∈R表示矩阵A的第i行第j列的元素,
w
j
∈
R
k
w_j\in \mathbb{R}^k
wj∈Rk表示矩阵
W
Q
W_Q
WQ的第j行的元素向量,
q
i
∈
R
k
q_i\in \mathbb{R}^k
qi∈Rk是矩阵
Q
Q
Q的第i行的元素向量,即:
q
i
=
∑
j
∈
[
n
]
a
i
,
j
⋅
w
j
q_i=\sum_{j\in [n]}a_{i,j}·w_j
qi=j∈[n]∑ai,j⋅wj
因此,上述过程可以描述为,C将A中的每个元素
a
i
,
j
a_{i,j}
ai,j均单独加密为密文发送给S,然后S同态评估MatrixMul,一个演示的示例如下:
图2 SIMD-based matrix multiplication
在上述描述中,C需要发送
m
×
n
m\times n
m×n个密文给S,从某一方面来说,这种开销是比较大的,因此本文在第4节提出一种算法可以将如此类型的
m
×
n
m\times n
m×n个密文压缩为
m
×
n
N
′
m\times n\over{N'}
N′m×n个密文,即一个密文中存放
N
′
N'
N′个元素,随后S可以将压缩后的密文恢复为压缩前的密文形式。
3.1.2 Matrix multiplication(ciphertext-ciphertext)
经过上述步骤后,可以获得加密的
(
Q
,
K
,
V
)
(Q,K,V)
(Q,K,V),在Attention的第二个MatMul块,S需要计算
Q
⋅
K
T
Q·K^T
Q⋅KT。很明显,现在Q的每一行和
K
T
K^T
KT的每一列已经以SIMD的形式加密为
E
n
c
(
q
)
,
E
n
c
(
k
T
)
Enc(q) , Enc(k^T)
Enc(q),Enc(kT)。如果S可以计算
E
n
c
(
q
)
Enc(q)
Enc(q)和
E
n
c
(
k
T
)
Enc(k^T)
Enc(kT)的内积,则可以获得
Q
⋅
K
T
Q·K^T
Q⋅KT的加密结果。
由于SIMD,S可以很容易的计算得到Enc(u),其中
u
=
[
u
0
,
.
.
.
,
u
k
−
1
]
u=[u_0,...,u_{k-1}]
u=[u0,...,uk−1]是q和
k
T
k^T
kT的元素级的乘法,现在为了计算内积,S仅仅需要在SIMD下计算
s
:
=
∑
i
=
0
k
−
1
u
i
s:=\sum_{i=0}^{k-1}u_i
s:=∑i=0k−1ui。
为了计算这个和,我们可以通过k-1次的旋转及加和来计算,从而获得密文Enc([s,s,…,s]),但是本文提出了’QuickSum’算法,该算法仅仅需要logk次旋转就可达到这个目标。'QuickSum’算法在第5节介绍。
进一步,S将计算得到每一行的的m个密文组合到单一密文中,计算方法如下:
∑
i
=
0
m
−
1
(
E
n
c
(
s
i
,
s
i
,
.
.
.
,
s
i
)
⋅
b
i
)
\sum_{i=0}^{m-1}(Enc(s_i,s_i,...,s_i)·b_i)
i=0∑m−1(Enc(si,si,...,si)⋅bi)
其中
b
i
b_i
bi仅在第i个槽的位置是1,其余槽均为0。
易知,输出矩阵为
A
∈
R
m
×
m
A\in \mathbb{R}^{m\times m}
A∈Rm×m,其中A的每行向量以SIMD形式加密,将该结果作为Softmax的输入。
3.1.3 Softmax
Softmax函数需要被应用于A的每一行,该函数评估如下:
y
i
=
e
x
p
(
a
i
−
a
m
a
x
)
∑
j
=
0
m
−
1
e
x
p
(
a
j
−
a
m
a
x
)
(1)
y_i={exp(a_i-a_{max})\over{\sum_{j=0}^{m-1}exp(a_j-a_{max})}}\tag{1}
yi=∑j=0m−1exp(aj−amax)exp(ai−amax)(1)
其中
a
m
a
x
=
m
a
x
(
a
0
,
.
.
.
,
a
m
−
1
)
a_{max}=max(a_0,...,a_{m-1})
amax=max(a0,...,am−1),从而确保指数函数的每个输入
(
a
j
−
a
m
a
x
)
(a_j-a_{max})
(aj−amax)是非正数,保证稳定性。
本文提出了’QuickMax’算法,该算法以
E
n
c
(
[
a
0
,
.
.
.
,
a
m
−
1
]
)
Enc([a_0,...,a_{m-1}])
Enc([a0,...,am−1])为输入,并输出
E
n
c
(
[
a
m
a
x
,
.
.
.
,
a
m
a
x
]
)
Enc([a_{max},...,a_{max}])
Enc([amax,...,amax]),且,该算法仅需要logm-1次Sgn操作与logm次旋转操作。该算法描述在第5节。
给定
E
n
c
(
[
a
0
,
.
.
.
,
a
m
−
1
]
)
Enc([a_0,...,a_{m-1}])
Enc([a0,...,am−1])和
E
n
c
(
[
a
m
a
x
,
.
.
.
,
a
m
a
x
]
)
Enc([a_{max},...,a_{max}])
Enc([amax,...,amax])。
S进行如下步骤计算:
E
n
c
(
[
a
0
′
,
.
.
.
,
a
m
−
1
′
]
)
=
E
n
c
(
[
a
0
,
.
.
.
,
a
m
−
1
]
)
−
E
n
c
(
[
a
m
a
x
,
.
.
.
,
a
m
a
x
]
)
Enc([a'_0,...,a'_{m-1}])=Enc([a_0,...,a_{m-1}])-Enc([a_{max},...,a_{max}])
Enc([a0′,...,am−1′])=Enc([a0,...,am−1])−Enc([amax,...,amax])
然后根据如下公式计算指数函数,这里使用泰勒展开:
e
x
p
(
x
)
≈
(
1
+
x
2
r
)
2
r
,
x
≤
0
exp(x)\approx(1+{x\over{2^r}})^{2^r},x\leq 0
exp(x)≈(1+2rx)2r,x≤0
其中
r
=
6
r=6
r=6,此时平均误差被限制在
1
0
−
5
10^{-5}
10−5,即S以SIMD格式计算指数函数:
E
n
c
(
e
0
,
.
.
.
,
e
m
−
1
)
=
e
x
p
(
E
n
c
(
[
a
0
′
,
.
.
.
,
a
m
−
1
′
]
)
)
Enc(e_0,...,e_{m-1})=exp(Enc([a'_0,...,a'_{m-1}]))
Enc(e0,...,em−1)=exp(Enc([a0′,...,am−1′]))
很明显,这里
e
j
=
e
x
p
(
a
j
′
)
e_j=exp(a'_j)
ej=exp(aj′)。
接下来,S应用
Q
u
i
c
k
S
u
m
(
∗
)
QuickSum(*)
QuickSum(∗)算法来获得
E
n
c
(
[
∑
j
=
0
m
−
1
e
j
,
.
.
.
,
∑
j
=
0
m
−
1
e
j
]
)
Enc([\sum^{m-1}_{j=0}e_j,...,\sum^{m-1}_{j=0}e_j])
Enc([∑j=0m−1ej,...,∑j=0m−1ej])。
进一步的,S使用文献[21,24]中的Goldschmidt除法算法来计算:
E
n
c
(
y
0
,
.
.
.
,
y
m
−
1
)
=
E
n
c
(
e
0
,
.
.
.
,
e
m
−
1
)
E
n
c
(
[
∑
j
=
0
m
−
1
e
j
,
.
.
.
,
∑
j
=
0
m
−
1
e
j
]
)
Enc(y_0,...,y_{m-1})={Enc(e_0,...,e_{m-1})\over Enc([\sum^{m-1}_{j=0}e_j,...,\sum^{m-1}_{j=0}e_j])}
Enc(y0,...,ym−1)=Enc([∑j=0m−1ej,...,∑j=0m−1ej])Enc(e0,...,em−1)
Softmax算法的详细描述如算法1所示:
3.1.4 Matrix multiplication(ciphertext-ciphertext)
这里是Attention的最后一个MatMul块,该块的计算原理同3.1.2节完全一致。
3.2 Layer normalization
本文的归一化表示如下(但是不太清楚这个归一化使用的是什么计算公式):
3.3 Feed forward
前馈网络层涉及到两个矩阵乘法以及一个GELU。矩阵乘法如上文所述来计算。GELU可以使用下述分段多项式来近似,当输入
x
∈
[
−
60
,
60
]
x\in [-60,60]
x∈[−60,60],则可以确保误差在
1
0
−
3
10^{-3}
10−3内。
G
E
L
U
(
x
)
=
∈
{
0
(
x
≤
−
4
)
P
(
x
)
=
∑
i
=
0
i
=
3
c
i
x
i
(
−
4
<
x
≤
−
1.95
)
Q
(
x
)
=
∑
i
=
0
i
=
6
d
i
x
i
(
−
1.95
<
x
≤
3
)
x
(
x
>
3
)
GELU(x)=\in \begin{cases} 0 &(x\leq -4) \\ P(x)=\sum_{i=0}^{i=3}c_ix^i &(-4<x\leq -1.95) \\ Q(x)=\sum_{i=0}^{i=6}d_ix^i &(-1.95<x\leq 3) \\ x &(x>3) \end{cases}
GELU(x)=∈⎩⎨⎧0P(x)=∑i=0i=3cixiQ(x)=∑i=0i=6dixix(x≤−4)(−4<x≤−1.95)(−1.95<x≤3)(x>3)
首先,使用Sgn操作获得四个加密bit:
b
0
,
b
1
,
b
2
,
b
3
b_0,b_1,b_2,b_3
b0,b1,b2,b3,当且仅当输入x属于第i段时,
b
i
=
1
b_i=1
bi=1,否则
b
i
=
0
b_i=0
bi=0,如此,GELU(x)函数可以表示为:
G
E
L
U
(
x
)
:
=
b
0
⋅
0
+
b
1
⋅
P
(
x
)
+
b
2
⋅
Q
(
x
)
+
b
3
⋅
x
GELU(x):=b_0·0+b_1·P(x)+b_2·Q(x)+b_3·x
GELU(x):=b0⋅0+b1⋅P(x)+b2⋅Q(x)+b3⋅x。
完整的Secure GELU算法可以表示如下:
3.4 Argmax
transformer最终的输出应该是一个选择向量
E
n
c
(
[
b
0
,
.
.
.
,
b
m
−
1
]
)
Enc([b_0,...,b_{m-1}])
Enc([b0,...,bm−1]),其中
b
i
=
1
i
f
a
i
=
m
a
x
(
a
0
,
.
.
.
,
a
m
−
1
)
b_i=1 \ if \ a_i=max(a_0,...,a_{m-1})
bi=1 if ai=max(a0,...,am−1),其他情况下
b
i
=
0
b_i=0
bi=0。
因此,本文的Secure Argmax算法如下:
3.5 Placement of bootstrapping
由于bootstrapping操作是昂贵的,因此合理的放置bootstrapping的位置是至关重要的。
图4 Placement of bootstrapping for a BERT-base transformer
4. SIMD密文的压缩和分解
假设C想要发送N’个密文给S,且每个密文以SIMD方式加密N个相同的值,Enc([
a
0
,
.
.
.
,
a
0
a_0,...,a_0
a0,...,a0]),…,Enc([
a
N
′
−
1
,
.
.
.
,
a
N
′
−
1
a_{N'-1},...,a_{N'-1}
aN′−1,...,aN′−1])。
SIMD密文的压缩算法
C将向量[
a
0
,
a
1
,
.
.
.
,
a
N
′
−
1
a_0,a_1,...,a_{N'-1}
a0,a1,...,aN′−1]的各个元素打包到一个多项式的系数中,即:
p
(
x
)
=
a
0
+
a
1
x
+
a
2
x
2
+
.
.
.
+
a
N
′
−
1
x
N
′
−
1
p(x)=a_0+a_1x+a_2x^2+...+a_{N'-1}x^{N'-1}
p(x)=a0+a1x+a2x2+...+aN′−1xN′−1
然后将该多项式加密
p
~
0
=
E
(
p
(
x
)
)
\widetilde p_0=E(p(x))
p0=E(p(x))发送给S。
然后S可以对密文
p
~
0
\widetilde p_0
p0分解从而得到压缩前
N
′
N'
N′个SIMD密文。
S分解密文
p
~
0
\widetilde p_0
p0过程如下:
SIMD密文的分解算法:
(1)执行
S
u
b
s
(
p
~
0
,
N
′
+
1
)
Subs(\widetilde p_0, N'+1)
Subs(p0,N′+1)返回:
E
(
a
0
+
a
1
x
N
′
+
1
+
a
2
x
(
N
′
+
1
)
2
+
.
.
.
+
a
N
′
−
1
x
(
N
′
+
1
)
N
′
−
1
)
=
E
(
a
0
+
a
1
(
−
x
)
+
a
2
(
−
x
)
2
)
+
.
.
.
+
a
N
′
−
1
(
−
x
)
N
′
−
1
)
E(a_0+a_1x^{N'+1}+a_2x^{(N'+1)^2}+...+a_{N'-1}x^{(N'+1)^{N'-1}}) \\ =E(a_0+a_1(-x)+a_2(-x)^2)+...+a_{N'-1}(-x)^{N'-1})
E(a0+a1xN′+1+a2x(N′+1)2+...+aN′−1x(N′+1)N′−1)=E(a0+a1(−x)+a2(−x)2)+...+aN′−1(−x)N′−1)
注意,
x
N
′
+
1
≡
0
(
m
o
d
x
N
′
+
1
)
x^{N'}+1 \equiv 0 \ \ (mod \ \ x^{N'} + 1)
xN′+1≡0 (mod xN′+1),因此
x
N
′
+
1
=
x
N
′
∗
x
=
−
x
(
m
o
d
x
N
′
+
1
)
x^{N'+1} = x^{N'} * x = -x \ (mod \ \ x^{N'}+1)
xN′+1=xN′∗x=−x (mod xN′+1),这里的
N
’
N’
N’也就是分圆环的次数。
(2)执行
p
~
0
+
S
u
b
s
(
p
~
0
,
N
′
+
1
)
\widetilde p_0+Subs(\widetilde p_0,N'+1)
p0+Subs(p0,N′+1)操作,移除p(x)的所有奇数项。
a
0
+
a
1
x
+
a
2
x
2
+
.
.
.
+
a
N
′
−
1
x
N
′
+
1
+
a
0
+
a
1
(
−
x
)
+
a
2
(
−
x
)
2
)
+
.
.
.
+
a
N
′
−
1
(
−
x
)
N
′
−
1
=
a
0
+
0
x
+
a
2
x
2
+
.
.
.
+
a
N
′
−
2
x
N
′
−
2
+
0
x
N
′
−
1
a_0+a_1x+a_2x^2+...+a_{N'-1}x^{N'+1} \\+ a_0+a_1(-x)+a_2(-x)^2)+...+a_{N'-1}(-x)^{N'-1}\\= a_0+0x+a_2x^2+...+a_{N'-2}x^{N'-2}+0x^{N'-1}
a0+a1x+a2x2+...+aN′−1xN′+1+a0+a1(−x)+a2(−x)2)+...+aN′−1(−x)N′−1=a0+0x+a2x2+...+aN′−2xN′−2+0xN′−1
(3)通过
l
o
g
N
′
log N'
logN′次
S
u
b
s
(
)
Subs()
Subs()操作,S可以提取得到密文:
E
(
a
0
+
0
x
1
+
0
x
2
+
.
.
.
+
0
x
N
′
−
1
)
E(a_0+0x^1+0x^2+...+0x^{N'-1})
E(a0+0x1+0x2+...+0xN′−1),实际上,这就是密文Enc([
a
0
,
a
0
,
.
.
.
,
a
0
a_0,a_0,...,a_0
a0,a0,...,a0])。完整的操作流程如下:
类似地,为了提取E(
a
1
+
0
x
1
+
.
.
.
+
0
x
N
′
−
1
a_1+0x^1+...+0x^{N'-1}
a1+0x1+...+0xN′−1),S应该左旋明文多项式p(x)一个单位,通过乘以
x
−
1
x^{-1}
x−1,然后再次执行上述的提取过程。通过执行
N
‘
N‘
N‘次该提取过程,S可以获得向量[
a
0
,
a
1
,
.
.
.
,
a
N
′
−
1
a_0,a_1,...,a_{N'-1}
a0,a1,...,aN′−1]中每个元素的单独SIMD格式加密。
然而上述过程需要执行
(
N
′
⋅
l
o
g
N
′
)
(N'·logN')
(N′⋅logN′)次
S
u
b
s
(
)
Subs()
Subs()操作。对比之下,本文提出一种算法,可以实现相同的目标,但是仅需要
2
N
′
2N'
2N′次
S
u
b
s
(
)
Subs()
Subs()操作。该算法可以简单地描述如下:
算法5是Secure Decompression的详细描述:
下面提供上述分解操作的理论证明:
Theorem 1:
仅有常数项的多项式的加密E(
a
s
+
0
x
1
+
.
.
.
+
0
x
N
′
−
1
a_s+0x^1+...+0x^{N'-1}
as+0x1+...+0xN′−1)是向量[
a
s
,
a
s
,
.
.
.
,
a
s
a_s,a_s,...,a_s
as,as,...,as]的加密Enc([
a
s
,
a
s
,
.
.
.
,
a
s
a_s,a_s,...,a_s
as,as,...,as])。
4.1 Application to matrix multiplication
压缩分解技术可以自然地应用于MatrixMul,此外,基于下面的观察结果,本文进一步优化了矩阵乘法,观察到在transformer推理过程中,对于不同输入的矩阵
A
∈
R
m
×
n
A\in \R^{m\times n}
A∈Rm×n需要乘以相同的矩阵
W
∈
R
n
×
k
W\in \R^{n\times k}
W∈Rn×k。
设
A
=
[
a
0
,
.
.
.
,
a
n
−
1
]
A = [a_0,...,a_{n-1}]
A=[a0,...,an−1],其中
a
i
∈
R
m
a_i\in \R^m
ai∈Rm表示矩阵
A
A
A的第
i
i
i行。假设S和C需要生成t个响应词,即有t个输入矩阵:
A
0
=
[
a
0
,
0
,
a
0
,
1
,
.
.
.
,
a
0
,
n
−
1
]
A
1
=
[
a
1
,
0
,
a
1
,
1
,
.
.
.
,
a
1
,
n
−
1
]
.
.
.
A
0
=
[
a
t
−
1
,
0
,
a
t
−
1
,
1
,
.
.
.
,
a
t
−
1
,
n
−
1
]
A_0=[a_{0,0},a_{0,1},...,a_{0,n-1}] \\ A_1=[a_{1,0},a_{1,1},...,a_{1,n-1}] \\ ... \\ A_0=[a_{t-1,0},a_{t-1,1},...,a_{t-1,n-1}]
A0=[a0,0,a0,1,...,a0,n−1]A1=[a1,0,a1,1,...,a1,n−1]...A0=[at−1,0,at−1,1,...,at−1,n−1]
令
a
i
′
=
[
a
0
,
i
a
1
,
i
.
.
.
a
t
−
1
,
i
]
a'_i=\left[\begin{matrix} a_{0,i} \\ a_{1,i} \\ ... \\ a_{t-1,i} \end{matrix} \right]
ai′=a0,ia1,i...at−1,i和
q
j
′
:
=
∑
i
=
0
n
−
1
a
i
′
w
i
,
j
∀
j
∈
[
k
]
q'_j:=\sum^{n-1}_{i=0}a'_iw_{i,j}\ \ \ \forall j\in [k]
qj′:=∑i=0n−1ai′wi,j ∀j∈[k],则有
Q
′
=
q
0
′
∣
∣
q
1
′
∣
∣
.
.
.
∣
∣
q
k
−
1
′
=
[
A
0
W
A
1
W
.
.
.
A
t
−
1
W
]
Q'=q'_0||q'_1||...||q'_{k-1}=\left[\begin{matrix} A_0W \\ A_1W \\ ... \\ A_{t-1}W \end{matrix} \right]
Q′=q0′∣∣q1′∣∣...∣∣qk−1′=A0WA1W...At−1W
预计算阶段:
这里,我们引入一个预计算阶段,其中S使用上述提到的密文压缩技术,将压缩后的密文
(
E
n
c
S
(
[
w
i
,
j
,
w
i
,
j
,
.
.
.
,
w
i
.
j
⏟
t
×
m
]
)
∀
i
∈
[
n
]
,
j
∈
[
k
]
)
(Enc_S([\underbrace{w_{i,j},w_{i,j},...,w_{i.j}}_{t\times m}])\ \ \forall i \in [n],j\in [k])
(EncS([t×mwi,j,wi,j,...,wi.j]) ∀i∈[n],j∈[k])发送给C。注意,该传输仅只发生一次,除非模型发生改变。接下来,C对压缩的密文执行分解技术,以获得
E
n
c
S
(
[
w
i
,
j
,
w
i
,
j
,
.
.
.
,
w
i
.
j
⏟
t
×
m
]
)
∀
i
∈
[
n
]
,
j
∈
[
k
]
Enc_S([\underbrace{w_{i,j},w_{i,j},...,w_{i.j}}_{t\times m}])\ \ \forall i \in [n],j\in [k]
EncS([t×mwi,j,wi,j,...,wi.j]) ∀i∈[n],j∈[k]。在预计算阶段C并没有关于输入的信息,采样
U
∈
R
(
t
m
)
×
n
U\in \R^{(tm)\times n}
U∈R(tm)×n,然后计算:
E
n
c
S
(
v
j
)
←
∑
i
=
0
n
−
1
(
u
i
×
E
n
c
S
(
[
w
i
,
j
,
.
.
.
,
w
i
,
j
]
)
)
∀
j
∈
[
k
]
Enc_S(v_j)\leftarrow\sum^{n-1}_{i=0}(u_i\times Enc_S([w_{i,j},...,w_{i,j}]))\ \ \ \forall j\in[k]
EncS(vj)←i=0∑n−1(ui×EncS([wi,j,...,wi,j])) ∀j∈[k]
其中
u
i
u_i
ui是矩阵
U
U
U的第i列。接下来,C使用自己的密钥来加密
E
n
c
S
(
v
j
)
Enc_S(v_j)
EncS(vj)以获得
E
n
c
C
(
E
n
c
S
(
v
j
)
)
Enc_C(Enc_S(v_j))
EncC(EncS(vj)),并将其发送给S。注意
E
n
c
S
(
E
n
c
C
(
v
j
)
)
=
E
n
c
C
(
E
n
c
S
(
v
j
)
)
Enc_S(Enc_C(v_j))=Enc_C(Enc_S(v_j))
EncS(EncC(vj))=EncC(EncS(vj)),故S可以对其进行解密,从而获得
E
n
c
C
(
v
j
)
Enc_C(v_j)
EncC(vj)。注意,这里的
v
j
v_j
vj是矩阵
U
⋅
W
U·W
U⋅W的第
j
j
j列。
切换不同用户加密密钥的过程计算如下:
给定
c
t
S
=
(
−
a
s
S
+
m
+
e
)
ct_S=(-as_S+m+e)
ctS=(−asS+m+e),
使用
s
C
s_C
sC加密有
c
t
C
,
S
=
(
−
a
s
S
−
a
s
C
+
m
+
e
+
e
′
,
a
)
ct_{C,S}=(-as_S-as_C+m+e+e',a)
ctC,S=(−asS−asC+m+e+e′,a),
使用
s
S
s_S
sS解密有:
c
t
C
=
(
−
a
s
S
−
a
s
C
+
m
+
e
+
e
′
,
a
)
+
(
a
s
S
,
0
)
=
(
−
a
s
C
+
m
+
e
+
e
′
)
ct_C=(-as_S-as_C+m+e+e',a)+(as_S,0)=(-as_C+m+e+e')
ctC=(−asS−asC+m+e+e′,a)+(asS,0)=(−asC+m+e+e′).
在线处理阶段:
此时,C知道输入的信息
A
′
=
a
0
′
∣
∣
a
1
′
∣
∣
.
.
.
∣
∣
a
n
−
1
′
A'=a'_0||a'_1||...||a'_{n-1}
A′=a0′∣∣a1′∣∣...∣∣an−1′,然后C将明文
(
A
′
−
U
)
(A'-U)
(A′−U)发送给S,注意,由于S不知道U的值,故S也不清楚
A
′
A'
A′的值,然后S可以计算:
(
A
′
−
U
)
⋅
W
+
(
E
n
c
C
(
v
0
)
∣
∣
E
n
c
C
(
v
1
)
∣
∣
.
.
.
∣
∣
E
n
c
C
(
v
k
−
1
)
)
=
(
A
′
W
−
V
)
+
(
E
n
c
C
(
v
0
)
∣
∣
E
n
c
C
(
v
1
)
∣
∣
.
.
.
∣
∣
E
n
c
C
(
v
k
−
1
)
)
=
(
E
n
c
C
(
q
0
′
)
∣
∣
E
n
c
C
(
q
1
′
)
∣
∣
.
.
.
∣
∣
E
n
c
C
(
q
k
−
1
′
)
)
(A'-U)·W + (Enc_C(v_0)||Enc_C(v_1)||...||Enc_C(v_{k-1}))\\ =(A'W-V)+(Enc_C(v_0)||Enc_C(v_1)||...||Enc_C(v_{k-1})) \\ =(Enc_C(q'_0)||Enc_C(q'_1)||...||Enc_C(q'_{k-1}))
(A′−U)⋅W+(EncC(v0)∣∣EncC(v1)∣∣...∣∣EncC(vk−1))=(A′W−V)+(EncC(v0)∣∣EncC(v1)∣∣...∣∣EncC(vk−1))=(EncC(q0′)∣∣EncC(q1′)∣∣...∣∣EncC(qk−1′))
其中
q
j
′
q'_j
qj′是矩阵
Q
′
Q'
Q′的第
j
j
j列。算法6描述了优化后的矩阵乘法细节:
可以注意到,只需要预计算的过程交互一次,此后C可以直接向S发送明文信息
A
′
−
U
A'-U
A′−U,而不会泄露
A
′
A'
A′的信息。
5 SIMD槽折叠算法
回想矩阵Q,K,V的行向量是使用SIMD方式加密的。而上述介绍的一系列操作,如内积,Softmax,LayerNorm和Argmax等, 均涉及到利用所有槽元素计算函数
f
(
∗
)
f(*)
f(∗),并将得到的结果放置到所有槽上。例如给定
E
n
c
(
[
a
0
,
.
.
.
,
a
N
−
1
]
)
Enc([a_0,...,a_{N-1}])
Enc([a0,...,aN−1]),然后想要获得
E
n
c
(
[
s
,
.
.
.
,
s
]
)
Enc([s,...,s])
Enc([s,...,s]),其中
s
=
∑
i
=
0
N
−
1
a
i
s=\sum^{N-1}_{i=0}a_i
s=∑i=0N−1ai,此时
f
(
∗
)
f(*)
f(∗)即求和函数。
本节提供了一种通用的解决方案,只要函数
f
(
∗
)
f(*)
f(∗)满足:
f
(
f
(
a
0
,
a
1
)
,
a
2
)
=
f
(
a
0
,
f
(
a
1
,
a
2
)
)
f(f(a_0,a_1),a_2)=f(a_0,f(a_1,a_2))
f(f(a0,a1),a2)=f(a0,f(a1,a2))
算法7描述了槽折叠算法的细节:
这里是一个简单的例子,可以看到算法7的实现流程:
5.1 QuickSum
给定
[
a
0
,
a
1
,
.
.
.
,
a
n
−
1
,
0
,
.
.
.
,
0
]
[a_0,a_1,...,a_{n-1},0,...,0]
[a0,a1,...,an−1,0,...,0],为了获得
[
∑
i
=
0
N
−
1
a
i
,
.
.
.
,
∑
i
=
0
N
−
1
a
i
,
0
,
.
.
.
,
0
]
[\sum^{N-1}_{i=0}a_i,...,\sum^{N-1}_{i=0}a_i,0,...,0]
[∑i=0N−1ai,...,∑i=0N−1ai,0,...,0],可以将算法7的第5行替换为
s
~
←
s
~
+
a
~
\tilde{s}\leftarrow\tilde{s}+\tilde{a}
s~←s~+a~。
5.2 QuickMax
给定
[
a
0
,
a
1
,
.
.
.
,
a
n
−
1
,
0
,
.
.
.
,
0
]
[a_0,a_1,...,a_{n-1},0,...,0]
[a0,a1,...,an−1,0,...,0],为了获得
a
m
a
x
,
.
.
.
,
a
m
a
x
,
0
,
.
.
.
,
0
a_{max},...,a_{max},0,...,0
amax,...,amax,0,...,0,其中
a
m
a
x
=
m
a
x
(
a
0
,
a
1
,
.
.
.
,
a
n
−
1
)
a_{max}=max(a_0,a_1,...,a_{n-1})
amax=max(a0,a1,...,an−1),很明显
m
a
x
(
a
,
b
)
max(a,b)
max(a,b)可以表示为:
m
a
x
(
a
,
b
)
=
a
+
b
+
(
a
−
b
)
⋅
S
g
n
(
a
−
b
)
2
max(a,b)={a+b+(a-b)·Sgn(a-b)\over 2}
max(a,b)=2a+b+(a−b)⋅Sgn(a−b)
因此可以将算法7的第5行替换为:
s
~
←
0.5
⊗
(
a
~
⊕
s
~
⊕
(
a
~
⊖
s
~
)
⊗
S
g
n
(
a
~
⊖
s
~
)
)
\tilde{s}\leftarrow 0.5\otimes(\tilde{a}\oplus\tilde{s}\oplus(\tilde{a}\ominus\tilde{s})\otimes Sgn(\tilde{a}\ominus\tilde{s}))
s~←0.5⊗(a~⊕s~⊕(a~⊖s~)⊗Sgn(a~⊖s~))
6. Conclusion
本文提出了NEXUS系统,可以说是第一个不需要客户端和服务器进行交互的安全transformer推理协议。本文提出了适用于RNS-CKKS的一系列新协议,以使得服务器可以高效且精确的在加密数据上计算transformer的每一层。
版权归原作者 lucky_wjie 所有, 如有侵权,请联系我们删除。