0


大模型 - 知识蒸馏原理解析

知识蒸馏的详细过程和原理解析

知识蒸馏是一种通过将大型预训练模型(教师模型)的知识传递给较小模型(学生模型)的方法。这样可以在减少模型的复杂度和计算资源需求的同时,尽量保留模型的性能。以下是知识蒸馏的详细过程和每个步骤中用到的原理。

1. 输入数据

假设我们有一个图像分类任务,输入数据

     x 
    
   
  
    x 
   
  
x 是一张图像。这个图像同时馈送给教师模型和学生模型。

2. 教师模型

  • 教师模型是一个已经训练好的大模型,它对输入 x x x 进行预测。
  • 教师模型的输出经过一个带温度参数 T T T 的 softmax 函数,得到软标签(soft labels)。温度参数 T T T 用于平滑预测概率,使得输出概率分布更平缓。

具体来说,假设教师模型输出的 logits 为

     [ 
    
   
     2.0 
    
   
     , 
    
   
     1.0 
    
   
     , 
    
   
     0.1 
    
   
     ] 
    
   
  
    [2.0, 1.0, 0.1] 
   
  
[2.0,1.0,0.1],在温度  
 
  
   
   
     T 
    
   
     = 
    
   
     2 
    
   
  
    T=2 
   
  
T=2 下,softmax 计算如下:


  
   
    
    
      softmax 
     
    
      ( 
     
     
     
       z 
      
     
       i 
      
     
    
      ; 
     
    
      T 
     
    
      = 
     
    
      2 
     
    
      ) 
     
    
      = 
     
     
      
      
        e 
       
       
        
        
          z 
         
        
          i 
         
        
       
         / 
        
       
         2 
        
       
      
      
       
       
         ∑ 
        
       
         j 
        
       
       
       
         e 
        
        
         
         
           z 
          
         
           j 
          
         
        
          / 
         
        
          2 
         
        
       
      
     
    
   
     \text{softmax}(z_i; T=2) = \frac{e^{z_i / 2}}{\sum_{j} e^{z_j / 2}} 
    
   
 softmax(zi​;T=2)=∑j​ezj​/2ezi​/2​

计算得:

      softmax 
     
    
      ( 
     
    
      2.0 
     
    
      / 
     
    
      2 
     
    
      ) 
     
    
      = 
     
     
      
      
        e 
       
      
        1.0 
       
      
      
       
       
         e 
        
       
         1.0 
        
       
      
        + 
       
       
       
         e 
        
       
         0.5 
        
       
      
        + 
       
       
       
         e 
        
       
         0.05 
        
       
      
     
    
      = 
     
    
      0.504 
     
    
   
     \text{softmax}(2.0 / 2) = \frac{e^{1.0}}{e^{1.0} + e^{0.5} + e^{0.05}} = 0.504 
    
   
 softmax(2.0/2)=e1.0+e0.5+e0.05e1.0​=0.504

  
   
    
    
      softmax 
     
    
      ( 
     
    
      1.0 
     
    
      / 
     
    
      2 
     
    
      ) 
     
    
      = 
     
     
      
      
        e 
       
      
        0.5 
       
      
      
       
       
         e 
        
       
         1.0 
        
       
      
        + 
       
       
       
         e 
        
       
         0.5 
        
       
      
        + 
       
       
       
         e 
        
       
         0.05 
        
       
      
     
    
      = 
     
    
      0.277 
     
    
   
     \text{softmax}(1.0 / 2) = \frac{e^{0.5}}{e^{1.0} + e^{0.5} + e^{0.05}} = 0.277 
    
   
 softmax(1.0/2)=e1.0+e0.5+e0.05e0.5​=0.277

  
   
    
    
      softmax 
     
    
      ( 
     
    
      0.1 
     
    
      / 
     
    
      2 
     
    
      ) 
     
    
      = 
     
     
      
      
        e 
       
      
        0.05 
       
      
      
       
       
         e 
        
       
         1.0 
        
       
      
        + 
       
       
       
         e 
        
       
         0.5 
        
       
      
        + 
       
       
       
         e 
        
       
         0.05 
        
       
      
     
    
      = 
     
    
      0.219 
     
    
   
     \text{softmax}(0.1 / 2) = \frac{e^{0.05}}{e^{1.0} + e^{0.5} + e^{0.05}} = 0.219 
    
   
 softmax(0.1/2)=e1.0+e0.5+e0.05e0.05​=0.219

