0


sliding window attention

sliding window attention是为了解决在输出序列长度sequence length很大的时候attention计算量爆炸增长的问题。

用一句话来总结sliding window attention其实就是:每一个token只和包含其本身在内的前W个token做Attention。最简单的实现其实就是给不需要计算attention的其它token都加上一个mask就可以了,是不是非常简单?

用图片更直观一些,如下(图片来源:图解Mixtral 8 * 7b推理优化原理与源码实现):

核心代码如下:

def scaled_dot_product_attention(q, k, v, window_size, mask=None):
    matmul_qk = torch.matmul(q, k.transpose(-1, -2))
    dk = torch.tensor(k.shape[-1], dtype=torch.float32)
    scaled_attention_logits = matmul_qk / torch.sqrt(dk)
    if mask is not None:
        scaled_attention_logits += mask * -1e9

    # 添加sliding window attention
    seq_len = q.shape[-2]
    window_mask = torch.full((seq_len, seq_len), float("-inf"), device=q.device)
    for i in range(seq_len):
        # 计算sliding window attention的起始位置
        start = max(0, i - window_size)
        end = min(seq_len, i + window_size + 1)
        window_mask[i, start:end] = 0
    scaled_attention_logits += window_mask  # 其它位置由于是-inf会被mask掉

    attention_weights = F.softmax(scaled_attention_logits, dim=-1)
    output = torch.matmul(attention_weights, v)
    return (
        output,
        attention_weights,
    )

同时,这种方式并不是意味着当前token只能获取到前window_size个token的信息,因为当前token前面的window_size个token也都是能够获取到前面的信息的,因此只要网络达到一定的深度,这样的sliding window attention是可行的,并不会损失太多信息。

简单实践结果:

使用window_size=5训练:

不使用sliding window attention训练:

可以发现在训练时间上sliding window attention明显更快,同时性能上也相差不大。不过我这个实践比较简陋,用的corpus也都是一些中短句,有兴趣的朋友也可以自己试试!

未完待续……


本文转载自: https://blog.csdn.net/m0_62053105/article/details/140154494
版权归原作者 小麦要吃麦当劳 所有, 如有侵权,请联系我们删除。

“sliding window attention”的评论:

还没有评论