0


YoloV5 模型自定义评估-误报、漏报、错报

YoloV5模型训练成功后,可以通过自带的val.py文件进行评估分析,其提供mAp、Iou以及混淆矩阵等,很好,但是……领导不认可……/(ㄒoㄒ)/~~。领导要的是最直观的东西,比如这个模型识别目标的准确率,还有误报率等……。那么,领导的要求就是我们开发的方向:

为了得到准确率以及误报、漏报、错报的情况,需要使用模型检测已经标注过的样本,将检测结果与标注进行比对。

YoloV5的val.py文件已经实现了大部分功能,可以直接拿来改造:

首先,原文件中对预测结果的NMS处理函数调用如下:

  1. out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)

val.py本来的一个作用就是与train.py配合,在模型迭代训练过程中提供评估,但是我们这里不需要处理太多可能,只要置信度最高的就好了,修改如下:

  1. out = non_max_suppression(out, conf_thres, iou_thres, multi_label=False, agnostic=single_cls)

然后将检测结果与标注结果做对比,原代码中是在process_batch这个函数中进行处理,计算所有的检测框和标注框的IoU,根据检测类比和标注类别再进行处理。这里做一下改造,代码如下:

  1. def process_batch(batch_i, detections, labels, iouv):
  2. """
  3. Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format.
  4. Arguments:
  5. detections (Array[N, 6]), x1, y1, x2, y2, conf, class
  6. labels (Array[M, 5]), class, x1, y1, x2, y2
  7. Returns:
  8. correct (Array[N, 10]), for 10 IoU levels
  9. """
  10. # 正确:0 误报: 1 漏报: 2 错报: 3
  11. labelNum = labels.shape[0]
  12. detectNum = detections.shape[0]
  13. if labelNum == 0 and detectNum == 0:
  14. return 0
  15. elif labelNum == 0: #存在误报
  16. results = torch.zeros(detectNum, 6 , dtype=torch.float, device=iouv.device)
  17. results[:,0] = batch_i
  18. results[:,1] = 1
  19. results[:,2] = detections[:, 5] #检测类别
  20. results[:,3] = -1.0 #标注类别
  21. #results[:,4] = 0.0 #检测框和标注框的IoU
  22. results[:,5] = detections[:, 4] #检测置信度
  23. return results
  24. elif detectNum == 0: #存在漏报
  25. results = torch.zeros(labelNum, 6, dtype=torch.float, device=iouv.device)
  26. results[:,0] = batch_i
  27. results[:,1] = 2
  28. results[:,2] = -1.0 #检测类别
  29. results[:,3] = labels[:, 0] #标注类别
  30. #results[:,4] = 0.0 #检测框和标注框的IoU
  31. #results[:,5] = 0.0 #检测置信度
  32. return results
  33. else:
  34. tempresults = torch.zeros(0, 6, dtype=torch.float, device=iouv.device)
  35. indexes = torch.zeros(1, detectNum, dtype=torch.float, device=iouv.device)
  36. validNum = 0
  37. i=0
  38. for label in labels:
  39. j=0
  40. bMatchLabel = False
  41. for detect in detections:
  42. iou = singlebox_iou(label[1:], detect[:4])
  43. if iou > 0.5:
  44. bMatchLabel = True
  45. indexes[:,j] = 1.0
  46. validNum += 1
  47. result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
  48. result[:,0] = batch_i
  49. result[:,1] = 0 if detect[5] == label[0] else 3
  50. result[:,2] = detect[5] #检测类别
  51. result[:,3] = label[0] #标注类别
  52. result[:,4] = iou #检测框和标注框的IoU
  53. result[:,5] = detect[4] #检测置信度
  54. #tempresults.append(result)
  55. tempresults = torch.cat((tempresults, result), 0)
  56. j += 1
  57. i += 1
  58. if not bMatchLabel: #如果有标注狂没有匹配到检测框的为漏报
  59. validNum += 1
  60. result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
  61. result[:,0] = batch_i
  62. result[:,1] = 2
  63. result[:,2] = -1.0 #检测类别
  64. result[:,3] = label[0] #标注类别
  65. #tempresults.append(result)
  66. tempresults = torch.cat((tempresults, result), 0)
  67. #如果有检测框没有匹配到标注框的为误报
  68. noMatchDect = indexes[0][(indexes[0] == 0).nonzero()]
  69. if noMatchDect.shape[0] > 0:
  70. for index in noMatchDect:
  71. result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
  72. i = int(index[0])
  73. result[:,0] = batch_i
  74. result[:,1] = 1
  75. result[:,2] = detections[i][5] #检测类别
  76. result[:,3] = -1.0 #标注类别
  77. result[:,5] = detections[i][4] #检测置信度
  78. #tempresults.append(result)
  79. tempresults = torch.cat((tempresults, result), 0)
  80. return tempresults

还得解释一下这里对误报、漏报、错报的定义:

误报:没有标注,但是检测出目标

漏报:有标注,但是没检测出目标

错报:有标注,和检测目标的类别不一样

所有样本处理完后进行统计,统计每个类别的误报、漏报数量,以及错报(其实混淆矩阵就是这个功能,只不过这里用的是数量),代码如下:

  1. def analyse(totalResults, names, save_dir):
  2. classnames =[]
  3. for i in range(len(names)):
  4. classnames.append(str(names[i]))
  5. #误报
  6. misIndexes = torch.where(totalResults[:,1] == 1.0)
  7. misItems = totalResults[misIndexes]
  8. misStats = torch.zeros(1, len(names), dtype=torch.int, device=totalResults.device)
  9. line = "误报- "
  10. for i in range(len(names)):
  11. n = len(torch.where(misItems[:,2] == i)[0])
  12. misStats[:,i] = n
  13. line += names[i] + ":" + str(n) +" "
  14. print(line + "\n")
  15. #plt.rcParams['font.sans-serif']=['SimHei']
  16. #plt.style.use('ggplot')
  17. #plt.title("误报")
  18. #plt.xlabel("异物类别")
  19. #plt.ylabel("数量")
  20. #plt.bar(classnames, misStats[0].tolist(), width=1.2, color='blue')
  21. #plt.savefig(str(save_dir) + "/evaluates/误报.jpg")
  22. #漏报
  23. failIndexes = torch.where(totalResults[:,1] == 2.0)
  24. failItems = totalResults[failIndexes]
  25. failStats = torch.zeros(1, len(names), dtype=torch.int, device=totalResults.device)
  26. line = "漏报- "
  27. for i in range(len(names)):
  28. n = len(torch.where(failItems[:,3] == i)[0])
  29. failStats[:,i] = n
  30. line += names[i] + ":" + str(n) +" "
  31. #plt.title("漏报")
  32. #plt.xlabel("异物类别")
  33. #plt.ylabel("数量")
  34. #plt.bar(classnames, failStats[0].tolist(), width=1.2, color='blue')
  35. #plt.savefig(str(save_dir) + "/evaluates/漏报.jpg")
  36. print(line + "\n")
  37. #错报
  38. errorIndexes = torch.where(totalResults[:,1] == 3.0)
  39. errorItems = totalResults[errorIndexes]
  40. errorStats = torch.zeros(len(names), len(names), dtype=torch.int, device=totalResults.device)
  41. for item in errorItems:
  42. errorStats[int(item[2]), int(item[3])] += 1
  43. print("错报- \n")
  44. print(errorStats)

本文转载自: https://blog.csdn.net/ntw516/article/details/125328362
版权归原作者 业余码农 所有, 如有侵权,请联系我们删除。

“YoloV5 模型自定义评估-误报、漏报、错报”的评论:

还没有评论