软标签为

     [ 
    
   
     0.504 
    
   
     , 
    
   
     0.277 
    
   
     , 
    
   
     0.219 
    
   
     ] 
    
   
  
    [0.504, 0.277, 0.219] 
   
  
[0.504,0.277,0.219]。

3. 学生模型

  • 学生模型是一个较小的模型,它也对输入 x x x 进行预测。
  • 学生模型的输出经过两个 softmax 函数处理,一个带温度 T T T 得到软预测(soft predictions),另一个带温度 T = 1 T=1 T=1 得到硬预测(hard predictions)。

假设学生模型输出的 logits 为

     [ 
    
   
     1.8 
    
   
     , 
    
   
     0.9 
    
   
     , 
    
   
     0.4 
    
   
     ] 
    
   
  
    [1.8, 0.9, 0.4] 
   
  
[1.8,0.9,0.4],在温度  
 
  
   
   
     T 
    
   
     = 
    
   
     2 
    
   
  
    T=2 
   
  
T=2 下,softmax 计算如下:


  
   
    
    
      softmax 
     
    
      ( 
     
    
      1.8 
     
    
      / 
     
    
      2 
     
    
      ) 
     
    
      = 
     
     
      
      
        e 
       
      
        0.9 
       
      
      
       
       
         e 
        
       
         0.9 
        
       
      
        + 
       
       
       
         e 
        
       
         0.45 
        
       
      
        + 
       
       
       
         e 
        
       
         0.2 
        
       
      
     
    
      = 
     
    
      0.474 
     
    
   
     \text{softmax}(1.8 / 2) = \frac{e^{0.9}}{e^{0.9} + e^{0.45} + e^{0.2}} = 0.474 
    
   
 softmax(1.8/2)=e0.9+e0.45+e0.2e0.9​=0.474

  
   
    
    
      softmax 
     
    
      ( 
     
    
      0.9 
     
    
      / 
     
    
      2 
     
    
      ) 
     
    
      = 
     
     
      
      
        e 
       
      
        0.45 
       
      
      
       
       
         e 
        
       
         0.9 
        
       
      
        + 
       
       
       
         e 
        
       
         0.45 
        
       
      
        + 
       
       
       
         e 
        
       
         0.2 
        
       
      
     
    
      = 
     
    
      0.301 
     
    
   
     \text{softmax}(0.9 / 2) = \frac{e^{0.45}}{e^{0.9} + e^{0.45} + e^{0.2}} = 0.301 
    
   
 softmax(0.9/2)=e0.9+e0.45+e0.2e0.45​=0.301

  
   
    
    
      softmax 
     
    
      ( 
     
    
      0.4 
     
    
      / 
     
    
      2 
     
    
      ) 
     
    
      = 
     
     
      
      
        e 
       
      
        0.2 
       
      
      
       
       
         e 
        
       
         0.9 
        
       
      
        + 
       
       
       
         e 
        
       
         0.45 
        
       
      
        + 
       
       
       
         e 
        
       
         0.2 
        
       
      
     
    
      = 
     
    
      0.225 
     
    
   
     \text{softmax}(0.4 / 2) = \frac{e^{0.2}}{e^{0.9} + e^{0.45} + e^{0.2}} = 0.225 
    
   
 softmax(0.4/2)=e0.9+e0.45+e0.2e0.2​=0.225

软预测为

     [ 
    
   
     0.474 
    
   
     , 
    
   
     0.301 
    
   
     , 
    
   
     0.225 
    
   
     ] 
    
   
  
    [0.474, 0.301, 0.225] 
   
  
[0.474,0.301,0.225]。

