0


YoloV5 模型自定义评估-误报、漏报、错报

YoloV5模型训练成功后,可以通过自带的val.py文件进行评估分析,其提供mAp、Iou以及混淆矩阵等,很好,但是……领导不认可……/(ㄒoㄒ)/~~。领导要的是最直观的东西,比如这个模型识别目标的准确率,还有误报率等……。那么,领导的要求就是我们开发的方向:

为了得到准确率以及误报、漏报、错报的情况,需要使用模型检测已经标注过的样本,将检测结果与标注进行比对。

YoloV5的val.py文件已经实现了大部分功能,可以直接拿来改造:

首先,原文件中对预测结果的NMS处理函数调用如下:

out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)

val.py本来的一个作用就是与train.py配合,在模型迭代训练过程中提供评估,但是我们这里不需要处理太多可能,只要置信度最高的就好了,修改如下:

out = non_max_suppression(out, conf_thres, iou_thres, multi_label=False, agnostic=single_cls)

然后将检测结果与标注结果做对比,原代码中是在process_batch这个函数中进行处理,计算所有的检测框和标注框的IoU,根据检测类比和标注类别再进行处理。这里做一下改造,代码如下:

def process_batch(batch_i, detections, labels, iouv):
    """
    Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format.
    Arguments:
        detections (Array[N, 6]), x1, y1, x2, y2, conf, class
        labels (Array[M, 5]), class, x1, y1, x2, y2
    Returns:
        correct (Array[N, 10]), for 10 IoU levels
    """

    # 正确:0   误报: 1   漏报: 2    错报: 3
    labelNum = labels.shape[0]
    detectNum = detections.shape[0]
    if labelNum == 0 and detectNum == 0:
        return 0
    elif labelNum == 0: #存在误报
        results = torch.zeros(detectNum, 6 , dtype=torch.float, device=iouv.device)
        results[:,0] = batch_i
        results[:,1] = 1
        results[:,2] = detections[:, 5]   #检测类别
        results[:,3] = -1.0                 #标注类别
        #results[:,4] = 0.0                 #检测框和标注框的IoU
        results[:,5] = detections[:, 4]   #检测置信度
        return results
    elif detectNum == 0:  #存在漏报
        results = torch.zeros(labelNum, 6, dtype=torch.float, device=iouv.device)
        results[:,0] = batch_i
        results[:,1] = 2
        results[:,2] = -1.0                #检测类别
        results[:,3] = labels[:, 0]      #标注类别
        #results[:,4] = 0.0                 #检测框和标注框的IoU
        #results[:,5] = 0.0                 #检测置信度
        return results
    else:
        tempresults = torch.zeros(0, 6, dtype=torch.float, device=iouv.device)
        indexes = torch.zeros(1, detectNum, dtype=torch.float, device=iouv.device)
        validNum = 0
        i=0 
        for label in labels:
            j=0
            bMatchLabel = False
            for detect in detections:
                iou = singlebox_iou(label[1:], detect[:4])
                if iou > 0.5:
                    bMatchLabel = True
                    indexes[:,j] = 1.0
                    validNum += 1
                    result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
                    result[:,0] = batch_i
                    result[:,1] = 0 if detect[5] == label[0] else 3
                    result[:,2] = detect[5]     #检测类别
                    result[:,3] = label[0]      #标注类别
                    result[:,4] = iou           #检测框和标注框的IoU
                    result[:,5] = detect[4]     #检测置信度
                    #tempresults.append(result)
                    tempresults = torch.cat((tempresults, result), 0)
                j += 1
            i += 1
            if not bMatchLabel:  #如果有标注狂没有匹配到检测框的为漏报
                validNum += 1
                result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
                result[:,0] = batch_i        
                result[:,1] = 2
                result[:,2] = -1.0          #检测类别
                result[:,3] = label[0]      #标注类别
                #tempresults.append(result)
                tempresults = torch.cat((tempresults, result), 0)

        #如果有检测框没有匹配到标注框的为误报
        noMatchDect = indexes[0][(indexes[0] == 0).nonzero()]
        if noMatchDect.shape[0] > 0:
            for index in noMatchDect:
                result = torch.zeros(1, 6 , dtype=torch.float, device=iouv.device)
                i = int(index[0])
                result[:,0] = batch_i
                result[:,1] = 1
                result[:,2] = detections[i][5]   #检测类别
                result[:,3] = -1.0               #标注类别
                result[:,5] = detections[i][4]   #检测置信度
                #tempresults.append(result)
                tempresults = torch.cat((tempresults, result), 0)

        return tempresults

还得解释一下这里对误报、漏报、错报的定义:

误报:没有标注,但是检测出目标

漏报:有标注,但是没检测出目标

错报:有标注,和检测目标的类别不一样

所有样本处理完后进行统计,统计每个类别的误报、漏报数量,以及错报(其实混淆矩阵就是这个功能,只不过这里用的是数量),代码如下:

def analyse(totalResults, names, save_dir):
    classnames =[]
    for i in range(len(names)):
        classnames.append(str(names[i]))
        
    #误报
    misIndexes = torch.where(totalResults[:,1] == 1.0)
    misItems = totalResults[misIndexes]
    misStats = torch.zeros(1, len(names), dtype=torch.int, device=totalResults.device)
    line = "误报- "
    for i in range(len(names)):
        n = len(torch.where(misItems[:,2] == i)[0])
        misStats[:,i] = n
        line += names[i] + ":" + str(n) +" "

    print(line + "\n")

    #plt.rcParams['font.sans-serif']=['SimHei']
    #plt.style.use('ggplot')
    #plt.title("误报")
    #plt.xlabel("异物类别")
    #plt.ylabel("数量")
    #plt.bar(classnames, misStats[0].tolist(), width=1.2, color='blue')
    #plt.savefig(str(save_dir) + "/evaluates/误报.jpg")

    #漏报
    failIndexes = torch.where(totalResults[:,1] == 2.0)
    failItems = totalResults[failIndexes]
    failStats = torch.zeros(1, len(names), dtype=torch.int, device=totalResults.device)
    line = "漏报- "
    for i in range(len(names)):
        n = len(torch.where(failItems[:,3] == i)[0])
        failStats[:,i] = n
        line += names[i] + ":" + str(n) +" "

    #plt.title("漏报")
    #plt.xlabel("异物类别")
    #plt.ylabel("数量")
    #plt.bar(classnames, failStats[0].tolist(), width=1.2, color='blue')
    #plt.savefig(str(save_dir) + "/evaluates/漏报.jpg")

    print(line + "\n")

    #错报
    errorIndexes  = torch.where(totalResults[:,1] == 3.0)
    errorItems = totalResults[errorIndexes]
    errorStats = torch.zeros(len(names), len(names), dtype=torch.int, device=totalResults.device)
    for item in errorItems:
        errorStats[int(item[2]), int(item[3])] += 1

    print("错报- \n")
    print(errorStats)

本文转载自: https://blog.csdn.net/ntw516/article/details/125328362
版权归原作者 业余码农 所有, 如有侵权,请联系我们删除。

“YoloV5 模型自定义评估-误报、漏报、错报”的评论:

还没有评论