0


YOLOv5+姿态估计HRnet与SimDR检测视频中的人体关键点

一、前言

  1. 由于工程项目中需要对视频中的person进行关键点检测,我测试各个算法后,并没有采用比较应用化成熟的Openpose,决定采用检测精度更高的HRnet系列。但是由于官方给的算法只能测试数据集,需要自己根据算法模型编写实例化代码。
  2. 本文根据SimDR工程实现视频关键点检测。SimDR根据HRnet改进而来,整个工程既包括HRnet又包括改进后的算法,使用起来较为方便,而且本文仅在cpu上就可以跑通整个工程。

二、环境配置

  1. python的环境主要就是按照工程中SimDRyolov5requirement.txt安装即可。总之缺啥装啥。

三、工程准备

1、克隆工程

  1. git clone https://github.com/leeyegy/SimDR.git #克隆姿态估计工程
  2. cd SimDR
  3. git clone -b v5.0 https://github.com/ultralytics/yolov5.git #在姿态估计工程中添加yolov5算法

2、目标检测

①添加权重文件

  1. 添加yolov5x.pt(见评论区网盘)到‘ SimDR/yolov5/weights/ ’文件夹下。

②获取边界框

  1. yolov5文件夹下新建YOLOv5.py,复制以下内容到文件中。注意:根据大家的反馈,不同的电脑,导入yolov5相关包时会不同的方式,代码中我是from yolov5.xxx import xxx,但是有些可以不用前面的yolov5,大家自行尝试哈。一般出现No module xxx 都是有关yolov5 的包导入出错哈。
  1. import argparse
  2. import time
  3. from pathlib import Path
  4. import numpy as np
  5. import cv2
  6. import torch
  7. import torch.backends.cudnn as cudnn
  8. from numpy import random
  9. import sys
  10. import os
  11. from yolov5.models.experimental import attempt_load
  12. from yolov5.utils.datasets import LoadStreams, LoadImages
  13. from yolov5.utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
  14. scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
  15. from yolov5.utils.plots import plot_one_box
  16. from yolov5.utils.torch_utils import select_device, load_classifier, time_synchronized
  17. from yolov5.utils.datasets import letterbox
  18. class Yolov5():
  19. def __init__(self, weights=None, opt=None, device=None):
  20. """
  21. @param weights:
  22. @param save_txt:
  23. @param opt:
  24. @param device:
  25. """
  26. self.weights = weights
  27. self.device = device
  28. # save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
  29. # save_dir.mkdir(parents=True, exist_ok=True) # make dir
  30. self.img_size = 640
  31. self.model = attempt_load(weights, map_location=self.device)
  32. self.stride = int(self.model.stride.max())
  33. self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
  34. self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names]
  35. self.opt = opt
  36. def detect(self,img0):
  37. """
  38. @param img0: 输入图片 shape=[h,w,3]
  39. @return:
  40. """
  41. person_boxes = np.ones((6))
  42. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  43. # Convert
  44. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  45. img = np.ascontiguousarray(img)
  46. img = torch.from_numpy(img).to(self.device)
  47. img = img.float() # uint8 to fp16/32
  48. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  49. if img.ndimension() == 3:
  50. img = img.unsqueeze(0)
  51. pred = self.model(img, augment=self.opt.augment)[0]
  52. # Apply NMS
  53. pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, classes=self.opt.classes, agnostic=self.opt.agnostic_nms)
  54. for i, det in enumerate(pred):
  55. if len(det):
  56. # Rescale boxes from img_size to im0 size
  57. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()
  58. boxes = reversed(det)
  59. boxes = boxes.cpu().numpy() #2022.04.06修改,在GPU上跑boxes无法直接转numpy数据
  60. #for i , box in enumerate(np.array(boxes)):
  61. for i , box in enumerate(boxes):
  62. if int(box[-1]) == 0 and box[-2]>=0.7:
  63. person_boxes=np.vstack((person_boxes , box))
  64. # label = f'{self.names[int(box[-1])]} {box[-2]:.2f}'
  65. # print(label)
  66. # plot_one_box(box, img0, label=label, color=self.colors[int(box[-1])], line_thickness=3)
  67. # cv2.imwrite('result1.jpg',img0)
  68. # print(s)
  69. # print(person_boxes,np.ndim(person_boxes))
  70. if np.ndim(person_boxes)>=2 :
  71. person_boxes_result = person_boxes[1:]
  72. boxes_result = person_boxes[1:,:4]
  73. else:
  74. person_boxes_result = []
  75. boxes_result = []
  76. return person_boxes_result,boxes_result
  77. def yolov5test(opt,path = ''):
  78. detector = Yolov5(weights='weights/yolov5x.pt',opt=opt,device=torch.device('cpu'))
  79. img0 = cv2.imread(path)
  80. personboxes ,boxes= detector.detect(img0)
  81. for i,(x1,y1,x2,y2) in enumerate(boxes):
  82. print(x1,y1,x2,y2)
  83. print(personboxes,'\n',boxes)
  84. if __name__ == '__main__':
  85. parser = argparse.ArgumentParser()
  86. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  87. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  88. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  89. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  90. parser.add_argument('--augment', action='store_true', help='augmented inference')
  91. parser.add_argument('--update', action='store_true', help='update all model')
  92. parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  93. parser.add_argument('--name', default='exp', help='save results to project/name')
  94. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  95. opt = parser.parse_args()
  96. print(opt)
  97. # check_requirements(exclude=('pycocotools', 'thop'))
  98. with torch.no_grad():
  99. yolov5test(opt,'data/images/zidane.jpg')

③路径问题

  1. 本文代码是在pycharm中运行,yolov5工程的加入导致有些文件夹名称相同,pycharm会搞混,可能会出现某些包找不到。这里需要先运行一下YOLOv5.py脚本,根据报错改一下import的内容。举个例子,./SimDR/yolov5/models/experimental.py 文件中会出现图片中的问题

改成如下即可,其他的文件改法相同。

④添加SPPF模块