硬预测(

     T 
    
   
     = 
    
   
     1 
    
   
  
    T=1 
   
  
T=1)的 softmax 计算如下:

  
   
    
    
      softmax 
     
    
      ( 
     
    
      1.8 
     
    
      ) 
     
    
      = 
     
     
      
      
        e 
       
      
        1.8 
       
      
      
       
       
         e 
        
       
         1.8 
        
       
      
        + 
       
       
       
         e 
        
       
         0.9 
        
       
      
        + 
       
       
       
         e 
        
       
         0.4 
        
       
      
     
    
      = 
     
    
      0.659 
     
    
   
     \text{softmax}(1.8) = \frac{e^{1.8}}{e^{1.8} + e^{0.9} + e^{0.4}} = 0.659 
    
   
 softmax(1.8)=e1.8+e0.9+e0.4e1.8​=0.659

  
   
    
    
      softmax 
     
    
      ( 
     
    
      0.9 
     
    
      ) 
     
    
      = 
     
     
      
      
        e 
       
      
        0.9 
       
      
      
       
       
         e 
        
       
         1.8 
        
       
      
        + 
       
       
       
         e 
        
       
         0.9 
        
       
      
        + 
       
       
       
         e 
        
       
         0.4 
        
       
      
     
    
      = 
     
    
      0.242 
     
    
   
     \text{softmax}(0.9) = \frac{e^{0.9}}{e^{1.8} + e^{0.9} + e^{0.4}} = 0.242 
    
   
 softmax(0.9)=e1.8+e0.9+e0.4e0.9​=0.242

  
   
    
    
      softmax 
     
    
      ( 
     
    
      0.4 
     
    
      ) 
     
    
      = 
     
     
      
      
        e 
       
      
        0.4 
       
      
      
       
       
         e 
        
       
         1.8 
        
       
      
        + 
       
       
       
         e 
        
       
         0.9 
        
       
      
        + 
       
       
       
         e 
        
       
         0.4 
        
       
      
     
    
      = 
     
    
      0.099 
     
    
   
     \text{softmax}(0.4) = \frac{e^{0.4}}{e^{1.8} + e^{0.9} + e^{0.4}} = 0.099 
    
   
 softmax(0.4)=e1.8+e0.9+e0.4e0.4​=0.099

硬预测为

     [ 
    
   
     0.659 
    
   
     , 
    
   
     0.242 
    
   
     , 
    
   
     0.099 
    
   
     ] 
    
   
  
    [0.659, 0.242, 0.099] 
   
  
[0.659,0.242,0.099]。

4. 蒸馏损失(Distillation Loss)

  • 蒸馏损失是教师模型的软标签和学生模型的软预测之间的差异,通常使用 KL 散度(Kullback-Leibler Divergence)作为损失函数。

         D 
        
        
        
          K 
         
        
          L 
         
        
       
      
        ( 
       
      
        P 
       
      
        ∥ 
       
      
        Q 
       
      
        ) 
       
      
        = 
       
       
       
         ∑ 
        
        
        
          x 
         
        
          ∈ 
         
        
          X 
         
        
       
      
        P 
       
      
        ( 
       
      
        x 
       
      
        ) 
       
      
        log 
       
      
        ⁡ 
       
       
       
         ( 
        
        
         
         
           P 
          
         
           ( 
          
         
           x 
          
         
           ) 
          
         
         
         
           Q 
          
         
           ( 
          
         
           x 
          
         
           ) 
          
         
        
       
         ) 
        
       
      
     
       D_{KL}(P \parallel Q) = \sum_{x \in X} P(x) \log \left( \frac{P(x)}{Q(x)} \right) 
      
     
    

    DKL​(P∥Q)=x∈X∑​P(x)log(Q(x)P(x)​)

假设软标签

     P 
    
   
  
    P 
   
  
P 为  
 
  
   
   
     [ 
    
   
     0.504 
    
   
     , 
    
   
     0.277 
    
   
     , 
    
   
     0.219 
    
   
     ] 
    
   
  
    [0.504, 0.277, 0.219] 
   
  
[0.504,0.277,0.219],软预测  
 
  
   
   
     Q 
    
   
  
    Q 
   
  
Q 为  
 
  
   
   
     [ 
    
   
     0.474 
    
   
     , 
    
   
     0.301 
    
   
     , 
    
   
     0.225 
    
   
     ] 
    
   
  
    [0.474, 0.301, 0.225] 
   
  
[0.474,0.301,0.225]:

  
   
    
     
     
       D 
      
      
      
        K 
       
      
        L 
       
      
     
    
      ( 
     
    
      P 
     
    
      ∥ 
     
    
      Q 
     
    
      ) 
     
    
      = 
     
    
      0.504 
     
    
      log 
     
    
      ⁡ 
     
     
     
       ( 
      
      
      
        0.504 
       
      
        0.474 
       
      
     
       ) 
      
     
    
      + 
     
    
      0.277 
     
    
      log 
     
    
      ⁡ 
     
     
     
       ( 
      
      
      
        0.277 
       
      
        0.301 
       
      
     
       ) 
      
     
    
      + 
     
    
      0.219 
     
    
      log 
     
    
      ⁡ 
     
     
     
       ( 
      
      
      
        0.219 
       
      
        0.225 
       
      
     
       ) 
      
     
    
   
     D_{KL}(P \parallel Q) = 0.504 \log \left( \frac{0.504}{0.474} \right) + 0.277 \log \left( \frac{0.277}{0.301} \right) + 0.219 \log \left( \frac{0.219}{0.225} \right) 
    
   
 DKL​(P∥Q)=0.504log(0.4740.504​)+0.277log(0.3010.277​)+0.219log(0.2250.219​)

