一 修改思想
目前yoloV7已经发布有一段时间了,其中yoloV7有一个关键点检测的分支pose,是一个姿态关键点的检测算法,其中有给出的数据,大家可以下载运行起来。
由于实际项目需求,发现17个关键点是不能满足大家的需求的,这里我就稍作修改了一下,把关键点的数量修改为任意数量,并且添加目标检测多分类情况。
二 数据修改
这里我是以车牌的关键点进行举例修改,所有关键点的数量为4个。
修改的过程中需要做左右翻转,所以我的关键点翻转后,1和2交换,3和4交换,5和6交换,依次类推。
1 yaml文件修改
- 设置关键点数量,修改关键点数量为4
- 设置类别数量和类别标签
数据制作
- 训练标签制作
# -nfs-阿拉伯车牌字符-沙特阿拉伯卡口车牌-2-沙特阿拉伯卡口车牌-2-image1837.txt# data.txt 含义分别是: cls x y w h point1x point1y point2x point2y point3x point3y point4x point4y ...# 类别 目标中心点x 目标中心点y 目标宽w 目标高h 目标点1x坐标 目标点1y坐标 目标点2x坐标 目标点2y坐标 目标点3x坐标 目标点3y坐标 目标点4x坐标 目标点4y坐标 依次类推 00.57392996108949410.17241379310344830.37159533073929960.290640394088669930.389105058365758760.083743842364532010.75875486381322950.0295566502463054170.76070038910505830.26600985221674880.392996108949416330.3201970443349753720.57392996108949410.17241379310344830.37159533073929960.290640394088669930.389105058365758760.083743842364532010.75875486381322950.0295566502463054170.76070038910505830.26600985221674880.392996108949416330.3201970443349753700.57392996108949410.17241379310344830.37159533073929960.290640394088669930.389105058365758760.083743842364532010.75875486381322950.0295566502463054170.76070038910505830.26600985221674880.392996108949416330.32019704433497537
- 训练文件train.txt 和val.txt 制作
# train.txt./train/images/-nfs-阿拉伯车牌字符-沙特阿拉伯卡口车牌-2-沙特阿拉伯卡口车牌-2-image1837.jpeg
./train/images/-nfs-车牌字符-埃及车牌-埃及车牌截图-2021-04-3011-11-52屏幕截图.png
./train/images/-nfs-车牌字符-埃及车牌-埃及车牌截图-2021-04-3013-57-27屏幕截图.png
./train/images/-nfs-车牌字符-埃及车牌-埃及车牌截图-2021-04-3010-19-54屏幕截图.png
./train/images/-nfs-阿拉伯车牌字符-外国车牌现场_20210519_1-外国车牌现场_20210519_1-e0d92b0990a1249388bc77bdfa8e43ed.jpg
./train/images/-nfs-车牌字符-埃及车牌-埃及车牌截图-2021-04-3013-51-28屏幕截图.png
./train/images/-nfs-车牌字符-约旦车牌-videoplayback-videoplayback_13_1460.jpg
./train/images/-nfs-车牌字符-埃及车牌-埃及车牌截图-2021-04-3013-56-51屏幕截图.png
./train/images/-nfs-车牌字符-埃及车牌-埃及车牌截图-2021-04-3010-27-50屏幕截图.png
数据读取修改
核心思想就是要把关键点的数量传入数据读取中,根据关键点数量进行数据读取操作。
- datasets.py/LoadImagesAndLabels() 初始化修改
- cache_labels() 方法 :数据读取修改
- LoadImagesAndLabels()中__getitem__() 方法 :数据左右翻转修改
- datasets.py/random_perspective() 方法
用到random_perspective()方法的地方记得都去添加一个关机键点数量参数。
三 网络结构修改
1 model文件修改
- yolo.py 我们用到的是IKeypoint()方法,所以这里只修改这个方法,其他方法是一样的修改。 结构中主要是把分类数量和关机键点数量加入进去就ok了。
loss文件修改
- loss初始化类别数量和关机键数量
- loss 计算中加入类别和关机键点计算
- build_targets() 方法修改 添加关键点数量
四 训练代码修改
1 train.py
- 添加关机键点数量
- 读取数据加入关机键点数量
- 初始化loss 添加关键点数量和类别数量
- 画图添加关键点数量
2 test.py
- 读取数据添加关机键点数量
- 画图添加关机键点数量
3 general.py 中non_max_suppression()方法修改
4 plots.py 文件夹修改
主要是添加关机键点的数量
五 代码分享
1 训练测试
关键点效果已经成功加上去了,并且也添加了多分类。
2 代码链接
代码地址(yolov7-pose_Npoint_Ncla):https://github.com/qinggangwu/yolov7-pose_Npoint_Ncla
版权归原作者 五小白 所有, 如有侵权,请联系我们删除。