sliding window attention是为了解决在输出序列长度sequence length很大的时候attention计算量爆炸增长的问题。
用一句话来总结sliding window attention其实就是:每一个token只和包含其本身在内的前W个token做Attention。最简单的实现其实就是给不需要计算attention的其它token都加上一个mask就可以了,是不是非常简单?
用图片更直观一些,如下(图片来源:图解Mixtral 8 * 7b推理优化原理与源码实现):
核心代码如下:
def scaled_dot_product_attention(q, k, v, window_size, mask=None):
matmul_qk = torch.matmul(q, k.transpose(-1, -2))
dk = torch.tensor(k.shape[-1], dtype=torch.float32)
scaled_attention_logits = matmul_qk / torch.sqrt(dk)
if mask is not None:
scaled_attention_logits += mask * -1e9
# 添加sliding window attention
seq_len = q.shape[-2]
window_mask = torch.full((seq_len, seq_len), float("-inf"), device=q.device)
for i in range(seq_len):
# 计算sliding window attention的起始位置
start = max(0, i - window_size)
end = min(seq_len, i + window_size + 1)
window_mask[i, start:end] = 0
scaled_attention_logits += window_mask # 其它位置由于是-inf会被mask掉
attention_weights = F.softmax(scaled_attention_logits, dim=-1)
output = torch.matmul(attention_weights, v)
return (
output,
attention_weights,
)
同时,这种方式并不是意味着当前token只能获取到前window_size个token的信息,因为当前token前面的window_size个token也都是能够获取到前面的信息的,因此只要网络达到一定的深度,这样的sliding window attention是可行的,并不会损失太多信息。
简单实践结果:
使用window_size=5训练:
不使用sliding window attention训练:
可以发现在训练时间上sliding window attention明显更快,同时性能上也相差不大。不过我这个实践比较简陋,用的corpus也都是一些中短句,有兴趣的朋友也可以自己试试!
未完待续……
版权归原作者 小麦要吃麦当劳 所有, 如有侵权,请联系我们删除。