0


YOLO v5 代码精读(1) detect模块以及非极大值抑制

目录


YOLO简介

YOLO 是目前最先进的目标检测模型之一,现在博客上常有的是如何使用YOLO模型训练自己的数据集,而鲜有对YOLO代码的精读。我认为只有对算法和代码实现有全面的了解,才能将YOLO使用的更加得心应手。

这里的代码精读为YOLO v5,github版本为6.0。版本不同代码也会有所不同,请结合源码阅读本文。本文使用注释完成对每行代码的解读,文段来概括总结每个代码段。

yolo v5代码 6.0版本 github代码地址

argpares模块

在了解yolo v5代码之前,首先要了解python的一个标准模块:argparse。argparse是python自带的解析命令行参数的模块,可以用来定义和读取命令行中的参数。yolo v5中很多的参数都是通过argpares模块组织的,所以了解这个模块非常重要。

因为yolo v5是一个大型项目,最后可能会被部署至终端,所以yolo v5的代码中提供了通过命令行运行代码的方式。

1c2c192228554262b523baaf1c10e5e1.png

上图的命令中的"--source"参数,对应detect模块中下面的代码

b674a1dbe8ba4601854ee2c95ea0aa68.png

'--source'表示命令后面跟的参数名,type=str表示变量类型为字符串,default表示默认的参数,help参数表示执行help命令时,该参数名显示的帮助信息。

为了更方便理解,我创建了一个test.py 模块

  1. import argparse
  2. # 参数解析器
  3. arg = argparse.ArgumentParser()
  4. # 添加参数
  5. arg.add_argument('--aaa', default='hello world')
  6. arg.add_argument('--bbb', default=123)
  7. # 获取解析的参数
  8. opt = arg.parse_args()
  9. print(opt.aaa, opt.bbb)

命令参数通过arg.pares_args()函数获取,再通过调用属性的方式获取参数

b078a061f91a4243b887de7f52ee7b64.png

运行结果如上所示。若命令后没有参数名,则参数为默认值;若命令后跟了参数名和参数值,那么这个参数名的值,将会替换为输入的参数值。

yolo v5 的代码正是通过这种方式,将一些重要的参数(如超参数、数据集的路径等)组织起来。在开发阶段,可能不会以这种命令的方式去运行,一般是在部署的时候,才会去用命令去运行。所以开发时若想修改某个参数的值,可以修改这个命令参数名的default关键字参数。

detect模块

接下来就是对 yolo v5 代码的逐句解读

detect模块是对图像、视频、目录、流等进行推断。

导入部分

先看导入部分

  1. """
  2. Run inference on images, videos, directories, streams, etc.
  3. 对图像、视频、目录、流等进行推断。
  4. Usage:
  5. 使用
  6. $ python path/to/detect.py --source path/to/img.jpg --weights yolov5s.pt --img 640
  7. 使用 python 命令运行 detect.py 模块 --source后面跟图片的路径 --weights 后面跟权重文件的路径 --img表示图片的尺寸
  8. """
  9. import argparse
  10. import os
  11. import sys
  12. from pathlib import Path
  13. import cv2
  14. import numpy as np
  15. import torch
  16. import torch.backends.cudnn as cudnn
  17. """
  18. 确保root目录正确,避免导包时出现错误
  19. 因为以下导入的是自定义的包,若根目录错误就会导致导入失败,这里不再过多解释
  20. """
  21. FILE = Path(__file__).resolve() # __file__表示当前模块的路径
  22. ROOT = FILE.parents[0] # YOLOv5 root directory
  23. if str(ROOT) not in sys.path:
  24. sys.path.append(str(ROOT)) # add ROOT to PATH
  25. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  26. from models.experimental import attempt_load
  27. from utils.datasets import LoadImages, LoadStreams
  28. from utils.general import apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, colorstr, \
  29. increment_path, non_max_suppression, print_args, save_one_box, scale_coords, set_logging, \
  30. strip_optimizer, xyxy2xywh
  31. from utils.plots import Annotator, colors
  32. from utils.torch_utils import load_classifier, select_device, time_sync

中间的代码段是为了确保root目录正确,避免导包时出现错误。若根目录正确,也可省略中间的代码段。导入部分会在使用时再进行讲解。

主函数

接下来暂时跳过函数定义,先看主函数做了哪些操作

  1. if __name__ == "__main__":
  2. # 接收命令行参数
  3. opt = parse_opt()
  4. # 将命令行参数传入main函数
  5. main(opt)

