0


self-attention(自注意力) 和 cross-attetion(交叉注意力) 中的差异

self attention(自注意力)

在这里插入图片描述
来源: https://arxiv.org/pdf/1706.03762

计算实现:
1.计算出

  1. score
  2. =
  3. Q
  4. K
  5. T
  6. \text{score} =QK^T
  7. score=QKT

2.计算

  1. attention
  2. =
  3. s
  4. o
  5. f
  6. t
  7. m
  8. a
  9. x
  10. (
  11. score
  12. )
  13. \text{attention}=softmax(\text{score})
  14. attention=softmax(score)
  1. 计算

    1. weighted
    2. =
    3. a
    4. t
    5. t
    6. e
    7. n
    8. t
    9. i
    10. o
    11. n
    12. V
    13. \text{weighted}=attention*V

    weighted=attention∗V

1.计算复杂度是

  1. O
  2. (
  3. L
  4. 2
  5. )
  6. O(L^2)
  7. O(L2)

2.因为需要计算 LXL 的 注意力矩阵

  1. softmax
  2. (
  3. Q
  4. K
  5. T
  6. d
  7. )
  8. \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)
  9. softmax(dQKT​)

完整公式

  1. self_attention
  2. (
  3. Q
  4. ,
  5. K
  6. ,
  7. V
  8. )
  9. =
  10. softmax
  11. (
  12. Q
  13. K
  14. T
  15. d
  16. )
  17. \text{self\_attention}(Q, K, V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)
  18. self_attention(Q,K,V)=softmax(dQKT​)

在这里插入图片描述
来源:https://arxiv.org/pdf/2009.14794

代码实现:

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. classSelfAttention(nn.Module):def__init__(self, input_dim):super(SelfAttention, self).__init__()
  5. self.input_dim = input_dim
  6. # 定义线性变换层,用于计算查询、键和值
  7. self.query = nn.Linear(input_dim, input_dim)# [batch_size, seq_length, input_dim]
  8. self.key = nn.Linear(input_dim, input_dim)# [batch_size, seq_length, input_dim]
  9. self.value = nn.Linear(input_dim, input_dim)# [batch_size, seq_length, input_dim]
  10. self.softmax = nn.Softmax(dim=2)# 注意力权重的softmax函数,沿着最后一个维度进行defforward(self, x):# x的形状为 (batch_size, seq_length, input_dim)
  11. queries = self.query(x)# 计算查询矩阵
  12. keys = self.key(x)# 计算键矩阵
  13. values = self.value(x)# 计算值矩阵# 计算注意力得分
  14. score = torch.bmm(queries, keys.transpose(1,2))/(self.input_dim **0.5)
  15. attention = self.softmax(score)# 对得分应用softmax函数得到注意力权重
  16. weighted = torch.bmm(attention, values)# 使用注意力权重加权值矩阵return weighted # 返回加权后的值矩阵

Cross Attention(交叉注意力)

在这里插入图片描述
这张图片展示了交叉注意力模块的工作原理。

交叉注意力模块

  • 输入:- “What?”:这是表示“内容”的输入序列,包含值(Value,(V))和键(Key,(K))。- “Where?”:这是表示“位置”的输入序列,包含查询(Query,(Q))。
  • 计算过程:- 从“内容”输入序列中提取出值 (V) 和键 (K)。- 从“位置”输入序列中提取出查询 (Q)。- 计算查询 (Q) 和键 (K) 的点积,得到注意力能量(Attention energy)。- 将注意力能量除以 (\sqrt{C/h}),其中 (C) 是键的维度,(h) 是注意力头的数量,用以进行缩放。- 对缩放后的注意力能量应用 softmax 函数,得到注意力权重。- 将注意力权重应用到值 (V) 上,得到输出上下文(Output context)。

数学公式:

  1. Cross_attention
  2. (
  3. Q
  4. ,
  5. K
  6. ,
  7. V
  8. )
  9. =
  10. Softmax
  11. (
  12. Q
  13. K
  14. T
  15. C
  16. /
  17. h
  18. )
  19. V
  20. \text{Cross\_attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{C/h}}\right) \cdot V \
  21. Cross_attention(Q,K,V)=Softmax(C/hQKT​)⋅V
  • 解释:- ( Q ):查询矩阵。- ( K ):键矩阵。- ( V ):值矩阵。- (\text{Softmax}):softmax 函数,用于将注意力能量转换为概率分布。- ( \sqrt{C/h} ):缩放因子,控制注意力能量的大小。

