0


用于Transformer的6种注意力的数学原理和代码实现

Transformer 的出色表现让注意力机制出现在深度学习的各处。本文整理了深度学习中最常用的6种注意力机制的数学原理和代码实现。

1、Full Attention

2017的《Attention is All You Need》中的编码器-解码器结构实现中提出。它结构并不复杂,所以不难理解。

上图 1.左侧显示了 Scaled Dot-Product Attention 的机制。当我们有多个注意力时,我们称之为多头注意力(右),这也是最常见的注意力的形式公式如下:

公式1

这里Q(Query)、K(Key)和V(values)被认为是它的输入,dₖ(输入维度)被用来降低复杂度和计算成本。这个公式可以说是深度学习中注意力机制发展的开端。下面我们看一下它的代码:

class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)
        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)
        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)
        if self.output_attention:
            return (V.contiguous(), A)
        else:
            return (V.contiguous(), None)

2、ProbSparse Attention

借助“Transformer Dissection: A Unified Understanding of Transformer's Attention via the lens of Kernel”中的信息我们可以将公式修改为下面的公式2。第i个query的attention就被定义为一个概率形式的核平滑方法(kernel smoother):

公式2

从公式 2,我们可以定义第 i 个查询的稀疏度测量如下:

最后,注意力块的最终公式是下面的公式4。

代码如下:

class ProbAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(ProbAttention, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
    def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
        # Q [B, H, L, D]
        B, H, L_K, E = K.shape
        _, _, L_Q, _ = Q.shape
        # calculate the sampled Q_K
        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
        K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
        Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)
        # find the Top_k query with sparisty measurement
        M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
        M_top = M.topk(n_top, sorted=False)[1]
        # use the reduced Q to calculate Q_K
        Q_reduce = Q[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :] # factor*ln(L_q)
        Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
        return Q_K, M_top
    def _get_initial_context(self, V, L_Q):
        B, H, L_V, D = V.shape
        if not self.mask_flag:
            # V_sum = V.sum(dim=-2)
            V_sum = V.mean(dim=-2)
            contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
        else: # use mask
            assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only
            contex = V.cumsum(dim=-2)
        return contex
    def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
        B, H, L_V, D = V.shape
        if self.mask_flag:
            attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)

        attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)

        context_in[torch.arange(B)[:, None, None],
                   torch.arange(H)[None, :, None],
                   index, :] = torch.matmul(attn, V).type_as(context_in)
        if self.output_attention:
            attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
            attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
            return (context_in, attns)
        else:
            return (context_in, None)
    def forward(self, queries, keys, values, attn_mask):
        B, L_Q, H, D = queries.shape
        _, L_K, _, _ = keys.shape

        queries = queries.transpose(2,1)
        keys = keys.transpose(2,1)
        values = values.transpose(2,1)

        U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
        u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 

        U_part = U_part if U_part<L_K else L_K
        u = u if u<L_Q else L_Q
        
        scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) 

        # add scale factor
        scale = self.scale or 1./sqrt(D)
        if scale is not None:
            scores_top = scores_top * scale
        # get the context
        context = self._get_initial_context(values, L_Q)
        # update the context with selected top_k queries
        context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
        
        return context.transpose(2,1).contiguous(), attn

这也是 Informer这个用于长序列时间序列预测的新型Transformer中使用的注意力。

3、LogSparse Attention

我们之前讨论的注意力有两个缺点:1. 与位置无关 2. 内存的瓶颈。为了应对这两个问题,研究人员使用了卷积算子和 LogSparse Transformers。

Transformer 中相邻层之间不同注意力机制的图示

卷积自注意力显示在(右)中,它使用步长为 1,内核大小为 k 的卷积层将输入(具有适当的填充)转换为Q/K。这种位置感知可以根据(左)中的形状正确匹配最相关的特征

他们不是使用步长为 1,卷积核大小 1,而是使用步长为 1 ,核大小为 k 的随意卷积(以确保模型无法访问未来的点)将输入转换为Q和K

代码实现