在主函数中,先调用了parse_opt()函数,用于接收命令行参数

  1. """
  2. weight:表示模型的权重参数的路径
  3. source:表示数据源,可以是图片文件、目录、URL 0为网络摄像头
  4. imgsz:表示输入图片的大小 默认640*640
  5. conf-thres:置信度阈值,默认0.25 用于非极大值抑制
  6. iou-thres:iou阈值,默认0.45 用于非极大值抑制
  7. max-det:图片最多可以有多少个预测框
  8. device:程序被装载的位置 CPU或GPU
  9. view-img:是否展示图片 默认False
  10. save-text:是否将预测框保存为txt 默认为False
  11. save-conf: 是否将置信度保存到txt中 默认False
  12. save-crop: 是否保存裁剪预测框图片, 默认为False
  13. nosave: 不保存图片、视频 默认False 即保存结果
  14. classes: 设置只保留某一部分类别, 形如0或者0 2 3
  15. agnostic-nms: 是否多个类别一起计算nms 默认为False
  16. augment: 推断时是否进行数据增强 默认为False
  17. visualize: 是否可视化网络层输出特征 默认为False
  18. update: 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
  19. project: 保存结果的路径
  20. name: 保存结果的目录名
  21. exist_ok: 是否重新结果目录 默认为False
  22. line-thickness: 画框的线条粗细
  23. hide-labels: 可视化时隐藏预测类别
  24. hide-conf: 可视化时隐藏置信度
  25. half: 是否使用F16精度推理, 半进度提高检测速度
  26. dnn: 用OpenCV DNN预测
  27. """

以上为parse_opt()函数中,定义的所有命令行参数及注释。

函数最后返回一个参数对象,所有的命令行参数都在这个对象中,再将这个对象传入mian()函数

main()

  1. def main(opt):
  2. # general模块中的函数,用于检查依赖库是否完整
  3. check_requirements(exclude=('tensorboard', 'thop'))
  4. # 运行
  5. run(**vars(opt))

main()函数中只有两行代码,首先调用check_requirements()函数,这是从general模块中导入的函数,用于检查依赖库是否完整。exclude代表排除哪些库,此时函数不会检查这两个库是否存在,因为detect是预测阶段,thsorboard和thop是用于展示训练数据的,预测阶段不需要这两个库。

接下来调用run()函数,vars()函数返回对象的__dict__属性,可以理解为将opt转换为字典,再通过**进行解包,将字典内的键和值作为参数填入run()函数。通过解包的方式,实现了将命令行参数传参至run()函数。

run()

run()函数就是detect模块中进行预测的函数,所有预测工作都在这个函数中完成。

  1. @torch.no_grad() # 该装饰器表示以下函数内不会进行梯度计算和反向传播
  2. def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)

首先注意到run()函数有一个装饰器@torch.no_grad(),装饰器是一种拓展原来函数功能的一种函数。pytorch中的数据格式被称为tensor,用于存储高维数据。tensor中有一个属性为requires_grad,其值为True时,在反向传播的过程中就会计算其梯度,而@torch.no_grad()的作用就是将requires_grad的值置为False,此时便不会计算函数内所有tensor的梯度,有利于节省内存。

run()函数的参数与命令行参数一一对应,这里不再赘述。

接下来对run()函数逐段分析:

资源处理

  1. """解析资源路径"""
  2. # 将资源路径路径转换为字符串
  3. source = str(source)
  4. # bool类型 是否保存结果 保存(非不保存即为保存) 且 资源路径不以.txt结尾
  5. save_img = not nosave and not source.endswith('.txt')
  6. # bool类型 是否为网络摄像头 数据源为数字 或 以.txt结尾 或 小写字母以rtsp://,rtmp://,http://,https://开头
  7. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  8. ('rtsp://', 'rtmp://', 'http://', 'https://'))
  9. # 检查runs/detect目录下的exp目录到exp几了,并增加下一个exp目录,调用general模块中的函数,exist_ok表示只有在路径存在时创建目录
  10. save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # Path类 / 字符串表示在路径后增加一层路径
  11. # 若保存为txt,返回save/labels 若不保存为txt,则返回save_dir 再创建文件夹 parents:若父目录不存在,创建父目录。exist_ok:只有在目录不存在时创建目录
  12. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir

首先是对资源路径进行一些基础判断。判断是否保存结果,以及数据源是否为网络摄像头。接下来就是创建保存的路径。

  1. # 初始化日志信息
  2. set_logging()
  3. # 在控制台上输出YOLO的基本信息 包括当前时间 torch版本 CPU或GPU
  4. # device表示程序被装载在那块cpu或gpu上
  5. device = select_device(device) # select_device()函数是torch_utils中的函数,将程序装载至对应的位置
  6. # 是否使用半精读计算 需要更少的内存,但需要在支持的GPU上才能运行
  7. half &= device.type != 'cpu' # half precision only supported on CUDA

接下来就是初始化日志信息,以及选择将程序装载在哪块cpu或gpu上。

  1. """加载模型,解析文件后缀"""
  2. # 若weights参数是一个列表,则返回列表的第一项 否则返回整个weights 这里w为权重文件的路径
  3. w = str(weights[0] if isinstance(weights, list) else weights)
  4. # 是否分类,当前后缀名,支持的后缀名
  5. classify, suffix, suffixes = False, Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '']
  6. # 检查后缀名是否支持,否则抛出异常
  7. check_suffix(w, suffixes) # check weights have acceptable suffix
  8. # 将后缀名保存为具体的变量,若这个变量为True,则文件为对应的后缀名
  9. pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
  10. # 这里的stride和names为临时值 stride为yolo模型中定义的值,为计算的步幅 names为类别标签
  11. stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults

然后就是解析文件的后缀,先判别文件后缀是否合规,再将文件后缀保存为对象,方面后面的判断。

其中stride为特征层级的缩放尺寸,根据YOLO模型的原理,作者将原数据分成了多个大小不同的feature map,每个feature map 感受野不同,可以用于检测不同大小的物体,feature map 越小,模型的感受野越大,可以检测更大的物体,反之同理。stride即为feature map 的缩放尺寸。

  1. """根据不同的文件后缀,用不同的方式加载模型"""
  2. # 文件后缀为pt
  3. if pt:
  4. # 加载.pt格式的模型 如果文件名中含有torchscript,则通过torch.jit.load(w)加载模型,
  5. # 否则通过attempt_load(weights, map_location=device)加载模型
  6. model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
  7. # 从模型中获取计算的步幅
  8. stride = int(model.stride.max()) # model stride
  9. # 从模型中获取分类标签 如果模型中存在module属性,则返回model.module.names 否则返回model.names
  10. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  11. if half:
  12. # 使用半精读计算
  13. model.half() # to FP16
  14. # 使用两阶段的分类器
  15. if classify: # second-stage classifier
  16. # 加载resnet50作为模型
  17. modelc = load_classifier(name='resnet50', n=2) # initialize
  18. # 将模型装载到对应的位置
  19. modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
  20. # 文件后缀为 onnx
  21. elif onnx:
  22. # 如果使用opencv加载深度学习模型
  23. if dnn:
  24. # check_requirements(('opencv-python>=4.5.4',))
  25. # 通过opencv加载模型
  26. net = cv2.dnn.readNetFromONNX(w)
  27. else:
  28. # 如果使用opencv加载深度学习模型,则使用onnxruntime库加载
  29. check_requirements(('onnx', 'onnxruntime'))
  30. import onnxruntime
  31. session = onnxruntime.InferenceSession(w, None)
  32. # 其余的则为tensorflow模型
  33. else: # TensorFlow models
  34. # 检查tensorflow库是否存在
  35. check_requirements(('tensorflow>=2.4.1',))
  36. # 导入tensorflow库
  37. import tensorflow as tf
  38. # 文件后缀为pb
  39. if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
  40. # 以下代码为tensorflow加载pb模型的步骤
  41. def wrap_frozen_graph(gd, inputs, outputs):
  42. x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
  43. return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
  44. tf.nest.map_structure(x.graph.as_graph_element, outputs))
  45. graph_def = tf.Graph().as_graph_def()
  46. graph_def.ParseFromString(open(w, 'rb').read())
  47. frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
  48. # 文件后缀为 saved_model
  49. elif saved_model:
  50. # 加载saved_model模型
  51. model = tf.keras.models.load_model(w)
  52. # 文件后缀名为 tflite
  53. elif tflite:
  54. # 加载tflite模型
  55. interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
  56. interpreter.allocate_tensors() # allocate
  57. input_details = interpreter.get_input_details() # inputs
  58. output_details = interpreter.get_output_details() # outputs
  59. int8 = input_details[0]['dtype'] == np.uint8 # is TFLite quantized uint8 model
  60. # 检查图片尺寸 判断图片尺寸是不是模型步长的倍数 若不满足重新计算图片尺寸
  61. imgsz = check_img_size(imgsz, s=stride) # check image size