代码实现

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import math
  5. classCrossAttention(nn.Module):def__init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):super().__init__()# 定义线性变换层,用于计算查询、键和值
  6. self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
  7. self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
  8. self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
  9. self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
  10. self.n_heads = n_heads # 注意力头的数量
  11. self.d_head = d_embed // n_heads # 每个注意力头的维度defforward(self, x, y):# x (潜在表示): (batch_size, seq_len_q, dim_q)# y (上下文): (batch_size, seq_len_kv, dim_kv) = (batch_size, 77, 768)
  12. input_shape = x.shape
  13. batch_size, sequence_length, d_embed = input_shape
  14. # 将每个查询的嵌入向量划分为多个头,确保 d_heads * n_heads = dim_q
  15. interim_shape =(batch_size,-1, self.n_heads, self.d_head)# 计算查询矩阵 (batch_size, seq_len_q, dim_q) -> (batch_size, seq_len_q, dim_q)
  16. q = self.q_proj(x)# 计算键矩阵 (batch_size, seq_len_kv, dim_kv) -> (batch_size, seq_len_kv, dim_q)
  17. k = self.k_proj(y)# 计算值矩阵 (batch_size, seq_len_kv, dim_kv) -> (batch_size, seq_len_kv, dim_q)
  18. v = self.v_proj(y)# 将查询矩阵重塑并转置以匹配注意力头 (batch_size, seq_len_q, dim_q) -> (batch_size, seq_len_q, h, dim_q / h) -> (batch_size, h, seq_len_q, dim_q / h)
  19. q = q.view(interim_shape).transpose(1,2)# 将键矩阵重塑并转置以匹配注意力头 (batch_size, seq_len_kv, dim_q) -> (batch_size, seq_len_kv, h, dim_q / h) -> (batch_size, h, seq_len_kv, dim_q / h)
  20. k = k.view(interim_shape).transpose(1,2)# 将值矩阵重塑并转置以匹配注意力头 (batch_size, seq_len_kv, dim_q) -> (batch_size, seq_len_kv, h, dim_q / h) -> (batch_size, h, seq_len_kv, dim_q / h)
  21. v = v.view(interim_shape).transpose(1,2)# 计算注意力得分 (batch_size, h, seq_len_q, dim_q / h) @ (batch_size, h, dim_q / h, seq_len_kv) -> (batch_size, h, seq_len_q, seq_len_kv)
  22. weight = q @ k.transpose(-1,-2)# 缩放注意力得分 (batch_size, h, seq_len_q, seq_len_kv)
  23. weight /= math.sqrt(self.d_head)# 对注意力得分应用softmax函数 (batch_size, h, seq_len_q, seq_len_kv)
  24. weight = F.softmax(weight, dim=-1)# 计算加权后的值矩阵 (batch_size, h, seq_len_q, seq_len_kv) @ (batch_size, h, seq_len_kv, dim_q / h) -> (batch_size, h, seq_len_q, dim_q / h)
  25. output = weight @ v
  26. # 将输出矩阵转置回原始形状 (batch_size, h, seq_len_q, dim_q / h) -> (batch_size, seq_len_q, h, dim_q / h)
  27. output = output.transpose(1,2).contiguous()# 将输出矩阵重塑回原始形状 (batch_size, seq_len_q, h, dim_q / h) -> (batch_size, seq_len_q, dim_q)
  28. output = output.view(input_shape)# 应用最后的线性变换 (batch_size, seq_len_q, dim_q) -> (batch_size, seq_len_q, dim_q)
  29. output = self.out_proj(output)# 返回最终的输出 (batch_size, seq_len_q, dim_q)return output

代码来源
https://github.com/hkproj/pytorch-stable-diffusion/blob/main/sd/attention.py


本文转载自: https://blog.csdn.net/weixin_55982578/article/details/140847442
版权归原作者 骆驼穿针眼 所有, 如有侵权,请联系我们删除。

“self-attention(自注意力) 和 cross-attetion(交叉注意力) 中的差异”的评论:

还没有评论