0


sliding window attention

sliding window attention是为了解决在输出序列长度sequence length很大的时候attention计算量爆炸增长的问题。

用一句话来总结sliding window attention其实就是:每一个token只和包含其本身在内的前W个token做Attention。最简单的实现其实就是给不需要计算attention的其它token都加上一个mask就可以了,是不是非常简单?

用图片更直观一些,如下(图片来源:图解Mixtral 8 * 7b推理优化原理与源码实现):

核心代码如下:

  1. def scaled_dot_product_attention(q, k, v, window_size, mask=None):
  2. matmul_qk = torch.matmul(q, k.transpose(-1, -2))
  3. dk = torch.tensor(k.shape[-1], dtype=torch.float32)
  4. scaled_attention_logits = matmul_qk / torch.sqrt(dk)
  5. if mask is not None:
  6. scaled_attention_logits += mask * -1e9
  7. # 添加sliding window attention
  8. seq_len = q.shape[-2]
  9. window_mask = torch.full((seq_len, seq_len), float("-inf"), device=q.device)
  10. for i in range(seq_len):
  11. # 计算sliding window attention的起始位置
  12. start = max(0, i - window_size)
  13. end = min(seq_len, i + window_size + 1)
  14. window_mask[i, start:end] = 0
  15. scaled_attention_logits += window_mask # 其它位置由于是-inf会被mask掉
  16. attention_weights = F.softmax(scaled_attention_logits, dim=-1)
  17. output = torch.matmul(attention_weights, v)
  18. return (
  19. output,
  20. attention_weights,
  21. )

同时,这种方式并不是意味着当前token只能获取到前window_size个token的信息,因为当前token前面的window_size个token也都是能够获取到前面的信息的,因此只要网络达到一定的深度,这样的sliding window attention是可行的,并不会损失太多信息。

简单实践结果:

使用window_size=5训练:

不使用sliding window attention训练:

可以发现在训练时间上sliding window attention明显更快,同时性能上也相差不大。不过我这个实践比较简陋,用的corpus也都是一些中短句,有兴趣的朋友也可以自己试试!

未完待续……


本文转载自: https://blog.csdn.net/m0_62053105/article/details/140154494
版权归原作者 小麦要吃麦当劳 所有, 如有侵权,请联系我们删除。

“sliding window attention”的评论:

还没有评论