以上的大段代码是根据不同的模型文件,使用不同的方法加载模型。根据代码可以看出,yolo v5 不仅仅支持pytorch的模型,还支持opencv,tensorflow等深度学习库的模型。export模块中也写出了不同模型不同的导出方法。yolo v5 要考虑到系统的兼容性,所以需要兼容这么多格式的模型。但我认为,在实际的使用过程中,这样的代码过于冗杂,只需要兼容一种模型即可。

  1. # 调用网络摄像头
  2. if webcam:
  3. # 检查图片是否可以展示成功
  4. # 这里通过opencv调用摄像头
  5. view_img = check_imshow()
  6. # 优化运行效率
  7. cudnn.benchmark = True # set True to speed up constant image size inference
  8. # 加载流 可以加载网络摄像头甚至Youtube中的视频链接
  9. dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
  10. bs = len(dataset) # batch_size
  11. else:
  12. # 如果不是网络摄像头,那么加载图片
  13. dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
  14. bs = 1 # batch_size
  15. # 每个batch_size的vid_path与vide_writer 二维数组 初始化为None
  16. vid_path, vid_writer = [None] * bs, [None] * bs

上述视频为对数据源的加载,根据webcam判断应该加载视频流或图片。其中LoadStreams与LoadImages均重写了__next__()函数,可以使用for循环进行迭代,将每张照片拿到 。

  1. # Run inference
  2. """运行推断过程 将图片带入模型得出结果"""
  3. if pt and device.type != 'cpu':
  4. # 带入数据校验模型 使用一张空白的图片进行一次前向推断
  5. model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
  6. # 初始化一些中间变量
  7. dt, seen = [0.0, 0.0, 0.0], 0

接下来执行推断过程,首先要用空白的图片数据带入模型,进行一次前向推断。这个过程可以理解为一个热身的过程,通过热身可以校验模型中数据的维度等是否正确。这是一种训练技巧。

for循环

  1. # 从图片或视频加载每一张图片
  2. # 每张图片的推断过程均在for循环内完成
  3. # path为图片的路径 img为resize处理后的图片 im0s表示未处理的原图 vid_cap为视频流实例
  4. for path, img, im0s, vid_cap in dataset:
  5. """处理图片"""
  6. # 获取cpu上执行的时间
  7. t1 = time_sync()
  8. # 如果模型为onnx格式
  9. if onnx:
  10. # 将图片数组中的元素改为float32
  11. img = img.astype('float32')
  12. # 若模型不为onnx
  13. else:
  14. # 把图片数组装载在对应的cpu或gpu上
  15. img = torch.from_numpy(img).to(device)
  16. # 如果使用半精读计算 就将数据转为半精读 否则还是float
  17. img = img.half() if half else img.float() # uint8 to fp16/32
  18. # /255.0将数据映射至0-1之间 归一化处理
  19. img = img / 255.0 # 0 - 255 to 0.0 - 1.0
  20. # 若图片为三维
  21. if len(img.shape) == 3:
  22. # 为图片扩展一个维度 batch_size的维度
  23. img = img[None] # expand for batch dim
  24. # 获取结束时间
  25. t2 = time_sync()
  26. # 将时间累积
  27. dt[0] += t2 - t1

