文章目录
多头注意力
在实践中,当给定 相同的查询、键和值的集合 时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces) 可能是有益的。
为此,与其只使用单独一个注意力汇聚,我们可以用独立学习得到的
h
h
h 组不同的**线性投影(linear projections)** 来变换查询、键和值。然后,这
h
h
h 组变换后的查询、键和值将**并行**地送到注意力汇聚中。最后,将这
h
h
h 个注意力汇聚的输出**拼接**在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为**多头注意力(multihead attention)**。对于
h
h
h 个注意力汇聚输出,每一个注意力汇聚都被称作一个**头(head)**。
本质地讲,自注意力机制是:通过某种运算来直接计算得到句子在编码过程中每个位置上的注意力权重;然后再以权重和的形式来计算得到整个句子的隐含向量表示。
自注意力机制的缺陷是:模型在对当前位置的信息进行编码时,会过度的将注意力集中于自身的位置, 因此作者提出了通过多头注意力机制来解决这一问题。
下图展示了使用全连接层来实现可学习的线性变换的多头注意力。
模型
在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询
q
∈
R
d
q
\mathbf{q} \in \mathbb{R}^{d_q}
q∈Rdq、键
k
∈
R
d
k
\mathbf{k} \in \mathbb{R}^{d_k}
k∈Rdk和值
v
∈
R
d
v
\mathbf{v} \in \mathbb{R}^{d_v}
v∈Rdv,每个注意力头
h
i
\mathbf{h}_i
hi(
i
=
1
,
…
,
h
i = 1, \ldots, h
i=1,…,h)的计算方法为:
h
i
=
f
(
W
i
(
q
)
q
,
W
i
(
k
)
k
,
W
i
(
v
)
v
)
∈
R
p
v
,
\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},
hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv,
其中,可学习的参数包括
W
i
(
q
)
∈
R
p
q
×
d
q
\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}
Wi(q)∈Rpq×dq、
W
i
(
k
)
∈
R
p
k
×
d
k
\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}
Wi(k)∈Rpk×dk和
W
i
(
v
)
∈
R
p
v
×
d
v
\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}
Wi(v)∈Rpv×dv,以及代表注意力汇聚的函数
f
f
f。
f
f
f 可以是之前学习的**加性注意力**和**缩放点积注意力**。多头注意力的输出需要经过另一个线性转换,它对应着
h
h
h 个头连结后的结果,因此其可学习参数是
W
o
∈
R
p
o
×
h
p
v
\mathbf W_o\in\mathbb R^{p_o\times h p_v}
Wo∈Rpo×hpv:
W
o
[
h
1
⋮
h
h
]
∈
R
p
o
.
\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.
Wo⎣⎡h1⋮hh⎦⎤∈Rpo.
基于这种设计,每个头都可能会关注输入的不同部分,可以表示比简单加权平均值更复杂的函数。
import math
import torch
from torch import nn
from d2l import torch as d2l
实现
在实现过程中,我们选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,我们设定
p
q
=
p
k
=
p
v
=
p
o
/
h
p_q = p_k = p_v = p_o / h
pq=pk=pv=po/h。值得注意的是,如果我们将查询、键和值的线性变换的输出数量设置为
p
q
h
=
p
k
h
=
p
v
h
=
p
o
p_q h = p_k h = p_v h = p_o
pqh=pkh=pvh=po,则可以并行计算
h
h
h 个头。在下面的实现中,
p
o
p_o
po是通过参数 num_hiddens 指定的。
classMultiHeadAttention(nn.Module):"""多头注意力"""def__init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False,**kwargs):super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)defforward(self, queries, keys, values, valid_lens):# queries, keys, values的形状:# (batch_size,查询或“键-值”对的个数,num_hiddens)# valid_len 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values的形状:# (batch_size*num_heads,查询或“键-值”个数,num_hiddens/num_head)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens isnotNone:# 在轴0,将第一项(标量或矢量) 复制 num_heads次,# 然后如此复制第二项,然后诸如此类
valid_lens = torch.repeat_interleave(valid_lens,
repeats=self.num_heads,
dim=0)# output的形状:(batch_size*num_heads, 查询个数,num_hiddens/num_head)
output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size, 查询个数,num_hiddens)
output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)
为了能够使多个头并行计算,上面的 MultiHeadAttention 类将使用下面定义的两个转置函数。具体来说,transpose_output 函数反转了 transpose_qkv 函数的操作。
deftranspose_qkv(X, num_heads):"""为了多头注意力的并行计算而变换形状"""# 输入X的形状(batch_size, 查询或”键-值“对的个数,num_hiddens)# 输出X的形状(batch_size,查询或”键-值“对的个数,# num_heads,num_hiddens/num_heads)
X = X.reshape(X.shape[0], X.shape[1], num_heads,-1)# 输出X的形状(batch_size,# num_heads,查询或”键-值“对的个数,num_hiddens/num_heads)
X = X.permute(0,2,1,3)# 输出X的形状(batch_size*num_heads,# 查询或”键-值“对的个数,num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])deftranspose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""# 输入X的形状(batch_size*num_heads,# 查询或”键-值“对的个数,num_hiddens/num_heads)# 输出X的形状(batch_size,# num_heads,查询或”键-值“对的个数,num_hiddens/num_heads)
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])# 输出X的形状(batch_size,查询或”键-值“对的个数,# num_heads,num_hiddens/num_heads)
X = X.permute(0,2,1,3)# 输出X的形状(batch_size,查询或”键-值“对的个数,num_hiddens)return X.reshape(X.shape[0], X.shape[1],-1)
下面我们使用键和值相同的小例子来测试我们编写的 MultiHeadAttention 类。多头注意力输出的形状是 (batch_size,num_queries, num_hiddens)。
num_hiddens, num_heads =100,5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads,0.5)
attention.eval()
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries =2,4
num_kvpairs, valid_lens =6, torch.tensor([3,2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])
小结
1、多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
2、基于适当的张量操作,可以实现多头注意力的并行计算。
版权归原作者 Gaolw1102 所有, 如有侵权,请联系我们删除。