yolov5 v5.0工程中没有SPPF模块,此时我们需要在./SimDR/yolov5/models/common.py文件末尾加入以下代码。

  1. import warnings
  2. class SPPF(nn.Module):
  3. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
  4. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  5. super().__init__()
  6. c_ = c1 // 2 # hidden channels
  7. self.cv1 = Conv(c1, c_, 1, 1)
  8. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  9. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  10. def forward(self, x):
  11. x = self.cv1(x)
  12. with warnings.catch_warnings():
  13. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  14. y1 = self.m(x)
  15. y2 = self.m(y1)
  16. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))

3、姿态估计

①添加权重

  1. SimDR文件夹下新建weight/hrnet文件夹,添加pose_hrnet_w48_384x288.pth等文件(见评论区网盘)

②修改yaml文件

  1. SimDR/experiments/文件夹下是cocompii数据集的配置文件,本文以coco为例。

  1. 接下来,修改./SimDR/experiments/coco/hrnet/heatmap/w48_384x288_adam_lr1e-3.yaml文件中的TEST部分的MODEL_FILE路径,如图所示。(SimDR算法的配置文件同理改动。)

③获取关键点

  1. 在’ SimDR/ ‘文件夹下新建Point_detect.py ,复制以下内容到文件中。
  2. 注意:代码第12行的路径要改成自己yolov5工程的路径,有这条代码才能正常运行。