接下来就是通过for循环,将每张照片从流或文件夹中获取出来,每执行一次for循环就是完成一次对图片的推断,对于这张图片的推断均体现在for循环内。这里先截取了for循环的一部分,首先是对图片的处理,将图片数组进行归一化,并修改维度。

  1. # Inference
  2. """推断过程 不同模型通过不同方式得出预测结果"""
  3. # 若模型为pt格式
  4. if pt:
  5. # visualize为可视化,默认为False,若进行可视化就新建目录,并保存结果,否则返回false
  6. visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
  7. # 获取预测结果,并保存第一维度为pred
  8. # pred为预测的结果 shape为(1,18900,85)
  9. pred = model(img, augment=augment, visualize=visualize)[0]
  10. # 若模型为onnx格式
  11. elif onnx:
  12. # 若使用opencv的深度学习
  13. if dnn:
  14. # 输入模型
  15. net.setInput(img)
  16. # 获取结果
  17. pred = torch.tensor(net.forward())
  18. else:
  19. # 获取预测结果
  20. pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
  21. # 使用tensorflow模型
  22. else: # tensorflow model (tflite, pb, saved_model)
  23. imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
  24. if pb:
  25. # 获取pb模型的预测结果
  26. pred = frozen_func(x=tf.constant(imn)).numpy()
  27. elif saved_model:
  28. # 获取save_model模型的预测结果
  29. pred = model(imn, training=False).numpy()
  30. elif tflite:
  31. # 获取tflite模型的预测结果
  32. if int8:
  33. scale, zero_point = input_details[0]['quantization']
  34. imn = (imn / scale + zero_point).astype(np.uint8) # de-scale
  35. interpreter.set_tensor(input_details[0]['index'], imn)
  36. interpreter.invoke()
  37. pred = interpreter.get_tensor(output_details[0]['index'])
  38. if int8:
  39. scale, zero_point = output_details[0]['quantization']
  40. pred = (pred.astype(np.float32) - zero_point) * scale # re-scale
  41. # pred[..., 0]这样的语法表示抽取数组的第几列,作为一个tensor
  42. pred[..., 0] *= imgsz[1] # x
  43. pred[..., 1] *= imgsz[0] # y
  44. pred[..., 2] *= imgsz[1] # w
  45. pred[..., 3] *= imgsz[0] # h
  46. pred = torch.tensor(pred)

上述代码将图片输入不同的模型,并得到预测结果。将data/images/bus.jpg图片输入模型,得到的是一个shape为(1,18900,85)的tensor;将data/images/zidane.jpg输入模型,齐达内的图片得到的是一个shape为(1,15120,85)的tensor。

1表示batch_size,表示这个batch_size中只有一张图片的预测结果,因为输入模型的batch_size就是1,所以输出的结果也为1。

18900或15120表示模型预测出了1890或15120个预测框。

85表示每个预测框中含有4个位置信息(包括预测框的x,y,w,h)、一个置信度信息和coco数据集80个类别的条件概率信息。

  1. # NMS 非极大值抑制 pred为预测结果, conf_thres为置信度阈值 默认为0.25 iou_thres为iou阈值 默认为0.45
  2. # classes为是否只保留特定的类别 默认为None agnostic_nms True表示多个类一起计算nms,False表示按照不同的类分别进行计算nms
  3. # max_det为保留的最大检测框数 默认为1000 也就是一张图片最多检测1000个物体
  4. # 经过极大值抑制后 shape变为(1,5,6) 1表示batch_size 5表示共有5个预测框 6表示x,y,x,y两个坐标,1个类别概率,1个类别索引
  5. pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
  6. dt[2] += time_sync() - t3