计算得:

       D 
      
      
      
        K 
       
      
        L 
       
      
     
    
      ( 
     
    
      P 
     
    
      ∥ 
     
    
      Q 
     
    
      ) 
     
    
      = 
     
    
      0.504 
     
    
      ⋅ 
     
    
      0.0623 
     
    
      + 
     
    
      0.277 
     
    
      ⋅ 
     
    
      − 
     
    
      0.0848 
     
    
      + 
     
    
      0.219 
     
    
      ⋅ 
     
    
      − 
     
    
      0.0267 
     
    
   
     D_{KL}(P \parallel Q) = 0.504 \cdot 0.0623 + 0.277 \cdot -0.0848 + 0.219 \cdot -0.0267 
    
   
 DKL​(P∥Q)=0.504⋅0.0623+0.277⋅−0.0848+0.219⋅−0.0267

  
   
    
    
      = 
     
    
      0.0314 
     
    
      − 
     
    
      0.0235 
     
    
      − 
     
    
      0.0058 
     
    
   
     = 0.0314 - 0.0235 - 0.0058 
    
   
 =0.0314−0.0235−0.0058

  
   
    
    
      = 
     
    
      0.0021 
     
    
   
     = 0.0021 
    
   
 =0.0021

5. 学生损失(Student Loss)

  • 学生损失是学生模型的硬预测和真实标签(硬标签)之间的差异,通常使用交叉熵损失函数。

假设真实标签

     y 
    
   
  
    y 
   
  
y 为类别 1,则 one-hot 编码为  
 
  
   
   
     [ 
    
   
     1 
    
   
     , 
    
   
     0 
    
   
     , 
    
   
     0 
    
   
     ] 
    
   
  
    [1, 0, 0] 
   
  
[1,0,0],硬预测为  
 
  
   
   
     [ 
    
   
     0.659 
    
   
     , 
    
   
     0.242 
    
   
     , 
    
   
     0.099 
    
   
     ] 
    
   
  
    [0.659, 0.242, 0.099] 
   
  
[0.659,0.242,0.099],交叉熵损失为:


  
   
    
    
      H 
     
    
      ( 
     
    
      y 
     
    
      , 
     
     
     
       y 
      
     
       ^ 
      
     
    
      ) 
     
    
      = 
     
    
      − 
     
     
     
       ∑ 
      
     
       i 
      
     
     
     
       y 
      
     
       i 
      
     
    
      log 
     
    
      ⁡ 
     
    
      ( 
     
     
      
      
        y 
       
      
        ^ 
       
      
     
       i 
      
     
    
      ) 
     
    
   
     H(y, \hat{y}) = - \sum_{i} y_i \log(\hat{y}_i) 
    
   
 H(y,y^​)=−i∑​yi​log(y^​i​)

  
   
    
    
      H 
     
    
      ( 
     
    
      y 
     
    
      , 
     
     
     
       y 
      
     
       ^ 
      
     
    
      ) 
     
    
      = 
     
    
      − 
     
    
      ( 
     
    
      1 
     
    
      ⋅ 
     
    
      log 
     
    
      ⁡ 
     
    
      ( 
     
    
      0.659 
     
    
      ) 
     
    
      + 
     
    
      0 
     
    
      ⋅ 
     
    
      log 
     
    
      ⁡ 
     
    
      ( 
     
    
      0.242 
     
    
      ) 
     
    
      + 
     
    
      0 
     
    
      ⋅ 
     
    
      log 
     
    
      ⁡ 
     
    
      ( 
     
    
      0.099 
     
    
      ) 
     
    
      ) 
     
    
   
     H(y, \hat{y}) = - (1 \cdot \log(0.659) + 0 \cdot \log(0.242) + 0 \cdot \log(0.099)) 
    
   
 H(y,y^​)=−(1⋅log(0.659)+0⋅log(0.242)+0⋅log(0.099))

  
   
    
    
      = 
     
    
      − 
     
    
      log 
     
    
      ⁡ 
     
    
      ( 
     
    
      0.659 
     
    
      ) 
     
    
      = 
     
    
      0.416 
     
    
   
     = - \log(0.659) = 0.416 
    
   
 =−log(0.659)=0.416

6. 总损失(Total Loss)

  • 总损失是蒸馏损失和学生损失的加权和: Total Loss = α × Student Loss + β × Distillation Loss \text{Total Loss} = \alpha \times \text{Student Loss} + \beta \times \text{Distillation Loss} Total Loss=α×Student Loss+β×Distillation Loss