【2022.04.16更新:根据评论区的建议,为关键点增加置信度值,这个值我是根据模型输出经过softmax后取最大值(关键点坐标就是这个最大值的索引),仅供参考。根据这个置信度可以解决半身照也会绘制全部点的问题。】

  1. import cv2
  2. import numpy as np
  3. import torch
  4. from torchvision.transforms import transforms
  5. import torch.nn.functional as F
  6. from lib.config import cfg
  7. from yolov5.YOLOv5 import Yolov5
  8. from lib.utils.transforms import flip_back_simdr,transform_preds,get_affine_transform
  9. from lib import models
  10. import argparse
  11. import sys
  12. sys.path.insert(0, 'D:\\Study\\Pose Estimation\\SimDR\\yolov5')
  13. class Points():
  14. def __init__(self,
  15. model_name='sa-simdr',
  16. resolution=(384,288),
  17. opt=None,
  18. yolo_weights_path="./yolov5/weights/yolov5x.pt",
  19. ):
  20. """
  21. Initializes a new SimpleHRNet object.
  22. HRNet (and YOLOv3) are initialized on the torch.device("device") and
  23. its (their) pre-trained weights will be loaded from disk.
  24. Args:
  25. c (int): number of channels (when using HRNet model) or resnet size (when using PoseResNet model).
  26. nof_joints (int): number of joints.
  27. checkpoint_path (str): path to an official hrnet checkpoint or a checkpoint obtained with `train_coco.py`.
  28. model_name (str): model name (HRNet or PoseResNet).
  29. Valid names for HRNet are: `HRNet`, `hrnet`
  30. Valid names for PoseResNet are: `PoseResNet`, `poseresnet`, `ResNet`, `resnet`
  31. Default: "HRNet"
  32. resolution (tuple): hrnet input resolution - format: (height, width).
  33. Default: (384, 288)
  34. interpolation (int): opencv interpolation algorithm.
  35. Default: cv2.INTER_CUBIC
  36. multiperson (bool): if True, multiperson detection will be enabled.
  37. This requires the use of a people detector (like YOLOv3).
  38. Default: True
  39. return_heatmaps (bool): if True, heatmaps will be returned along with poses by self.predict.
  40. Default: False
  41. return_bounding_boxes (bool): if True, bounding boxes will be returned along with poses by self.predict.
  42. Default: False
  43. max_batch_size (int): maximum batch size used in hrnet inference.
  44. Useless without multiperson=True.
  45. Default: 16
  46. yolo_model_def (str): path to yolo model definition file.
  47. Default: "./model/detectors/yolo/config/yolov3.cfg"
  48. yolo_class_path (str): path to yolo class definition file.
  49. Default: "./model/detectors/yolo/data/coco.names"
  50. yolo_weights_path (str): path to yolo pretrained weights file.
  51. Default: "./model/detectors/yolo/weights/yolov3.weights.cfg"
  52. device (:class:`torch.device`): the hrnet (and yolo) inference will be run on this device.
  53. Default: torch.device("cpu")
  54. """
  55. self.model_name = model_name
  56. self.resolution = resolution # in the form (height, width) as in the original implementation
  57. self.aspect_ratio = resolution[1]/resolution[0]
  58. self.yolo_weights_path = yolo_weights_path
  59. self.flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
  60. [9, 10], [11, 12], [13, 14], [15, 16]]
  61. self.device = torch.device(opt.device)
  62. cfg.defrost()
  63. if model_name in ('sa-simdr','sasimdr','sa_simdr'):
  64. if resolution ==(384,288):
  65. cfg.merge_from_file('./experiments/coco/hrnet/sa_simdr/w48_384x288_adam_lr1e-3_split1_5_sigma4.yaml')
  66. elif resolution == (256,192):
  67. cfg.merge_from_file('./experiments/coco/hrnet/sa_simdr/w48_256x192_adam_lr1e-3_split2_sigma4.yaml')
  68. else:
  69. raise ValueError('Wrong cfg file')
  70. elif model_name in ('simdr'):
  71. if resolution == (256, 192):
  72. cfg.merge_from_file('./experiments/coco/hrnet/simdr/nmt_w48_256x192_adam_lr1e-3.yaml')
  73. else:
  74. raise ValueError('Wrong cfg file')
  75. elif model_name in ('hrnet','HRnet','Hrnet'):
  76. if resolution == (384,288):
  77. cfg.merge_from_file('./experiments/coco/hrnet/heatmap/w48_384x288_adam_lr1e-3.yaml')
  78. elif resolution == (256,192):
  79. cfg.merge_from_file('./experiments/coco/hrnet/heatmap/w48_256x192_adam_lr1e-3.yaml')
  80. else:
  81. raise ValueError('Wrong cfg file')
  82. else:
  83. raise ValueError('Wrong model name.')
  84. cfg.freeze()
  85. self.model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
  86. cfg, is_train=False)
  87. print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
  88. checkpoint_path = cfg.TEST.MODEL_FILE
  89. checkpoint = torch.load(checkpoint_path, map_location=self.device)
  90. if 'model' in checkpoint:
  91. self.model.load_state_dict(checkpoint['model'])
  92. else:
  93. self.model.load_state_dict(checkpoint)
  94. if 'cuda' in str(self.device):
  95. print("device: 'cuda' - ", end="")
  96. if 'cuda' == str(self.device):
  97. # if device is set to 'cuda', all available GPUs will be used
  98. print("%d GPU(s) will be used" % torch.cuda.device_count())
  99. device_ids = None
  100. else:
  101. # if device is set to 'cuda:IDS', only that/those device(s) will be used
  102. print("GPU(s) '%s' will be used" % str(self.device))
  103. device_ids = [int(x) for x in str(self.device)[5:].split(',')]
  104. elif 'cpu' == str(self.device):
  105. print("device: 'cpu'")
  106. else:
  107. raise ValueError('Wrong device name.')
  108. self.model = self.model.to(self.device)
  109. self.model.eval()
  110. self.detector = Yolov5(
  111. weights=yolo_weights_path,
  112. opt=opt ,
  113. device=self.device)
  114. self.transform = transforms.Compose([
  115. transforms.ToPILImage(),
  116. transforms.Resize((self.resolution[0], self.resolution[1])), # (height, width)
  117. transforms.ToTensor(),
  118. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  119. ])
  120. def _box2cs(self, box):
  121. x, y, w, h = box[:4]
  122. return self._xywh2cs(x, y, w, h)
  123. def _xywh2cs(self, x, y, w, h):
  124. center = np.zeros((2), dtype=np.float32)
  125. center[0] = x + w * 0.5
  126. center[1] = y + h * 0.5
  127. if w > self.aspect_ratio * h:
  128. h = w * 1.0 / self.aspect_ratio
  129. elif w < self.aspect_ratio * h:
  130. w = h * self.aspect_ratio
  131. scale = np.array(
  132. [w * 1.0 / 200, h * 1.0 / 200],
  133. dtype=np.float32)
  134. if center[0] != -1:
  135. scale = scale * 1.25
  136. return center, scale
  137. def predict(self, image):
  138. """
  139. Predicts the human pose on a single image or a stack of n images.
  140. Args:
  141. image (:class:`np.ndarray`):
  142. the image(s) on which the human pose will be estimated.
  143. image is expected to be in the opencv format.
  144. image can be:
  145. - a single image with shape=(height, width, BGR color channel)
  146. - a stack of n images with shape=(n, height, width, BGR color channel)
  147. Returns:
  148. :class:`np.ndarray` or list:
  149. a numpy array containing human joints for each (detected) person.
  150. Format:
  151. if image is a single image:
  152. shape=(# of people, # of joints (nof_joints), 3); dtype=(np.float32).
  153. if image is a stack of n images:
  154. list of n np.ndarrays with
  155. shape=(# of people, # of joints (nof_joints), 3); dtype=(np.float32).
  156. Each joint has 3 values: (y position, x position, joint confidence).
  157. If self.return_heatmaps, the class returns a list with (heatmaps, human joints)
  158. If self.return_bounding_boxes, the class returns a list with (bounding boxes, human joints)
  159. If self.return_heatmaps and self.return_bounding_boxes, the class returns a list with
  160. (heatmaps, bounding boxes, human joints)
  161. """
  162. if len(image.shape) == 3:
  163. return self._predict_single(image)
  164. else:
  165. raise ValueError('Wrong image format.')
  166. def sa_simdr_pts(self,img,detection,images,boxes):
  167. c, s = [], []
  168. if detection is not None:
  169. for i, (x1, y1, x2, y2) in enumerate(detection):
  170. x1 = int(round(x1.item()))
  171. x2 = int(round(x2.item()))
  172. y1 = int(round(y1.item()))
  173. y2 = int(round(y2.item()))
  174. boxes[i] = [x1, y1, x2, y2]
  175. w, h = x2 - x1, y2 - y1
  176. xx1 = np.max((0, x1))
  177. yy1 = np.max((0, y1))
  178. xx2 = np.min((img.shape[1] - 1, x1 + np.max((0, w - 1))))
  179. yy2 = np.min((img.shape[0] - 1, y1 + np.max((0, h - 1))))
  180. box = [xx1, yy1, xx2 - xx1, yy2 - yy1]
  181. center, scale = self._box2cs(box)
  182. c.append(center)
  183. s.append(scale)
  184. trans = get_affine_transform(center, scale, 0, np.array(cfg.MODEL.IMAGE_SIZE))
  185. input = cv2.warpAffine(
  186. img,
  187. trans,
  188. (int(self.resolution[1]), int(self.resolution[0])),
  189. flags=cv2.INTER_LINEAR)
  190. images[i] = self.transform(input)
  191. if images.shape[0] > 0:
  192. images = images.to(self.device)
  193. with torch.no_grad():
  194. output_x, output_y = self.model(images)
  195. if cfg.TEST.FLIP_TEST:
  196. input_flipped = images.flip(3)
  197. output_x_flipped_, output_y_flipped_ = self.model(input_flipped)
  198. output_x_flipped = flip_back_simdr(output_x_flipped_.cpu().numpy(),
  199. self.flip_pairs, type='x')
  200. output_y_flipped = flip_back_simdr(output_y_flipped_.cpu().numpy(),
  201. self.flip_pairs, type='y')
  202. output_x_flipped = torch.from_numpy(output_x_flipped.copy()).to(self.device)
  203. output_y_flipped = torch.from_numpy(output_y_flipped.copy()).to(self.device)
  204. # feature is not aligned, shift flipped heatmap for higher accuracy
  205. if cfg.TEST.SHIFT_HEATMAP:
  206. output_x_flipped[:, :, 0:-1] = \
  207. output_x_flipped.clone()[:, :, 1:]
  208. output_x = F.softmax((output_x + output_x_flipped) * 0.5, dim=2)
  209. output_y = F.softmax((output_y + output_y_flipped) * 0.5, dim=2)
  210. else:
  211. output_x = F.softmax(output_x, dim=2)
  212. output_y = F.softmax(output_y, dim=2)
  213. max_val_x, preds_x = output_x.max(2, keepdim=True)
  214. max_val_y, preds_y = output_y.max(2, keepdim=True)
  215. mask = max_val_x > max_val_y
  216. max_val_x[mask] = max_val_y[mask]
  217. maxvals = max_val_x * 10.0
  218. output = torch.ones([images.size(0), preds_x.size(1), 3])
  219. output[:, :, 0] = torch.squeeze(torch.true_divide(preds_x, cfg.MODEL.SIMDR_SPLIT_RATIO))
  220. output[:, :, 1] = torch.squeeze(torch.true_divide(preds_y, cfg.MODEL.SIMDR_SPLIT_RATIO))
  221. # output[:, :, 2] = maxvals.squeeze(2)
  222. output = output.cpu().numpy()
  223. preds = output.copy()
  224. for i in range(output.shape[0]):
  225. preds[i] = transform_preds(
  226. output[i], c[i], s[i], [cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]]
  227. )
  228. preds[:, :, 2] = maxvals.squeeze(2)
  229. else:
  230. preds = np.empty((0, 0, 3), dtype=np.float32)
  231. return preds
  232. def simdr_pts(self,img,detection,images,boxes):
  233. c, s = [], []
  234. if detection is not None:
  235. for i, (x1, y1, x2, y2) in enumerate(detection):
  236. x1 = int(round(x1.item()))
  237. x2 = int(round(x2.item()))
  238. y1 = int(round(y1.item()))
  239. y2 = int(round(y2.item()))
  240. boxes[i] = [x1, y1, x2, y2]
  241. w, h = x2 - x1, y2 - y1
  242. xx1 = np.max((0, x1))
  243. yy1 = np.max((0, y1))
  244. xx2 = np.min((img.shape[1] - 1, x1 + np.max((0, w - 1))))
  245. yy2 = np.min((img.shape[0] - 1, y1 + np.max((0, h - 1))))
  246. box = [xx1, yy1, xx2 - xx1, yy2 - yy1]
  247. center, scale = self._box2cs(box)
  248. c.append(center)
  249. s.append(scale)
  250. trans = get_affine_transform(center, scale, 0, np.array(cfg.MODEL.IMAGE_SIZE))
  251. input = cv2.warpAffine(
  252. img,
  253. trans,
  254. (int(self.resolution[1]), int(self.resolution[0])),
  255. flags=cv2.INTER_LINEAR)
  256. images[i] = self.transform(input)
  257. if images.shape[0] > 0:
  258. images = images.to(self.device)
  259. with torch.no_grad():
  260. output_x, output_y = self.model(images)
  261. if cfg.TEST.FLIP_TEST:
  262. input_flipped = images.flip(3)
  263. output_x_flipped_, output_y_flipped_ = self.model(input_flipped)
  264. output_x_flipped = flip_back_simdr(output_x_flipped_.cpu().numpy(),
  265. self.flip_pairs, type='x')
  266. output_y_flipped = flip_back_simdr(output_y_flipped_.cpu().numpy(),
  267. self.flip_pairs, type='y')
  268. output_x_flipped = torch.from_numpy(output_x_flipped.copy()).to(self.device)
  269. output_y_flipped = torch.from_numpy(output_y_flipped.copy()).to(self.device)
  270. # feature is not aligned, shift flipped heatmap for higher accuracy
  271. if cfg.TEST.SHIFT_HEATMAP:
  272. output_x_flipped[:, :, 0:-1] = \
  273. output_x_flipped.clone()[:, :, 1:]
  274. output_x = (F.softmax(output_x, dim=2) + F.softmax(output_x_flipped, dim=2)) * 0.5
  275. output_y = (F.softmax(output_y, dim=2) + F.softmax(output_y_flipped, dim=2)) * 0.5
  276. else:
  277. output_x = F.softmax(output_x, dim=2)
  278. output_y = F.softmax(output_y, dim=2)
  279. max_val_x, preds_x = output_x.max(2, keepdim=True)
  280. max_val_y, preds_y = output_y.max(2, keepdim=True)
  281. mask = max_val_x > max_val_y
  282. max_val_x[mask] = max_val_y[mask]
  283. maxvals = max_val_x * 10.0
  284. output = torch.ones([images.size(0), preds_x.size(1), 3])
  285. output[:, :, 0] = torch.squeeze(torch.true_divide(preds_x, cfg.MODEL.SIMDR_SPLIT_RATIO))
  286. output[:, :, 1] = torch.squeeze(torch.true_divide(preds_y, cfg.MODEL.SIMDR_SPLIT_RATIO))
  287. output = output.cpu().numpy()
  288. preds = output.copy()
  289. for i in range(output.shape[0]):
  290. preds[i] = transform_preds(
  291. output[i], c[i], s[i], [cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]]
  292. )
  293. preds[:, :, 2] = maxvals.squeeze(2)
  294. else:
  295. preds = np.empty((0, 0, 3), dtype=np.float32)
  296. return preds
  297. def hrnet_pts(self,img,detection,images,boxes):
  298. if detection is not None:
  299. for i, (x1, y1, x2, y2) in enumerate(detection):
  300. x1 = int(round(x1.item()))
  301. x2 = int(round(x2.item()))
  302. y1 = int(round(y1.item()))
  303. y2 = int(round(y2.item()))
  304. # Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14)
  305. correction_factor = self.resolution[0] / self.resolution[1] * (x2 - x1) / (y2 - y1)
  306. if correction_factor > 1:
  307. # increase y side
  308. center = y1 + (y2 - y1) // 2
  309. length = int(round((y2 - y1) * correction_factor))
  310. y1 = max(0, center - length // 2)
  311. y2 = min(img.shape[0], center + length // 2)
  312. elif correction_factor < 1:
  313. # increase x side
  314. center = x1 + (x2 - x1) // 2
  315. length = int(round((x2 - x1) * 1 / correction_factor))
  316. x1 = max(0, center - length // 2)
  317. x2 = min(img.shape[1], center + length // 2)
  318. boxes[i] = [x1, y1, x2, y2]
  319. images[i] = self.transform(img[y1:y2, x1:x2, ::-1])
  320. if images.shape[0] > 0:
  321. images = images.to(self.device)
  322. with torch.no_grad():
  323. out = self.model(images)
  324. out = out.detach().cpu().numpy()
  325. pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
  326. # For each human, for each joint: y, x, confidence
  327. for i, human in enumerate(out):
  328. for j, joint in enumerate(human):
  329. pt = np.unravel_index(np.argmax(joint), (self.resolution[0] // 4, self.resolution[1] // 4))
  330. # 0: pt_x / (height // 4) * (bb_y2 - bb_y1) + bb_y1
  331. # 1: pt_y / (width // 4) * (bb_x2 - bb_x1) + bb_x1
  332. # 2: confidences
  333. pts[i, j, 0] = pt[1] * 1. / (self.resolution[1] // 4) * (boxes[i][2] - boxes[i][0]) + boxes[i][0]
  334. pts[i, j, 1] = pt[0] * 1. / (self.resolution[0] // 4) * (boxes[i][3] - boxes[i][1]) + boxes[i][1]
  335. pts[i, j, 2] = joint[pt]
  336. else:
  337. pts = np.empty((0, 0, 3), dtype=np.float32)
  338. return pts
  339. def _predict_single(self, image):
  340. _,detections = self.detector.detect(image)
  341. nof_people = len(detections) if detections is not None else 0
  342. boxes = np.empty((nof_people, 4), dtype=np.int32)
  343. images = torch.empty((nof_people, 3, self.resolution[0], self.resolution[1])) # (height, width)
  344. if self.model_name in ('sa-simdr','sasimdr'):
  345. pts=self.sa_simdr_pts(image,detections,images,boxes)
  346. elif self.model_name in ('hrnet','HRnet','hrnet'):
  347. pts = self.hrnet_pts(image, detections, images, boxes)
  348. elif self.model_name in ('simdr'):
  349. pts = self.simdr_pts(image, detections, images, boxes)
  350. return pts
  351. # c,s=[],[]
  352. # if detections is not None:
  353. # for i, (x1, y1, x2, y2) in enumerate(detections):
  354. # x1 = int(round(x1.item()))
  355. # x2 = int(round(x2.item()))
  356. # y1 = int(round(y1.item()))
  357. # y2 = int(round(y2.item()))
  358. # boxes[i] = [x1,y1,x2,y2]
  359. # w ,h= x2-x1,y2-y1
  360. # xx1 = np.max((0, x1))
  361. # yy1 = np.max((0, y1))
  362. # xx2 = np.min((image.shape[1] - 1, x1 + np.max((0, w - 1))))
  363. # yy2 = np.min((image.shape[0] - 1, y1 + np.max((0, h - 1))))
  364. # box = [xx1, yy1, xx2-xx1, yy2-yy1]
  365. # center,scale = self._box2cs(box)
  366. # c.append(center)
  367. # s.append(scale)
  368. #
  369. # trans = get_affine_transform(center, scale, 0, np.array(cfg.MODEL.IMAGE_SIZE))
  370. # input = cv2.warpAffine(
  371. # image,
  372. # trans,
  373. # (int(self.resolution[1]), int(self.resolution[0])),
  374. # flags=cv2.INTER_LINEAR)
  375. # images[i] = self.transform(input)
  376. # if images.shape[0] > 0:
  377. # images = images.to(self.device)
  378. # with torch.no_grad():
  379. # output_x,output_y = self.model(images)
  380. # if cfg.TEST.FLIP_TEST:
  381. # input_flipped = images.flip(3)
  382. # output_x_flipped_, output_y_flipped_ = self.model(input_flipped)
  383. # output_x_flipped = flip_back_simdr(output_x_flipped_.cpu().numpy(),
  384. # self.flip_pairs, type='x')
  385. # output_y_flipped = flip_back_simdr(output_y_flipped_.cpu().numpy(),
  386. # self.flip_pairs, type='y')
  387. # output_x_flipped = torch.from_numpy(output_x_flipped.copy()).to(self.device)
  388. # output_y_flipped = torch.from_numpy(output_y_flipped.copy()).to(self.device)
  389. #
  390. # # feature is not aligned, shift flipped heatmap for higher accuracy
  391. # if cfg.TEST.SHIFT_HEATMAP:
  392. # output_x_flipped[:, :, 0:-1] = \
  393. # output_x_flipped.clone()[:, :, 1:]
  394. # output_x = F.softmax((output_x + output_x_flipped) * 0.5, dim=2)
  395. # output_y = F.softmax((output_y + output_y_flipped) * 0.5, dim=2)
  396. # else:
  397. # output_x = F.softmax(output_x, dim=2)
  398. # output_y = F.softmax(output_y, dim=2)
  399. # max_val_x, preds_x = output_x.max(2, keepdim=True)
  400. # max_val_y, preds_y = output_y.max(2, keepdim=True)
  401. #
  402. # mask = max_val_x > max_val_y
  403. # max_val_x[mask] = max_val_y[mask]
  404. # maxvals = max_val_x.cpu().numpy()
  405. #
  406. # output = torch.ones([images.size(0), preds_x.size(1), 2])
  407. # output[:, :, 0] = torch.squeeze(torch.true_divide(preds_x, cfg.MODEL.SIMDR_SPLIT_RATIO))
  408. # output[:, :, 1] = torch.squeeze(torch.true_divide(preds_y, cfg.MODEL.SIMDR_SPLIT_RATIO))
  409. #
  410. # output = output.cpu().numpy()
  411. # preds = output.copy()
  412. # for i in range(output.shape[0]):
  413. # preds[i] = transform_preds(
  414. # output[i], c[i], s[i], [cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]]
  415. # )
  416. # else:
  417. # preds = np.empty((0, 0, 2), dtype=np.float32)
  418. # return preds
  419. # parser = argparse.ArgumentParser()
  420. # parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  421. # parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  422. # parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  423. # parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  424. # parser.add_argument('--augment', action='store_true', help='augmented inference')
  425. # parser.add_argument('--update', action='store_true', help='update all model')
  426. # parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  427. # parser.add_argument('--name', default='exp', help='save results to project/name')
  428. # parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  429. # opt = parser.parse_args()
  430. # model = Points(model_name='hrnet',opt=opt)
  431. # img0 = cv2.imread('./data/test1.jpg')
  432. # pts = model.predict(img0)
  433. # print(pts.shape)
  434. # for point in pts[0]:
  435. # image = cv2.circle(img0, (int(point[0]), int(point[1])), 3, [255,0,255], -1 , lineType= cv2.LINE_AA)
  436. # cv2.imwrite('./data/test11_result.jpg',image)

④绘制骨骼关键点

  1. 根据以上步骤,我们已经得到了关键点的坐标值,接下来需要在图片中描绘出来,以便展示检测结果。
  2. 首先在’ ./SimDR/lib/utils/ ‘文件夹下新建visualization.py文件,将以下内容复制到文件中。骨架绘制代码结合了simple-hrnetOpenpose工程。

【2022.04.16更新:由于之前的绘制代码被我魔改过,现在恢复成所有点与骨骼都绘制的模样,但是总觉得好丑,没有openpose那种美观,如果有人绘制出比较美观的骨架,希望能分享一下哈,共同进步!】

  1. import cv2
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. import torch
  5. import torchvision
  6. import ffmpeg
  7. import random
  8. import math
  9. import copy
  10. def plot_one_box(x, img, color=None, label=None, line_thickness=3):
  11. # Plots one bounding box on image img
  12. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  13. color = color or [random.randint(0, 255) for _ in range(3)]
  14. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  15. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  16. if label:
  17. tf = max(tl - 1, 1) # font thickness
  18. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  19. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  20. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  21. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  22. return img
  23. def joints_dict():
  24. joints = {
  25. "coco": {
  26. "keypoints": {
  27. 0: "nose",
  28. 1: "left_eye",
  29. 2: "right_eye",
  30. 3: "left_ear",
  31. 4: "right_ear",
  32. 5: "left_shoulder",
  33. 6: "right_shoulder",
  34. 7: "left_elbow",
  35. 8: "right_elbow",
  36. 9: "left_wrist",
  37. 10: "right_wrist",
  38. 11: "left_hip",
  39. 12: "right_hip",
  40. 13: "left_knee",
  41. 14: "right_knee",
  42. 15: "left_ankle",
  43. 16: "right_ankle"
  44. },
  45. "skeleton": [
  46. # # [16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8],
  47. # # [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]
  48. # [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7],
  49. # [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]
  50. [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7],
  51. [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], # [3, 5], [4, 6]
  52. [0, 5], [0, 6]
  53. # [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7],
  54. # [6, 8], [7, 9], [8, 10], [0, 3], [0, 4], [1, 3], [2, 4], # [3, 5], [4, 6]
  55. # [0, 5], [0, 6]
  56. ]
  57. },
  58. "mpii": {
  59. "keypoints": {
  60. 0: "right_ankle",
  61. 1: "right_knee",
  62. 2: "right_hip",
  63. 3: "left_hip",
  64. 4: "left_knee",
  65. 5: "left_ankle",
  66. 6: "pelvis",
  67. 7: "thorax",
  68. 8: "upper_neck",
  69. 9: "head top",
  70. 10: "right_wrist",
  71. 11: "right_elbow",
  72. 12: "right_shoulder",
  73. 13: "left_shoulder",
  74. 14: "left_elbow",
  75. 15: "left_wrist"
  76. },
  77. "skeleton": [
  78. # [5, 4], [4, 3], [0, 1], [1, 2], [3, 2], [13, 3], [12, 2], [13, 12], [13, 14],
  79. # [12, 11], [14, 15], [11, 10], # [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]
  80. [5, 4], [4, 3], [0, 1], [1, 2], [3, 2], [3, 6], [2, 6], [6, 7], [7, 8], [8, 9],
  81. [13, 7], [12, 7], [13, 14], [12, 11], [14, 15], [11, 10],
  82. ]
  83. },
  84. }
  85. return joints
  86. def draw_points(image, points, color_palette='tab20', palette_samples=16, confidence_threshold=0.1,color=None):
  87. """
  88. Draws `points` on `image`.
  89. Args:
  90. image: image in opencv format
  91. points: list of points to be drawn.
  92. Shape: (nof_points, 3)
  93. Format: each point should contain (y, x, confidence)
  94. color_palette: name of a matplotlib color palette
  95. Default: 'tab20'
  96. palette_samples: number of different colors sampled from the `color_palette`
  97. Default: 16
  98. confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1]
  99. Default: 0.1
  100. Returns:
  101. A new image with overlaid points
  102. """
  103. circle_size = max(2, int(np.sqrt(np.max(np.max(points, axis=0) - np.min(points, axis=0)) // 16)))
  104. for i, pt in enumerate(points):
  105. if pt[2] >= confidence_threshold:
  106. image = cv2.circle(image, (int(pt[0]), int(pt[1])), circle_size, color[i] ,-1, lineType= cv2.LINE_AA)
  107. return image
  108. def draw_skeleton(image, points, skeleton, color_palette='Set2', palette_samples=8, person_index=0,
  109. confidence_threshold=0.1,sk_color=None):
  110. """
  111. Draws a `skeleton` on `image`.
  112. Args:
  113. image: image in opencv format
  114. points: list of points to be drawn.
  115. Shape: (nof_points, 3)
  116. Format: each point should contain (y, x, confidence)
  117. skeleton: list of joints to be drawn
  118. Shape: (nof_joints, 2)
  119. Format: each joint should contain (point_a, point_b) where `point_a` and `point_b` are an index in `points`
  120. color_palette: name of a matplotlib color palette
  121. Default: 'Set2'
  122. palette_samples: number of different colors sampled from the `color_palette`
  123. Default: 8
  124. person_index: index of the person in `image`
  125. Default: 0
  126. confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1]
  127. Default: 0.1
  128. Returns:
  129. A new image with overlaid joints
  130. """
  131. canvas = copy.deepcopy(image)
  132. cur_canvas = canvas.copy()
  133. for i, joint in enumerate(skeleton):
  134. pt1, pt2 = points[joint]
  135. if pt1[2] >= confidence_threshold and pt2[2]>= confidence_threshold :
  136. length = ((pt1[0] - pt2[0]) ** 2 + (pt1[1] - pt2[1]) ** 2) ** 0.5
  137. angle = math.degrees(math.atan2(pt1[1] - pt2[1],pt1[0] - pt2[0]))
  138. polygon = cv2.ellipse2Poly((int(np.mean((pt1[0],pt2[0]))), int(np.mean((pt1[1],pt2[1])))), (int(length / 2), 2), int(angle), 0, 360, 1)
  139. cv2.fillConvexPoly(cur_canvas, polygon, sk_color[i],lineType=cv2.LINE_AA)
  140. # cv2.fillConvexPoly(cur_canvas, polygon, sk_color,lineType=cv2.LINE_AA)
  141. canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
  142. return canvas
  143. def draw_points_and_skeleton(image, points, skeleton, points_color_palette='tab20', points_palette_samples=16,
  144. skeleton_color_palette='Set2', skeleton_palette_samples=8, person_index=0,
  145. confidence_threshold=0.1,color=None,sk_color=None):
  146. """
  147. Draws `points` and `skeleton` on `image`.
  148. Args:
  149. image: image in opencv format
  150. points: list of points to be drawn.
  151. Shape: (nof_points, 3)
  152. Format: each point should contain (y, x, confidence)
  153. skeleton: list of joints to be drawn
  154. Shape: (nof_joints, 2)
  155. Format: each joint should contain (point_a, point_b) where `point_a` and `point_b` are an index in `points`
  156. points_color_palette: name of a matplotlib color palette
  157. Default: 'tab20'
  158. points_palette_samples: number of different colors sampled from the `color_palette`
  159. Default: 16
  160. skeleton_color_palette: name of a matplotlib color palette
  161. Default: 'Set2'
  162. skeleton_palette_samples: number of different colors sampled from the `color_palette`
  163. Default: 8
  164. person_index: index of the person in `image`
  165. Default: 0
  166. confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1]
  167. Default: 0.1
  168. Returns:
  169. A new image with overlaid joints
  170. """
  171. colors1 = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
  172. [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
  173. [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 85]]
  174. image = draw_skeleton(image, points, skeleton, color_palette=skeleton_color_palette,
  175. palette_samples=skeleton_palette_samples, person_index=person_index,
  176. confidence_threshold=confidence_threshold,sk_color=colors1)
  177. image = draw_points(image, points, color_palette=points_color_palette, palette_samples=points_palette_samples,
  178. confidence_threshold=confidence_threshold,color=colors1)
  179. return image
  180. def save_images(images, target, joint_target, output, joint_output, joint_visibility, summary_writer=None, step=0,
  181. prefix=''):
  182. """
  183. Creates a grid of images with gt joints and a grid with predicted joints.
  184. This is a basic function for debugging purposes only.
  185. If summary_writer is not None, the grid will be written in that SummaryWriter with name "{prefix}_images" and
  186. "{prefix}_predictions".
  187. Args:
  188. images (torch.Tensor): a tensor of images with shape (batch x channels x height x width).
  189. target (torch.Tensor): a tensor of gt heatmaps with shape (batch x channels x height x width).
  190. joint_target (torch.Tensor): a tensor of gt joints with shape (batch x joints x 2).
  191. output (torch.Tensor): a tensor of predicted heatmaps with shape (batch x channels x height x width).
  192. joint_output (torch.Tensor): a tensor of predicted joints with shape (batch x joints x 2).
  193. joint_visibility (torch.Tensor): a tensor of joint visibility with shape (batch x joints).
  194. summary_writer (tb.SummaryWriter): a SummaryWriter where write the grids.
  195. Default: None
  196. step (int): summary_writer step.
  197. Default: 0
  198. prefix (str): summary_writer name prefix.
  199. Default: ""
  200. Returns:
  201. A pair of images which are built from torchvision.utils.make_grid
  202. """
  203. # Input images with gt
  204. images_ok = images.detach().clone()
  205. images_ok[:, 0].mul_(0.229).add_(0.485)
  206. images_ok[:, 1].mul_(0.224).add_(0.456)
  207. images_ok[:, 2].mul_(0.225).add_(0.406)
  208. for i in range(images.shape[0]):
  209. joints = joint_target[i] * 4.
  210. joints_vis = joint_visibility[i]
  211. for joint, joint_vis in zip(joints, joints_vis):
  212. if joint_vis[0]:
  213. a = int(joint[1].item())
  214. b = int(joint[0].item())
  215. # images_ok[i][:, a-1:a+1, b-1:b+1] = torch.tensor([1, 0, 0])
  216. images_ok[i][0, a - 1:a + 1, b - 1:b + 1] = 1
  217. images_ok[i][1:, a - 1:a + 1, b - 1:b + 1] = 0
  218. grid_gt = torchvision.utils.make_grid(images_ok, nrow=int(images_ok.shape[0] ** 0.5), padding=2, normalize=False)
  219. if summary_writer is not None:
  220. summary_writer.add_image(prefix + 'images', grid_gt, global_step=step)
  221. # Input images with prediction
  222. images_ok = images.detach().clone()
  223. images_ok[:, 0].mul_(0.229).add_(0.485)
  224. images_ok[:, 1].mul_(0.224).add_(0.456)
  225. images_ok[:, 2].mul_(0.225).add_(0.406)
  226. for i in range(images.shape[0]):
  227. joints = joint_output[i] * 4.
  228. joints_vis = joint_visibility[i]
  229. for joint, joint_vis in zip(joints, joints_vis):
  230. if joint_vis[0]:
  231. a = int(joint[1].item())
  232. b = int(joint[0].item())
  233. # images_ok[i][:, a-1:a+1, b-1:b+1] = torch.tensor([1, 0, 0])
  234. images_ok[i][0, a - 1:a + 1, b - 1:b + 1] = 1
  235. images_ok[i][1:, a - 1:a + 1, b - 1:b + 1] = 0
  236. grid_pred = torchvision.utils.make_grid(images_ok, nrow=int(images_ok.shape[0] ** 0.5), padding=2, normalize=False)
  237. if summary_writer is not None:
  238. summary_writer.add_image(prefix + 'predictions', grid_pred, global_step=step)
  239. # Heatmaps
  240. # ToDo
  241. # for h in range(0,17):
  242. # heatmap = torchvision.utils.make_grid(output[h].detach(), nrow=int(np.sqrt(output.shape[0])),
  243. # padding=2, normalize=True, range=(0, 1))
  244. # summary_writer.add_image('train_heatmap_%d' % h, heatmap, global_step=step + epoch*len_dl_train)
  245. return grid_gt, grid_pred
  246. def check_video_rotation(filename):
  247. # thanks to
  248. # https://stackoverflow.com/questions/53097092/frame-from-video-is-upside-down-after-extracting/55747773#55747773
  249. # this returns meta-data of the video file in form of a dictionary
  250. meta_dict = ffmpeg.probe(filename)
  251. # from the dictionary, meta_dict['streams'][0]['tags']['rotate'] is the key
  252. # we are looking for
  253. rotation_code = None
  254. try:
  255. if int(meta_dict['streams'][0]['tags']['rotate']) == 90:
  256. rotation_code = cv2.ROTATE_90_CLOCKWISE
  257. elif int(meta_dict['streams'][0]['tags']['rotate']) == 180:
  258. rotation_code = cv2.ROTATE_180
  259. elif int(meta_dict['streams'][0]['tags']['rotate']) == 270:
  260. rotation_code = cv2.ROTATE_90_COUNTERCLOCKWISE
  261. else:
  262. raise ValueError
  263. except KeyError:
  264. pass
  265. return rotation_code

4、测试算法

①主程序

  1. SimDR文件夹下新建main.py ,复制以下代码到文件中,修改parser参数source的默认值,运行代码。
  1. import argparse
  2. import time
  3. import os
  4. import cv2 as cv
  5. import numpy as np
  6. from pathlib import Path
  7. from Point_detect import Points
  8. from lib.utils.visualization import draw_points_and_skeleton,joints_dict
  9. def image_detect(opt):
  10. skeleton = joints_dict()['coco']['skeleton']
  11. hrnet_model = Points(model_name='hrnet', opt=opt,resolution=(384,288)) #resolution = (384,288) or (256,192)
  12. # simdr_model = Points(model_name='simdr', opt=opt,resolution=(256,192)) #resolution = (256,192)
  13. # sa_simdr_model = Points(model_name='sa-simdr', opt=opt,resolution=(384,288)) #resolution = (384,288) or (256,192)
  14. img0 = cv.imread(opt.source)
  15. frame = img0.copy()
  16. #predict
  17. pred = hrnet_model.predict(img0)
  18. # pred = simdr_model.predict(frame)
  19. # pred = sa_simdr_model.predict(frame)
  20. #vis
  21. for i, pt in enumerate(pred):
  22. frame = draw_points_and_skeleton(frame, pt, skeleton)
  23. #save
  24. cv.imwrite('test_result.jpg', frame)
  25. def video_detect(opt):
  26. hrnet_model = Points(model_name='hrnet', opt=opt, resolution=(384, 288)) # resolution = (384,288) or (256,192)
  27. # simdr_model = Points(model_name='simdr', opt=opt,resolution=(256,192)) #resolution = (256,192)
  28. # sa_simdr_model = Points(model_name='sa-simdr', opt=opt,resolution=(384,288)) #resolution = (384,288) or (256,192)
  29. skeleton = joints_dict()['coco']['skeleton']
  30. cap = cv.VideoCapture(opt.source)
  31. if opt.save_video:
  32. fourcc = cv.VideoWriter_fourcc(*'MJPG')
  33. out = cv.VideoWriter('data/runs/{}_out.avi'.format(os.path.basename(opt.source).split('.')[0]), fourcc, 24, (int(cap.get(3)), int(cap.get(4))))
  34. while cap.isOpened():
  35. ret, frame = cap.read()
  36. if not ret:
  37. break
  38. pred = hrnet_model.predict(frame)
  39. # pred = simdr_model.predict(frame)
  40. # pred = sa_simdr_model.predict(frame)
  41. for pt in pred:
  42. frame = draw_points_and_skeleton(frame,pt,skeleton)
  43. if opt.show:
  44. cv.imshow('result', frame)
  45. if opt.save_video:
  46. out.write(frame)
  47. if cv.waitKey(1) == 27:
  48. break
  49. out.release()
  50. cap.release()
  51. cv.destroyAllWindows()
  52. # video_detect(0)
  53. if __name__ == '__main__':
  54. parser = argparse.ArgumentParser()
  55. parser.add_argument('--source', type=str, default='./data/images/test1.jpg', help='source') # file/folder, 0 for webcam
  56. parser.add_argument('--detect_weight', type=str, default="./yolov5/weights/yolov5x.pt", help='e.g "./yolov5/weights/yolov5x.pt"')
  57. parser.add_argument('--save_video', action='store_true', default=False,help='save results to *.avi')
  58. parser.add_argument('--show', action='store_true', default=True, help='save results to *.avi')
  59. parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  60. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  61. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  62. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  63. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  64. parser.add_argument('--augment', action='store_true', help='augmented inference')
  65. opt = parser.parse_args()
  66. image_detect(opt)

②结果展示

四、总结

  1. 全文较长,主要都是些代码,整个工程从跑数据集到实际检测需要对代码工程有一定的理解,整个项目不难,主要考验类的构造。如果需要整个工程可以私聊我。由于我也是刚入门的萌新,所以代码格式写法或者理论看法有很多错误,欢迎指正,共同进步,如果有帮助欢迎点赞评论,万分感谢。

五、参考内容

1、GitHub - leeyegy/SimDR: PyTorch implementation for: Is 2D Heatmap Representation Even Necessary for Human Pose Estimation? (http://arxiv.org/abs/2107.03332)

2、https://github.com/ultralytics/yolov5

3、GitHub - GreenTeaHua/simple-HRNet: Multi-person Human Pose Estimation with HRNet in Pytorch


本文转载自: https://blog.csdn.net/qq_40691868/article/details/122855962
版权归原作者 围白的尾巴 所有, 如有侵权,请联系我们删除。

“YOLOv5+姿态估计HRnet与SimDR检测视频中的人体关键点”的评论:

还没有评论