有了预测结果,接下来要从预测出的18900个预测框中筛选出最合适的框,这个过程被称为非极大值抑制。

  1. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  2. labels=(), max_det=300):
  3. """Runs Non-Maximum Suppression (NMS) on inference results
  4. Returns:
  5. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  6. """
  7. # 获取类别数
  8. nc = prediction.shape[2] - 5 # number of classes
  9. # prediction[..., 4] 表示数据的第四列,这里指置信度
  10. # 判断每一位的置信度是否大于置信度阈值 返回一个shape为(1,18900)的bool类型的tensor,代表这一位是否大于置信度阈值
  11. xc = prediction[..., 4] > conf_thres # candidates
  12. # Checks 检测阈值是否合法
  13. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  14. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  15. # Settings
  16. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  17. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  18. time_limit = 10.0 # seconds to quit after
  19. redundant = True # require redundant detections
  20. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  21. merge = False # use merge-NMS
  22. t = time.time()
  23. # 定义输出数据
  24. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  25. # xi表示第几个batch x表示这个batch内所有预测框
  26. for xi, x in enumerate(prediction): # image index, image inference
  27. # Apply constraints
  28. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  29. # xi为0 表示第0个batch
  30. # xc[xi] 表示获取这个batch内所有的置信度是否大于置信度阈值
  31. # 这种表示方式为x[xc[xi]] 将x中为True的值取出,并赋值给x
  32. # x现在表示,18900个预测框置信度中大于置信度阈值的预测框 shape为(52, 85)
  33. # 此时根据置信度阈值的过滤,预测框只剩下18900个
  34. x = x[xc[xi]] # confidence
  35. # Cat apriori labels if autolabelling
  36. # 暂时不理解
  37. if labels and len(labels[xi]):
  38. l = labels[xi]
  39. v = torch.zeros((len(l), nc + 5), device=x.device)
  40. v[:, :4] = l[:, 1:5] # box
  41. v[:, 4] = 1.0 # conf
  42. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  43. x = torch.cat((x, v), 0)
  44. # If none remain process next image
  45. # 若预测框数为0,则处理下一个batch
  46. if not x.shape[0]:
  47. continue
  48. # Compute conf
  49. # x[:, 5:] = x[:, 5:] * x[:, 4:5]
  50. # 80个类别的概率为条件类别概率,是假设这个框内有物体的情况下,该物体是某一类的概率
  51. # 置信度可以理解为预测框内存在物体的概率
  52. # 将所有类别的条件类别概率与置信度相乘才是某个类别的真正概率
  53. # 经过计算,物体的条件类别概率被计算为了真正的概率
  54. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  55. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  56. # 将x,y,w,h (一个坐标和一个宽一个高) 的表示方法改为 x,y,x,y(两个坐标,矩形的左上角和右下角)表示
  57. box = xywh2xyxy(x[:, :4])
  58. # Detections matrix nx6 (xyxy, conf, cls)
  59. if multi_label:
  60. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  61. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  62. # 每个类只标一个标签
  63. else: # best class only
  64. # conf为最大的类别概率,j为最大类别概率的索引值
  65. conf, j = x[:, 5:].max(1, keepdim=True)
  66. # 将结果拼到一起
  67. # box shape为(52, 4) 表示52个预测框的xyxy坐标表示
  68. # conf shape为(52, 1) 表示52个预测框的最大类别概率
  69. # j shape为(52,1) 表示最大类别概率的索引值,用于表示哪个类别
  70. # 此时x的shape为(52, 6)
  71. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  72. # Filter by class
  73. if classes is not None:
  74. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  75. # Apply finite constraint
  76. # if not torch.isfinite(x).all():
  77. # x = x[torch.isfinite(x).all(1)]
  78. # Check shape
  79. n = x.shape[0] # number of boxes
  80. # 如果没有box 执行下一张图片的推断
  81. if not n: # no boxes
  82. continue
  83. # 如果预测框的个数大于了最大值
  84. # 这里的预测框个数指的不是最终预测的个数
  85. elif n > max_nms: # excess boxes
  86. # 根据概率值排序
  87. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  88. # Batched NMS
  89. # agnostic参数 True表示多个类一起计算nms,False表示按照不同的类分别进行计算nms
  90. # 这里的c为偏移量 若不同的类分别进行计算nms,就把类别的索引乘一个很大的值
  91. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  92. # boxes为原来的box加上c偏置量 这样做是为了确保不同类别的预测框不会重叠
  93. # scores为类别的概率值
  94. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  95. # 执行非极大值抑制
  96. # 这里的非极大值抑制是torchvision中实现的非极大值抑制
  97. # nms的原理为计算不同预测框的iou(交并比),若大于阈值,则判定两个预测框预测了同一物体
  98. # 通过这种方式筛出预测框 返回最后确定的预测框索引
  99. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  100. # 如果最后的预测框数量大于最大预测数
  101. if i.shape[0] > max_det: # limit detections
  102. # 只取到最大预测数
  103. i = i[:max_det]
  104. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  105. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  106. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  107. weights = iou * scores[None] # box weights
  108. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  109. if redundant:
  110. i = i[iou.sum(1) > 1] # require redundancy
  111. # output为先前定义的空白结果
  112. # 让output[xi]目的是使输出结果的batch与输入相对应
  113. # x[i]表述从x中取出i中的索引
  114. output[xi] = x[i]
  115. # 非极大值抑制过程时间超时
  116. if (time.time() - t) > time_limit:
  117. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  118. break # time limit exceeded
  119. # 返回结果 output的shape为(1,5,6)
  120. return output

非极大值抑制是yolo在处理预测结果时的重要环节,所以这里把非极大值抑制的代码解读也写出来,非极大值抑制主要分为两个步骤。

第一步为过滤低置信度预测框。置信度反应了该预测框中存在对象的概率。经过模型的推断,产生了18900个预测框,过滤掉低置信度的预测框后,仅剩52个预测框,而剩下的52个预测框,大多数都是多个预测框预测了同一个物体,这时便要进行第二步。