假设

     α 
    
   
     = 
    
   
     1 
    
   
  
    \alpha = 1 
   
  
α=1, 
 
  
   
   
     β 
    
   
     = 
    
   
     0.5 
    
   
  
    \beta = 0.5 
   
  
β=0.5,则总损失为:

  
   
    
    
      Total Loss 
     
    
      = 
     
    
      1 
     
    
      × 
     
    
      0.416 
     
    
      + 
     
    
      0.5 
     
    
      × 
     
    
      0.0021 
     
    
      = 
     
    
      0.417 
     
    
   
     \text{Total Loss} = 1 \times 0.416 + 0.5 \times 0.0021 = 0.417 
    
   
 Total Loss=1×0.416+0.5×0.0021=0.417

代码示例

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定义教师模型和学生模型classTeacherModel(nn.Module):def__init__(self):super(TeacherModel, self).__init__()
        self.fc = nn.Linear(784,10)defforward(self, x):return self.fc(x)classStudentModel(nn.Module):def__init__(self):super(StudentModel, self).__init__()
        self.fc = nn.Linear(784,10)defforward(self, x):return self.fc(x)# 定义蒸馏损失函数defdistillation_loss(soft_labels, soft_predictions, T):
    soft_labels = F.softmax(soft_labels / T, dim=1)
    soft_predictions = F.log_softmax(soft_predictions / T, dim=1)
    loss = F.kl_div(soft_predictions, soft_labels, reduction='batchmean')*(T **2)return loss

# 定义学生损失函数defstudent_loss(hard_labels, hard_predictions):return F.cross_entropy(hard_predictions, hard_labels)# 超参数
alpha =1.0
beta =0.5
temperature =2.0
learning_rate =0.001
num_epochs =10# 数据加载器(使用MNIST数据集作为示例)from torchvision import datasets, transforms
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=64, shuffle=True)# 初始化模型、优化器
teacher_model = TeacherModel()
student_model = StudentModel()
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)# 假设教师模型已经预训练好,这里直接加载预训练权重# teacher_model.load_state_dict(torch.load('teacher_model.pth'))# 训练过程
teacher_model.eval()# 教师模型设为评估模式,不进行训练
student_model.train()# 学生模型设为训练模式for epoch inrange(num_epochs):
    total_loss =0for data, target in train_loader:
        data = data.view(data.size(0),-1)# 展开图像数据# 教师模型预测with torch.no_grad():
            teacher_output = teacher_model(data)# 学生模型预测
        student_output = student_model(data)
        soft_predictions = student_output / temperature
        hard_predictions = student_output
        
        # 计算蒸馏损失和学生损失
        dist_loss = distillation_loss(teacher_output, student_output, temperature)
        stud_loss = student_loss(target, hard_predictions)# 计算总损失
        loss = alpha * stud_loss + beta * dist_loss
        
        # 优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}')# 保存学生模型
torch.save(student_model.state_dict(),'student_model.pth')

代码解释

  1. 模型定义:定义了一个简单的全连接层的教师模型和学生模型。
  2. 蒸馏损失和学生损失函数:- distillation_loss 计算KL散度作为蒸馏损失。- student_loss 计算交叉熵损失作为学生损失。
  3. 超参数:- alphabeta 分别是学生损失和蒸馏损失的权重。- temperature 是温度参数,用于平滑教师模型的输出。
  4. 数据加载:使用MNIST数据集作为示例。
  5. 模型初始化:初始化教师模型和学生模型,并定义优化器。
  6. 训练过程:- 教师模型设为评估模式,学生模型设为训练模式。- 在每个训练周期中,对每个批次数据进行预测,计算损失,并进行优化。
  7. 保存模型:在训练结束后保存学生模型的权重。

该代码示例展示了如何通过PyTorch实现模型蒸馏的训练过程。如果有其他需求或需要进一步解释的地方,请告诉我。

总结

知识蒸馏通过教师模型提供的软标签引导学生模型,使得学生模型不仅关注硬标签的分类准确性,还能从软标签中学习更丰富的类别间关系,从而在模型压缩的同时尽量保留性能。这种方法特别适用于在资源受限的环境中部署高效的深度学习模型。


本文转载自: https://blog.csdn.net/weixin_47552266/article/details/140230614
版权归原作者 想胖的壮壮 所有, 如有侵权,请联系我们删除。

“大模型 - 知识蒸馏原理解析”的评论:

还没有评论