YoloV5模型训练成功后,可以通过自带的val.py文件进行评估分析,其提供mAp、Iou以及混淆矩阵等,很好,但是……领导不认可……/(ㄒoㄒ)/~~。领导要的是最直观的东西,比如这个模型识别目标的准确率,还有误报率等……。那么,领导的要求就是我们开发的方向:
为了得到准确率以及误报、漏报、错报的情况,需要使用模型检测已经标注过的样本,将检测结果与标注进行比对。
YoloV5的val.py文件已经实现了大部分功能,可以直接拿来改造:
首先,原文件中对预测结果的NMS处理函数调用如下:
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
val.py本来的一个作用就是与train.py配合,在模型迭代训练过程中提供评估,但是我们这里不需要处理太多可能,只要置信度最高的就好了,修改如下:
out = non_max_suppression(out, conf_thres, iou_thres, multi_label=False, agnostic=single_cls)
然后将检测结果与标注结果做对比,原代码中是在process_batch这个函数中进行处理,计算所有的检测框和标注框的IoU,根据检测类比和标注类别再进行处理。这里做一下改造,代码如下:
def process_batch(batch_i, detections, labels, iouv):
"""
Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format.
Arguments:
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
labels (Array[M, 5]), class, x1, y1, x2, y2
Returns:
correct (Array[N, 10]), for 10 IoU levels
"""
# 正确:0 误报: 1 漏报: 2 错报: 3
labelNum = labels.shape[0]
detectNum = detections.shape[0]
if labelNum == 0 and detectNum == 0:
return 0
elif labelNum == 0: #存在误报
results = torch.zeros(detectNum, 6 , dtype=torch.float, device=iouv.device)
results[:,0] = batch_i
results[:,1] = 1
results[:,2] = detections[:, 5] #检测类别
results[:,3] = -1.0 #标注类别
#results[:,4] = 0.0 #检测框和标注框的IoU
results[:,5] = detections[:, 4] #检测置信度
return results
elif detectNum == 0: #存在漏报
results = torch.zeros(labelNum, 6, dtype=torch.float, device=iouv.device)
results[:,0] = batch_i
results[:,1] = 2
results[:,2] = -1.0 #检测类别
results[:,3] = labels[:, 0] #标注类别
#results[:,4] = 0.0 #检测框和标注框的IoU
#results[:,5] = 0.0 #检测置信度
return results
else:
tempresults = torch.zeros(0, 6, dtype=torch.float, device=iouv.device)
indexes = torch.zeros(1, detectNum, dtype=torch.float, device=iouv.device)
validNum = 0
i=0
for label in labels:
j=0
bMatchLabel = False
for detect in detections:
iou = singlebox_iou(label[1:], detect[:4])
if iou > 0.5:
bMatchLabel = True
indexes[:,j] = 1.0
validNum += 1
result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
result[:,0] = batch_i
result[:,1] = 0 if detect[5] == label[0] else 3
result[:,2] = detect[5] #检测类别
result[:,3] = label[0] #标注类别
result[:,4] = iou #检测框和标注框的IoU
result[:,5] = detect[4] #检测置信度
#tempresults.append(result)
tempresults = torch.cat((tempresults, result), 0)
j += 1
i += 1
if not bMatchLabel: #如果有标注狂没有匹配到检测框的为漏报
validNum += 1
result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
result[:,0] = batch_i
result[:,1] = 2
result[:,2] = -1.0 #检测类别
result[:,3] = label[0] #标注类别
#tempresults.append(result)
tempresults = torch.cat((tempresults, result), 0)
#如果有检测框没有匹配到标注框的为误报
noMatchDect = indexes[0][(indexes[0] == 0).nonzero()]
if noMatchDect.shape[0] > 0:
for index in noMatchDect:
result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
i = int(index[0])
result[:,0] = batch_i
result[:,1] = 1
result[:,2] = detections[i][5] #检测类别
result[:,3] = -1.0 #标注类别
result[:,5] = detections[i][4] #检测置信度
#tempresults.append(result)
tempresults = torch.cat((tempresults, result), 0)
return tempresults
还得解释一下这里对误报、漏报、错报的定义:
误报:没有标注,但是检测出目标
漏报:有标注,但是没检测出目标
错报:有标注,和检测目标的类别不一样
所有样本处理完后进行统计,统计每个类别的误报、漏报数量,以及错报(其实混淆矩阵就是这个功能,只不过这里用的是数量),代码如下:
def analyse(totalResults, names, save_dir):
classnames =[]
for i in range(len(names)):
classnames.append(str(names[i]))
#误报
misIndexes = torch.where(totalResults[:,1] == 1.0)
misItems = totalResults[misIndexes]
misStats = torch.zeros(1, len(names), dtype=torch.int, device=totalResults.device)
line = "误报- "
for i in range(len(names)):
n = len(torch.where(misItems[:,2] == i)[0])
misStats[:,i] = n
line += names[i] + ":" + str(n) +" "
print(line + "\n")
#plt.rcParams['font.sans-serif']=['SimHei']
#plt.style.use('ggplot')
#plt.title("误报")
#plt.xlabel("异物类别")
#plt.ylabel("数量")
#plt.bar(classnames, misStats[0].tolist(), width=1.2, color='blue')
#plt.savefig(str(save_dir) + "/evaluates/误报.jpg")
#漏报
failIndexes = torch.where(totalResults[:,1] == 2.0)
failItems = totalResults[failIndexes]
failStats = torch.zeros(1, len(names), dtype=torch.int, device=totalResults.device)
line = "漏报- "
for i in range(len(names)):
n = len(torch.where(failItems[:,3] == i)[0])
failStats[:,i] = n
line += names[i] + ":" + str(n) +" "
#plt.title("漏报")
#plt.xlabel("异物类别")
#plt.ylabel("数量")
#plt.bar(classnames, failStats[0].tolist(), width=1.2, color='blue')
#plt.savefig(str(save_dir) + "/evaluates/漏报.jpg")
print(line + "\n")
#错报
errorIndexes = torch.where(totalResults[:,1] == 3.0)
errorItems = totalResults[errorIndexes]
errorStats = torch.zeros(len(names), len(names), dtype=torch.int, device=totalResults.device)
for item in errorItems:
errorStats[int(item[2]), int(item[3])] += 1
print("错报- \n")
print(errorStats)
版权归原作者 业余码农 所有, 如有侵权,请联系我们删除。