在大佬的博客补充了一些小问题,按照如下修改,你的代码就能跑起来了
使用MobileViT替换YOLOv5主干网络
收费教程:YOLOv5更换骨干网络之 MobileViT-S / MobileViT-XS / MobileViT-XXS
知识储备
MobileViT模型简介
MobileViT、MobileViTv2、MobileViTv3学习笔记(自用)
MobileViTv1、MobileViTv2、MobileViTv3网络详解
准备工作:
我使用的是6.0 yolov5s
mobilevit
正式修改
将mobilevit.py放在yolov5/models
2. 修改models/yolo.py
加入所有的模块,或者只加入MV2Block, MobileViTBlock
加入MV2Block, MobileViTBlock
3.修改yaml文件
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 backbone
backbone:
# [from, number, module, args] 640 x 640
# [[-1, 1, Conv, [32, 6, 2, 2]], # 0-P1/2 320 x 320
[[-1, 1, Focus, [32, 3]],
[-1, 1, MV2Block, [32, 1, 2]], # 1-P2/4
[-1, 1, MV2Block, [48, 2, 2]], # 160 x 160
[-1, 2, MV2Block, [48, 1, 2]],
[-1, 1, MV2Block, [64, 2, 2]], # 80 x 80
[-1, 1, MobileViTBlock, [64,96, 2, 3, 2, 192]], # 5 out_dim,dim, depth, kernel_size, patch_size, mlp_dim
[-1, 1, MV2Block, [80, 2, 2]], # 40 x 40
[-1, 1, MobileViTBlock, [80,120, 4, 3, 2, 480]], # 7
[-1, 1, MV2Block, [96, 2, 2]], # 20 x 20
[-1, 1, MobileViTBlock, [96,144, 3, 3, 2, 576]], # 11-P2/4 # 9
]
# YOLOv5 head
head:
[[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 7], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [256, False]], # 13
[-1, 1, Conv, [128, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 5], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [128, False]], # 17 (P3/8-small)
[-1, 1, Conv, [128, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [256, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [512, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
修改mobilevit.py
可以愉快的跑起来了!!!
END
谢谢观看,有用的话点个赞吧!
ADD
einops.EinopsError: Error while processing rearrange-reduction pattern "b d (h ph) (w pw) -> b (ph pw) (h w) d".
Input tensor shape: torch.Size([1, 120, 42, 42]). Additional info: {'ph': 4, 'pw': 4}
是因为输入输出不匹配造成
记得关掉rect哦!一个是在参数里,另一个在下图。如果要在test或者val中跑,同样要改
特别感谢养乐多阿
版权归原作者 蜗牛学ai 所有, 如有侵权,请联系我们删除。