class Attention(nn.Module):
    def __init__(self, n_head, n_embd, win_len, scale, q_len, sub_len, sparse=None, attn_pdrop=0.1, resid_pdrop=0.1):
        super(Attention, self).__init__()

        if(sparse):
            print('Activate log sparse!')
            mask = self.log_mask(win_len, sub_len)
        else:
            mask = torch.tril(torch.ones(win_len, win_len)).view(1, 1, win_len, win_len)

        self.register_buffer('mask_tri', mask)
        self.n_head = n_head
        self.split_size = n_embd * self.n_head
        self.scale = scale
        self.q_len = q_len
        self.query_key = nn.Conv1d(n_embd, n_embd * n_head * 2, self.q_len)
        self.value = Conv1D(n_embd * n_head, 1, n_embd)
        self.c_proj = Conv1D(n_embd, 1, n_embd * self.n_head)
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)

    def log_mask(self, win_len, sub_len):
        mask = torch.zeros((win_len, win_len), dtype=torch.float)
        for i in range(win_len):
            mask[i] = self.row_mask(i, sub_len, win_len)
        return mask.view(1, 1, mask.size(0), mask.size(1))

    def row_mask(self, index, sub_len, win_len):
        """
        Remark:
        1 . Currently, dense matrices with sparse multiplication are not supported by Pytorch. Efficient implementation
            should deal with CUDA kernel, which we haven't implemented yet.
        2 . Our default setting here use Local attention and Restart attention.
        3 . For index-th row, if its past is smaller than the number of cells the last
            cell can attend, we can allow current cell to attend all past cells to fully
            utilize parallel computing in dense matrices with sparse multiplication."""
        log_l = math.ceil(np.log2(sub_len))
        mask = torch.zeros((win_len), dtype=torch.float)
        if((win_len // sub_len) * 2 * (log_l) > index):
            mask[:(index + 1)] = 1
        else:
            while(index >= 0):
                if((index - log_l + 1) < 0):
                    mask[:index] = 1
                    break
                mask[index - log_l + 1:(index + 1)] = 1  # Local attention
                for i in range(0, log_l):
                    new_index = index - log_l + 1 - 2**i
                    if((index - new_index) <= sub_len and new_index >= 0):
                        mask[new_index] = 1
                index -= sub_len
        return mask

    def attn(self, query: torch.Tensor, key, value: torch.Tensor, activation="Softmax"):
        activation = activation_dict[activation](dim=-1)
        pre_att = torch.matmul(query, key)
        if self.scale:
            pre_att = pre_att / math.sqrt(value.size(-1))
        mask = self.mask_tri[:, :, :pre_att.size(-2), :pre_att.size(-1)]
        pre_att = pre_att * mask + -1e9 * (1 - mask)
        pre_att = activation(pre_att)
        pre_att = self.attn_dropout(pre_att)
        attn = torch.matmul(pre_att, value)

        return attn

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

    def forward(self, x):

        value = self.value(x)
        qk_x = nn.functional.pad(x.permute(0, 2, 1), pad=(self.q_len - 1, 0))
        query_key = self.query_key(qk_x).permute(0, 2, 1)
        query, key = query_key.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        attn = self.attn(query, key, value)
        attn = self.merge_heads(attn)
        attn = self.c_proj(attn)
        attn = self.resid_dropout(attn)
        return attn

class Conv1D(nn.Module):
    def __init__(self, out_dim, rf, in_dim):
        super(Conv1D, self).__init__()
        self.rf = rf
        self.out_dim = out_dim
        if rf == 1:
            w = torch.empty(in_dim, out_dim)
            nn.init.normal_(w, std=0.02)
            self.w = Parameter(w)
            self.b = Parameter(torch.zeros(out_dim))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.rf == 1:
            size_out = x.size()[:-1] + (self.out_dim,)
            x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
            x = x.view(*size_out)
        else:
            raise NotImplementedError
        return x

来自:https://github.com/AIStream-Peelout/flow-forecast/

4、LSH Attention

Reformer的论文选择了局部敏感哈希的angular变体。它们首先约束每个输入向量的L2范数(即将向量投影到一个单位球面上),然后应用一系列的旋转,最后找到每个旋转向量所属的切片。这样一来就需要找到最近邻的值,这就需要局部敏感哈希(LSH)了,它能够快速在高维空间中找到最近邻。一个局部敏感哈希算法可以将每个向量 x 转换为 hash h(x),和这个 x 靠近的哈希更有可能有着相同的哈希值,而距离远的则不会。作者希望最近的向量最可能得到相同的哈希值,或者 hash-bucket 大小相似的更有可能相同。

局部敏感哈希算法使用球投影点的随机旋转,通过argmax在有符号的轴投影上建立bucket。在这个高度简化的2D描述中,对于三个不同的角哈希,两个点x和y不太可能共享相同的哈希桶(上图),除非它们的球面投影彼此接近(下图)。

通过固定一个大小为 [dₖ, b/2] 的随机矩阵 R 来获得 b 个哈希值。h(x) = argmax([xR;-xR]) 其中 [u;v] 表示两个向量的串联。这样就可以使用LSH,将查询位置I带入重写公式1:

下图我们可以示意性地解释 LSH 注意力:

  • 原始的注意力矩阵通常是稀疏的,但不利于计算
  • LSH Attention基于哈希桶进行键的排序进行查询
  • 在排序后的注意矩阵中,来自同一桶的对将聚集在对角线附近
  • 采用批处理方法,m个连续查询的块相互处理,一个块返回。

代码很长为了节约时间这里就不贴了:

https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reformer_pytorch.py

5、Sparse Attention(Generating Long Sequences with Sparse Transformers)

OpenAI的Sparse Attention,通过“只保留小区域内的数值、强制让大部分注意力为零”的方式,来减少Attention的计算量。通过top-k选择,将注意退化为稀疏注意。这样,保留最有助于引起注意的部分,并删除其他无关的信息。这种选择性方法在保存重要信息和消除噪声方面是有效的。注意力可以更多地集中在最有贡献的价值因素上。

代码:https://github.com/openai/sparse_attention

6、Single-Headed Attention(Single Headed Attention RNN: Stop Thinking With Your Head)

SHA-RNN模型的注意力是简化到只保留了一个头并且唯一的矩阵乘法出现在query (下图Q) 那里,A是缩放点乘注意力 (Scaled Dot-Product Attention) ,是向量之间的运算。所以这种计算量比较小,能够快速的进行训练,就像它介绍的那样:

Obtain strong results on a byte level language modeling dataset (enwik8) in under 24 hours on a single GPU (12GB Titan V)

代码:https://github.com/Smerity/sha-rnn

引用:

  1. Kitaev, N., Ł. Kaiser, and A. Levskaya, Reformer: The efficient transformer. arXiv preprint arXiv:2001.04451, 2020.
  2. Li, S., et al., Enhancing the locality and breaking the memory bottleneck of transformer on time series forecasting. Advances in Neural Information Processing Systems, 2019. 32.
  3. Zhou, H., et al. Informer: Beyond efficient transformer for long sequence time-series forecasting. in Proceedings of AAAI. 2021.
  4. Vaswani, A., et al., Attention is all you need. Advances in neural information processing systems, 2017. 30.
  5. Rewon Child, Scott Gray, Alec Radford, Ilya Sutskever Generating Long Sequences with Sparse Transformers
  6. Stephen Merity ,Single Headed Attention RNN: Stop Thinking With Your Head

注意机制的发展到现在远远不止这些,在本篇文章中只整理了一些常见的注意力机制,希望对你有所帮助。

另外就是来自Erasmus University的Gianni Brauwers 和Flavius Frasincar在TKDE上发表的《A General Survey on Attention Mechanisms in Deep Learning》综述论文,提供了一个关于深度学习注意力机制的重要概述。各种注意力机制通过一个由注意力模型,统一符号和一个全面的分类注意力机制组成的框架来进行解释,还有注意力模型评价的各种方法。

有兴趣和有资源的话可以进行阅读,女神Alexandra Elbakyan的网站还未提供该论文。

作者:Reza Yazdanfar

“用于Transformer的6种注意力的数学原理和代码实现”的评论:

还没有评论