第二步为通过IOU阈值过滤。首先把不同类别的预测框加上不同的偏置量,保证不同类别的预测框不会有重叠。然后计算每个类别预测框的IOU,若两个预测框的IOU大于给定阈值,那么就判定这两个预测框预测了同一个的对象,并只保留一个类别概率大的框。

  1. # Second-stage classifier (optional)
  2. if classify:
  3. pred = apply_classifier(pred, modelc, img, im0s)
  4. # Process predictions
  5. # i表示 batch det表示五个预测框
  6. for i, det in enumerate(pred): # per image
  7. # seen 为计数
  8. seen += 1
  9. # 如果数据源是网络摄像头
  10. if webcam: # batch_size >= 1
  11. p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count
  12. # 数据源不是网络摄像头
  13. else:
  14. p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)
  15. # 设置保存路径
  16. p = Path(p) # to Path
  17. save_path = str(save_dir / p.name) # img.jpg
  18. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
  19. # 设置打印图片的信息
  20. s += '%gx%g ' % img.shape[2:] # print string
  21. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  22. # 保存图片
  23. imc = im0.copy() if save_crop else im0 # for save_crop
  24. # 绘图类实例
  25. annotator = Annotator(im0, line_width=line_thickness, example=str(names))
  26. # 如果有预测框
  27. if len(det):
  28. # Rescale boxes from img_size to im0 size
  29. # 映射图片的尺寸
  30. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  31. # Print results
  32. # 输出结果
  33. for c in det[:, -1].unique():
  34. n = (det[:, -1] == c).sum() # detections per class
  35. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  36. # Write results
  37. # 保存结果
  38. for *xyxy, conf, cls in reversed(det):
  39. # 保存txt文件
  40. if save_txt: # Write to file
  41. # 将坐标转换为旧的格式
  42. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  43. line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
  44. # 将预测结果写入文件 路径默认为“runs\detect\exp*\labels”
  45. with open(txt_path + '.txt', 'a') as f:
  46. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  47. if save_img or save_crop or view_img: # Add bbox to image
  48. c = int(cls) # integer class
  49. # 获取类别标签
  50. label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
  51. # 绘制含有标签的边框
  52. annotator.box_label(xyxy, label, color=colors(c, True))
  53. # 将预测框内的图片单独保存
  54. if save_crop:
  55. save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
  56. # Print time (inference-only)
  57. print(f'{s}Done. ({t3 - t2:.3f}s)')
  58. # Stream results
  59. # im0为绘制好的图片
  60. im0 = annotator.result()
  61. # 如果显示该图片
  62. if view_img:
  63. cv2.imshow(str(p), im0)
  64. cv2.waitKey(1) # 1 millisecond
  65. # Save results (image with detections)
  66. # 保存绘制完的图片
  67. if save_img:
  68. # 若为图片
  69. if dataset.mode == 'image':
  70. # 向路径中保存图片
  71. cv2.imwrite(save_path, im0)
  72. # 是视频或者流
  73. else: # 'video' or 'stream'
  74. if vid_path[i] != save_path: # new video
  75. vid_path[i] = save_path
  76. if isinstance(vid_writer[i], cv2.VideoWriter):
  77. vid_writer[i].release() # release previous video writer
  78. if vid_cap: # video
  79. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  80. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  81. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  82. else: # stream
  83. fps, w, h = 30, im0.shape[1], im0.shape[0]
  84. save_path += '.mp4'
  85. # 最后保存为视频
  86. vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  87. vid_writer[i].write(im0)

接下来的内容就比较简单了,首先将预测款绘制在图片上,然后将图片保存。此时for循环结束

输出结果

  1. # Print results
  2. t = tuple(x / seen * 1E3 for x in dt) # speeds per image
  3. print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
  4. if save_txt or save_img:
  5. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  6. print(f"Results saved to {colorstr('bold', save_dir)}{s}")
  7. if update:
  8. strip_optimizer(weights) # update model (to fix SourceChangeWarning)

最后将结果输出在控制台,detect模块到此结束。

其他的几个模块后续更新。


本文转载自: https://blog.csdn.net/qq_63708623/article/details/128448549
版权归原作者 G.E.N. 所有, 如有侵权,请联系我们删除。

“YOLO v5 代码精读(1) detect模块以及非极大值抑制”的评论:

还没有评论