使用冻结层进行迁移学习
在yolov5的训练过程中,作者介绍了如何使用冻结层实现迁移学习的策略。具体可以参考官方话题:Transfer Learning with Frozen Layers · Issue #1314 · ultralytics/yolov5 · GitHub
在很多情况下,迁移学习是一种十分有用的方法,可以在新的数据集上快速重新训练模型,无需重新训练整个模型。对部分权重进行冻结,其余权重进行更新并计算损失,比正常训练需要更少的计算资源,更少的训练时间(更快的达到收敛速度),yolov5将冻结层的梯度设置为0控制参数更新实现冻结训练。下面我们看实施的细节:
1.1层结构
通过如下指令打印出模型的层结构:
for k, v in model.named_parameters():
print(k)
#output
model.0.conv.weight
model.0.bn.weight
model.0.bn.bias
model.1.conv.weight
model.1.bn.weight
model.1.bn.bias
model.2.cv1.conv.weight
model.2.cv1.bn.weight
model.2.cv1.bn.bias
model.2.cv2.conv.weight
model.2.cv2.bn.weight
model.2.cv2.bn.bias
model.2.cv3.conv.weight
model.2.cv3.bn.weight
model.2.cv3.bn.bias
model.2.m.0.cv1.conv.weight
model.2.m.0.cv1.bn.weight
model.2.m.0.cv1.bn.bias
model.2.m.0.cv2.conv.weight
model.2.m.0.cv2.bn.weight
model.2.m.0.cv2.bn.bias
model.3.conv.weight
model.3.bn.weight
model.3.bn.bias
model.4.cv1.conv.weight
model.4.cv1.bn.weight
model.4.cv1.bn.bias
model.4.cv2.conv.weight
model.4.cv2.bn.weight
model.4.cv2.bn.bias
model.4.cv3.conv.weight
model.4.cv3.bn.weight
model.4.cv3.bn.bias
model.4.m.0.cv1.conv.weight
model.4.m.0.cv1.bn.weight
model.4.m.0.cv1.bn.bias
model.4.m.0.cv2.conv.weight
model.4.m.0.cv2.bn.weight
model.4.m.0.cv2.bn.bias
model.4.m.1.cv1.conv.weight
model.4.m.1.cv1.bn.weight
model.4.m.1.cv1.bn.bias
model.4.m.1.cv2.conv.weight
model.4.m.1.cv2.bn.weight
model.4.m.1.cv2.bn.bias
model.5.conv.weight
model.5.bn.weight
model.5.bn.bias
model.6.cv1.conv.weight
model.6.cv1.bn.weight
model.6.cv1.bn.bias
model.6.cv2.conv.weight
model.6.cv2.bn.weight
model.6.cv2.bn.bias
model.6.cv3.conv.weight
model.6.cv3.bn.weight
model.6.cv3.bn.bias
model.6.m.0.cv1.conv.weight
model.6.m.0.cv1.bn.weight
model.6.m.0.cv1.bn.bias
model.6.m.0.cv2.conv.weight
model.6.m.0.cv2.bn.weight
model.6.m.0.cv2.bn.bias
model.6.m.1.cv1.conv.weight
model.6.m.1.cv1.bn.weight
model.6.m.1.cv1.bn.bias
model.6.m.1.cv2.conv.weight
model.6.m.1.cv2.bn.weight
model.6.m.1.cv2.bn.bias
model.6.m.2.cv1.conv.weight
model.6.m.2.cv1.bn.weight
model.6.m.2.cv1.bn.bias
model.6.m.2.cv2.conv.weight
model.6.m.2.cv2.bn.weight
model.6.m.2.cv2.bn.bias
model.7.conv.weight
model.7.bn.weight
model.7.bn.bias
model.8.cv1.conv.weight
model.8.cv1.bn.weight
model.8.cv1.bn.bias
model.8.cv2.conv.weight
model.8.cv2.bn.weight
model.8.cv2.bn.bias
model.8.cv3.conv.weight
model.8.cv3.bn.weight
model.8.cv3.bn.bias
model.8.m.0.cv1.conv.weight
model.8.m.0.cv1.bn.weight
model.8.m.0.cv1.bn.bias
model.8.m.0.cv2.conv.weight
model.8.m.0.cv2.bn.weight
model.8.m.0.cv2.bn.bias
model.9.cv1.conv.weight
model.9.cv1.bn.weight
model.9.cv1.bn.bias
model.9.cv2.conv.weight
model.9.cv2.bn.weight
model.9.cv2.bn.bias
model.10.conv.weight
model.10.bn.weight
model.10.bn.bias
model.13.cv1.conv.weight
model.13.cv1.bn.weight
model.13.cv1.bn.bias
model.13.cv2.conv.weight
model.13.cv2.bn.weight
model.13.cv2.bn.bias
model.13.cv3.conv.weight
model.13.cv3.bn.weight
model.13.cv3.bn.bias
model.13.m.0.cv1.conv.weight
model.13.m.0.cv1.bn.weight
model.13.m.0.cv1.bn.bias
model.13.m.0.cv2.conv.weight
model.13.m.0.cv2.bn.weight
model.13.m.0.cv2.bn.bias
model.14.conv.weight
model.14.bn.weight
model.14.bn.bias
model.17.cv1.conv.weight
model.17.cv1.bn.weight
model.17.cv1.bn.bias
model.17.cv2.conv.weight
model.17.cv2.bn.weight
model.17.cv2.bn.bias
model.17.cv3.conv.weight
model.17.cv3.bn.weight
model.17.cv3.bn.bias
model.17.m.0.cv1.conv.weight
model.17.m.0.cv1.bn.weight
model.17.m.0.cv1.bn.bias
model.17.m.0.cv2.conv.weight
model.17.m.0.cv2.bn.weight
model.17.m.0.cv2.bn.bias
model.18.conv.weight
model.18.bn.weight
model.18.bn.bias
model.20.cv1.conv.weight
model.20.cv1.bn.weight
model.20.cv1.bn.bias
model.20.cv2.conv.weight
model.20.cv2.bn.weight
model.20.cv2.bn.bias
model.20.cv3.conv.weight
model.20.cv3.bn.weight
model.20.cv3.bn.bias
model.20.m.0.cv1.conv.weight
model.20.m.0.cv1.bn.weight
model.20.m.0.cv1.bn.bias
model.20.m.0.cv2.conv.weight
model.20.m.0.cv2.bn.weight
model.20.m.0.cv2.bn.bias
model.21.conv.weight
model.21.bn.weight
model.21.bn.bias
model.23.cv1.conv.weight
model.23.cv1.bn.weight
model.23.cv1.bn.bias
model.23.cv2.conv.weight
model.23.cv2.bn.weight
model.23.cv2.bn.bias
model.23.cv3.conv.weight
model.23.cv3.bn.weight
model.23.cv3.bn.bias
model.23.m.0.cv1.conv.weight
model.23.m.0.cv1.bn.weight
model.23.m.0.cv1.bn.bias
model.23.m.0.cv2.conv.weight
model.23.m.0.cv2.bn.weight
model.23.m.0.cv2.bn.bias
model.24.m.0.weight
model.24.m.0.bias
model.24.m.1.weight
model.24.m.1.bias
model.24.m.2.weight
model.24.m.2.bias
1.2在训练的过程中,通过将梯度设置为0实现匹配层的冻结。
# Freeze
freeze = [f'model.{x}.' for x in range(freeze)] # layers to freeze
for k, v in model.named_parameters():
v.requires_grad = True # train all layers
if any(x in k for x in freeze):
print(f'freezing {k}')
v.requires_grad = False
1.3冻结骨干网络,根据yaml配置文件可以看出,0-9层为Backbone层,所以我们只需设置freeze为10即可在训练的时候冻结骨干网络进行训练,同理设置freeze为24即可冻结所有的层。
yaml配置文件
python train.py --freeze 10 #冻结骨干网络
python train.py --freeze 24 #冻结所有的层
同时在话题的下面有人提出了有趣的训练过程:
作者也是给出了回复:
修改训练好的模型(按需要修改)
yolov5保存的权重文件不仅仅包含是模型和参数,还包含其他的一些东西
从save model可以看出,训练结果还保存的其他的参数:
#epoch-- 当前模型对应的epoch数。
#best_fitness-- Fitness 是我们寻找最大值的变量,在 YOLOv5 中,我们将默认适应度函数定义为度量的加权组合:mAP@0.5 贡献了 10% 的权重,mAP@0.5:0.95 贡献了剩余的 90%,没有 Precision P 和 Recall R。您可以根据需要调整这些设置或使用默认的适合度定义。
#model-- 保存的模型。
#ema-- 指数移动平均。在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
#updata-- 保存的模型
#optimizer-- 优化信息
#wandb_id-- 可视化工具
——————————
我们可以打印出模型的信息:
model = torch.load("yolov5s.pt")
print(model)
可以清晰的看出模型包含哪些信息。
我们可以按需要修改参数:
import argparse
import torch
import numpy as np
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='', help='weights of input'))
opt = parser.parse_args()
model = torch.load(opt.weights, map_location=torch.device('cpu'))
# 模型实例化
net = model['model']
# 只保留有用信息
ckpt = {'epoch': -1,
'best_fitness': model['best_fitness'],
'model': net,
'ema':None,
'updates':None,
'optimizer': None,
'wandb_id':None,
'date':model['date']}
# 保存模型
torch.save(ckpt, 'my_weight.pt')
print('=========DONE=========')
3.使用迁移学习的精度比较以及GPU的利用率可以参考官网:
Freezing Layers in YOLOv5 | yolov5_tutorial_freeze – Weights & Biases (wandb.ai)
版权归原作者 o氧气o 所有, 如有侵权,请联系我们删除。