0


MMdetection之train.py源码详解

一、tools/train.py

可选参数:

  1. # =========== optional arguments ===========
  2. # --work-dir 存储日志和模型的目录
  3. # --resume-from 加载 checkpoint 的目录
  4. # --no-validate 是否在训练的时候进行验证
  5. # 互斥组:
  6. # --gpus 使用的 GPU 数量
  7. # --gpu_ids 使用指定 GPU 的 id
  8. # --seed 随机数种子
  9. # --deterministic 是否设置 cudnn 为确定性行为
  10. # --options 其他参数
  11. # --launcher 分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
  12. # none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
  13. # --local_rank 本地进程编号,此参数 torch.distributed.launch 会自动传入。

对于 tools/train.py 其主要的流程如下

  1. 对于 train.py 来说,首先从命令行和配置文件读取配置,然后分别用 build_detectorbuild_dataset 构建模型和数据集,最后将模型和数据集传入 train_detector 进行训练。

(一)从命令行和配置文件获取参数配置

  1. cfg = Config.fromfile(args.config)

(二)构建模型

  1. # 构建模型: 需要传入 cfg.model,cfg.train_cfg,cfg.test_cfg
  2. model = build_detector(
  3. cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

(三)构建数据集

  1. # 构建数据集: 需要传入 cfg.data.train,表明是训练集
  2. datasets = [build_dataset(cfg.data.train)]

(四)训练模型

  1. # 训练检测器:需要传入模型、数据集、配置参数等
  2. train_detector(
  3. model,
  4. datasets,
  5. cfg,
  6. distributed=distributed,
  7. validate=(not args.no_validate),
  8. timestamp=timestamp,
  9. meta=meta)

二、源码详解

  1. import argparse
  2. import copy
  3. import os
  4. import os.path as osp
  5. import time
  6. import mmcv
  7. import torch
  8. # Config 用于读取配置文件, DictAction 将命令行字典类型参数转化为 key-value 形式
  9. from mmcv import Config, DictAction
  10. from mmcv.runner import init_dist
  11. from mmdet import __version__
  12. from mmdet.apis import set_random_seed, train_detector
  13. from mmdet.datasets import build_dataset
  14. from mmdet.models import build_detector
  15. from mmdet.utils import collect_env, get_root_logger
  16. # python tools/train.py ${CONFIG_FILE} [optional arguments]
  17. # =========== optional arguments ===========
  18. # --work-dir 存储日志和模型的目录
  19. # --resume-from 加载 checkpoint 的目录
  20. # --no-validate 是否在训练的时候进行验证
  21. # 互斥组:
  22. # --gpus 使用的 GPU 数量
  23. # --gpu_ids 使用指定 GPU 的 id
  24. # --seed 随机数种子
  25. # --deterministic 是否设置 cudnn 为确定性行为
  26. # --options 其他参数
  27. # --launcher 分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
  28. # none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
  29. # --local_rank 本地进程编号,此参数 torch.distributed.launch 会自动传入。
  30. def parse_args():
  31. parser = argparse.ArgumentParser(description='Train a detector')
  32. parser.add_argument('config', help='train config file path')
  33. parser.add_argument('--work-dir', help='the dir to save logs and models')
  34. parser.add_argument(
  35. '--resume-from', help='the checkpoint file to resume from')
  36. # action: store (默认, 表示保存参数)
  37. # action: store_true, store_false (如果指定参数, 则为 True, False)
  38. parser.add_argument(
  39. '--no-validate',
  40. action='store_true',
  41. help='whether not to evaluate the checkpoint during training')
  42. # --------- 创建一个互斥组. argparse 将会确保互斥组中的参数只能出现一个 ---------
  43. group_gpus = parser.add_mutually_exclusive_group()
  44. group_gpus.add_argument(
  45. '--gpus',
  46. type=int,
  47. help='number of gpus to use '
  48. '(only applicable to non-distributed training)')
  49. # 可以使用 python train.py --gpu-ids 0 1 2 3 指定使用的 GPU id
  50. # 参数结果:[0, 1, 2, 3]
  51. # nargs = '*':参数个数可以设置0个或n个
  52. # nargs = '+':参数个数可以设置1个或n个
  53. # nargs = '?':参数个数可以设置0个或1个
  54. group_gpus.add_argument(
  55. '--gpu-ids',
  56. type=int,
  57. nargs='+',
  58. help='ids of gpus to use '
  59. '(only applicable to non-distributed training)')
  60. # ------------------------------------------------------------------------
  61. parser.add_argument('--seed', type=int, default=None, help='random seed')
  62. parser.add_argument(
  63. '--deterministic',
  64. action='store_true',
  65. help='whether to set deterministic options for CUDNN backend.')
  66. # 其他参数: 可以使用 --options a=1,2,3 指定其他参数
  67. # 参数结果: {'a': [1, 2, 3]}
  68. parser.add_argument(
  69. '--options', nargs='+', action=DictAction, help='arguments in dict')
  70. # 如果使用 dist_utils.sh 进行分布式训练, launcher 默认为 pytorch
  71. parser.add_argument(
  72. '--launcher',
  73. choices=['none', 'pytorch', 'slurm', 'mpi'],
  74. default='none',
  75. help='job launcher')
  76. # 本地进程编号,此参数 torch.distributed.launch 会自动传入。
  77. parser.add_argument('--local_rank', type=int, default=0)
  78. args = parser.parse_args()
  79. # 如果环境中没有 LOCAL_RANK,就设置它为当前的 local_rank
  80. if 'LOCAL_RANK' not in os.environ:
  81. os.environ['LOCAL_RANK'] = str(args.local_rank)
  82. return args
  83. def main():
  84. args = parse_args()
  85. # 从文件读取配置
  86. cfg = Config.fromfile(args.config)
  87. # 从命令行读取额外的配置
  88. if args.options is not None:
  89. cfg.merge_from_dict(args.options)
  90. # 设置 cudnn_benchmark = True 可以加速输入大小固定的模型. 如:SSD300
  91. if cfg.get('cudnn_benchmark', False):
  92. torch.backends.cudnn.benchmark = True
  93. # work_dir 的优先程度为: 命令行 > 配置文件
  94. if args.work_dir is not None:
  95. cfg.work_dir = args.work_dir
  96. # 当 work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录
  97. elif cfg.get('work_dir', None) is None:
  98. # os.path.basename(path) 返回文件名
  99. # os.path.splitext(path) 分割路径, 返回路径名和文件扩展名的元组
  100. cfg.work_dir = osp.join('./work_dirs',
  101. osp.splitext(osp.basename(args.config))[0])
  102. # 是否继续上次的训练
  103. if args.resume_from is not None:
  104. cfg.resume_from = args.resume_from
  105. # gpu id
  106. if args.gpu_ids is not None:
  107. cfg.gpu_ids = args.gpu_ids
  108. else:
  109. cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
  110. # 如果 launcher 为 none,不启用分布式训练。不使用 dist_train.sh 默认参数为 none.
  111. if args.launcher == 'none':
  112. distributed = False
  113. # launcher 不为 none,启用分布式训练。使用 dist_train.sh,会传 ‘pytorch’
  114. else:
  115. distributed = True
  116. # 初始化 dist 里面会调用 init_process_group
  117. init_dist(args.launcher, **cfg.dist_params)
  118. # 创建 work_dir
  119. mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
  120. # 保存 config
  121. cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
  122. # init the logger before other steps
  123. # eg: 20200726_105413
  124. timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  125. log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
  126. # 获取 root logger。
  127. logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
  128. # init the meta dict to record some important information such as
  129. # environment info and seed, which will be logged
  130. meta = dict()
  131. # log env info
  132. env_info_dict = collect_env()
  133. env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
  134. dash_line = '-' * 60 + '\n'
  135. logger.info('Environment info:\n' + dash_line + env_info + '\n' +
  136. dash_line)
  137. meta['env_info'] = env_info
  138. # log some basic info
  139. logger.info(f'Distributed training: {distributed}')
  140. logger.info(f'Config:\n{cfg.pretty_text}')
  141. # 设置随机化种子
  142. if args.seed is not None:
  143. logger.info(f'Set random seed to {args.seed}, '
  144. f'deterministic: {args.deterministic}')
  145. set_random_seed(args.seed, deterministic=args.deterministic)
  146. cfg.seed = args.seed
  147. meta['seed'] = args.seed
  148. # 构建模型: 需要传入 cfg.model, cfg.train_cfg, cfg.test_cfg
  149. model = build_detector(
  150. cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
  151. # 构建数据集: 需要传入 cfg.data.train
  152. datasets = [build_dataset(cfg.data.train)]
  153. # workflow 代表流程:
  154. # [('train', 2), ('val', 1)] 就代表,训练两个 epoch 验证一个 epoch
  155. if len(cfg.workflow) == 2:
  156. val_dataset = copy.deepcopy(cfg.data.val)
  157. val_dataset.pipeline = cfg.data.train.pipeline
  158. datasets.append(build_dataset(val_dataset))
  159. if cfg.checkpoint_config is not None:
  160. # save mmdet version, config file content and class names in
  161. # checkpoints as meta data
  162. cfg.checkpoint_config.meta = dict(
  163. mmdet_version=__version__,
  164. config=cfg.pretty_text,
  165. CLASSES=datasets[0].CLASSES)
  166. # add an attribute for visualization convenience
  167. model.CLASSES = datasets[0].CLASSES
  168. # 训练检测器, 传入:模型, 数据集, config 等
  169. train_detector(
  170. model,
  171. datasets,
  172. cfg,
  173. distributed=distributed,
  174. validate=(not args.no_validate),
  175. timestamp=timestamp,
  176. meta=meta)
  177. if __name__ == '__main__':
  178. main()

三、核心函数详解

在 train.py 中主要调用:构建模型(build_detector),构建数据集(build_dataset),训练模型(train_detector)的函数。

(一)build_detector(mmdet/models/builder.py)

build_detector 函数将配置文件中的:model、train_cfg 和 test_cfg 传入参数。

下面以 faster_rcnn_r50_fpn_1x_coco.py 配置文件来举例:

model

  1. model = dict(
  2. type='FasterRCNN',
  3. pretrained='torchvision://resnet50',
  4. backbone=dict(
  5. type='ResNet',
  6. depth=50,
  7. num_stages=4,
  8. out_indices=(0, 1, 2, 3),
  9. frozen_stages=1,
  10. norm_cfg=dict(type='BN', requires_grad=True),
  11. norm_eval=True,
  12. style='pytorch'),
  13. neck=dict(
  14. type='FPN',
  15. in_channels=[256, 512, 1024, 2048],
  16. out_channels=256,
  17. num_outs=5),
  18. rpn_head=dict(
  19. type='RPNHead',
  20. in_channels=256,
  21. feat_channels=256,
  22. anchor_generator=dict(
  23. type='AnchorGenerator',
  24. scales=[8],
  25. ratios=[0.5, 1.0, 2.0],
  26. strides=[4, 8, 16, 32, 64]),
  27. bbox_coder=dict(
  28. type='DeltaXYWHBBoxCoder',
  29. target_means=[.0, .0, .0, .0],
  30. target_stds=[1.0, 1.0, 1.0, 1.0]),
  31. loss_cls=dict(
  32. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
  33. loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
  34. roi_head=dict(
  35. type='StandardRoIHead',
  36. bbox_roi_extractor=dict(
  37. type='SingleRoIExtractor',
  38. roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
  39. out_channels=256,
  40. featmap_strides=[4, 8, 16, 32]),
  41. bbox_head=dict(
  42. type='Shared2FCBBoxHead',
  43. in_channels=256,
  44. fc_out_channels=1024,
  45. roi_feat_size=7,
  46. num_classes=80,
  47. bbox_coder=dict(
  48. type='DeltaXYWHBBoxCoder',
  49. target_means=[0., 0., 0., 0.],
  50. target_stds=[0.1, 0.1, 0.2, 0.2]),
  51. reg_class_agnostic=False,
  52. loss_cls=dict(
  53. type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
  54. loss_bbox=dict(type='L1Loss', loss_weight=1.0))))

train_cfg

  1. train_cfg = dict(
  2. rpn=dict(
  3. assigner=dict(
  4. type='MaxIoUAssigner',
  5. pos_iou_thr=0.7,
  6. neg_iou_thr=0.3,
  7. min_pos_iou=0.3,
  8. match_low_quality=True,
  9. ignore_iof_thr=-1),
  10. sampler=dict(
  11. type='RandomSampler',
  12. num=256,
  13. pos_fraction=0.5,
  14. neg_pos_ub=-1,
  15. add_gt_as_proposals=False),
  16. allowed_border=-1,
  17. pos_weight=-1,
  18. debug=False),
  19. rpn_proposal=dict(
  20. nms_across_levels=False,
  21. nms_pre=2000,
  22. nms_post=1000,
  23. max_num=1000,
  24. nms_thr=0.7,
  25. min_bbox_size=0),
  26. rcnn=dict(
  27. assigner=dict(
  28. type='MaxIoUAssigner',
  29. pos_iou_thr=0.5,
  30. neg_iou_thr=0.5,
  31. min_pos_iou=0.5,
  32. match_low_quality=False,
  33. ignore_iof_thr=-1),
  34. sampler=dict(
  35. type='RandomSampler',
  36. num=512,
  37. pos_fraction=0.25,
  38. neg_pos_ub=-1,
  39. add_gt_as_proposals=True),
  40. pos_weight=-1,
  41. debug=False))

test_cfg

  1. test_cfg = dict(
  2. rpn=dict(
  3. nms_across_levels=False,
  4. nms_pre=1000,
  5. nms_post=1000,
  6. max_num=1000,
  7. nms_thr=0.7,
  8. min_bbox_size=0),
  9. rcnn=dict(
  10. score_thr=0.05,
  11. nms=dict(type='nms', iou_threshold=0.5),
  12. max_per_img=100)
  13. # soft-nms is also supported for rcnn testing
  14. # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
  15. )

运行时会将上面的三个值作为参数传入 build_detector 函数,build_detector 函数会调用 build 函数,build 函数调用 build_from_cfg 函数构建检测器对象。其中 train_cfg 和 test_cfg 作为默认参数用于构建 detector 对象。

  1. def build(cfg, registry, default_args=None):
  2. if isinstance(cfg, list):
  3. modules = [
  4. build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
  5. ]
  6. return nn.Sequential(*modules)
  7. else:
  8. # 调用 build_from_cfg 用来根据 config 字典构建 registry 里面的对象
  9. return build_from_cfg(cfg, registry, default_args)
  10. def build_detector(cfg, train_cfg=None, test_cfg=None):
  11. # 调用 build 函数,传入 cfg, registry 对象,
  12. # 把 train_cfg 和 test_cfg 作为默认字典传入
  13. return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

build_from_cfg 在 mmcv/utils/registery.py 中。其中参数 cfg 字典中的 type 键所对应的值表示需要创建的对象的类型。build_from_cfg 会自动在 Registry 注册的类中找到需要创建的类,并传入默认参数实例化。

  1. def build_from_cfg(cfg, registry, default_args=None):
  2. """Build a module from config dict.
  3. Args:
  4. cfg (dict): Config dict. It should at least contain the key "type".
  5. registry (:obj:`Registry`): The registry to search the type from.
  6. default_args (dict, optional): Default initialization arguments.
  7. Returns:
  8. object: The constructed object.
  9. """
  10. if not isinstance(cfg, dict):
  11. raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
  12. if 'type' not in cfg:
  13. raise KeyError(
  14. f'the cfg dict must contain the key "type", but got {cfg}')
  15. if not isinstance(registry, Registry):
  16. raise TypeError('registry must be an mmcv.Registry object, '
  17. f'but got {type(registry)}')
  18. if not (isinstance(default_args, dict) or default_args is None):
  19. raise TypeError('default_args must be a dict or None, '
  20. f'but got {type(default_args)}')
  21. args = cfg.copy()
  22. # 获取 type 对应的值
  23. obj_type = args.pop('type')
  24. if is_str(obj_type):
  25. # 获取需要创建的对象
  26. obj_cls = registry.get(obj_type)
  27. if obj_cls is None:
  28. raise KeyError(
  29. f'{obj_type} is not in the {registry.name} registry')
  30. elif inspect.isclass(obj_type):
  31. obj_cls = obj_type
  32. else:
  33. raise TypeError(
  34. f'type must be a str or valid type, but got {type(obj_type)}')
  35. # 如果 default_args 不是 None,传入默认值再实例化。
  36. if default_args is not None:
  37. for name, value in default_args.items():
  38. args.setdefault(name, value)
  39. return obj_cls(**args)

那么什么是 registry?registry 就是注册类,将一个字符串和类关联起来。如果索引字符串就会获得类。Registry 是注册所需要的类,可以用它来注册类。我们可以使用如下的方式来注册类。

下面是 Registry 类的代码,它的内部维护了一个已经注册的类的字典 ——_module_dict。每当注册一个类就在字典里添加一个字符串(默认为类名)与类的映射。register_module 方法,利用装饰器将类名和类添加到 _module_dict 中。对于注册的模块可以通过 build_from_cfg 来构建。

  1. import inspect
  2. import warnings
  3. from functools import partial
  4. from .misc import is_str
  5. class Registry:
  6. """A registry to map strings to classes.
  7. Args:
  8. name (str): Registry name.
  9. """
  10. def __init__(self, name):
  11. self._name = name
  12. # 已经注册的类的字典
  13. self._module_dict = dict()
  14. def __len__(self):
  15. return len(self._module_dict)
  16. def __contains__(self, key):
  17. return self.get(key) is not None
  18. def __repr__(self):
  19. format_str = self.__class__.__name__ + \
  20. f'(name={self._name}, ' \
  21. f'items={self._module_dict})'
  22. return format_str
  23. @property
  24. def name(self):
  25. return self._name
  26. @property
  27. def module_dict(self):
  28. return self._module_dict
  29. def get(self, key):
  30. """Get the registry record.
  31. Args:
  32. key (str): The class name in string format.
  33. Returns:
  34. class: The corresponding class.
  35. """
  36. return self._module_dict.get(key, None)
  37. def _register_module(self, module_class, module_name=None, force=False):
  38. if not inspect.isclass(module_class):
  39. raise TypeError('module must be a class, '
  40. f'but got {type(module_class)}')
  41. if module_name is None:
  42. module_name = module_class.__name__
  43. if not force and module_name in self._module_dict:
  44. raise KeyError(f'{module_name} is already registered '
  45. f'in {self.name}')
  46. self._module_dict[module_name] = module_class
  47. def deprecated_register_module(self, cls=None, force=False):
  48. warnings.warn(
  49. 'The old API of register_module(module, force=False) '
  50. 'is deprecated and will be removed, please use the new API '
  51. 'register_module(name=None, force=False, module=None) instead.')
  52. if cls is None:
  53. return partial(self.deprecated_register_module, force=force)
  54. self._register_module(cls, force=force)
  55. return cls
  56. def register_module(self, name=None, force=False, module=None):
  57. """Register a module.
  58. A record will be added to `self._module_dict`, whose key is the class
  59. name or the specified name, and value is the class itself.
  60. It can be used as a decorator or a normal function.
  61. Example:
  62. >>> backbones = Registry('backbone')
  63. >>> @backbones.register_module()
  64. >>> class ResNet:
  65. >>> pass
  66. >>> backbones = Registry('backbone')
  67. >>> @backbones.register_module(name='mnet')
  68. >>> class MobileNet:
  69. >>> pass
  70. >>> backbones = Registry('backbone')
  71. >>> class ResNet:
  72. >>> pass
  73. >>> backbones.register_module(ResNet)
  74. Args:
  75. name (str | None): The module name to be registered. If not
  76. specified, the class name will be used.
  77. force (bool, optional): Whether to override an existing class with
  78. the same name. Default: False.
  79. module (type): Module class to be registered.
  80. """
  81. if not isinstance(force, bool):
  82. raise TypeError(f'force must be a boolean, but got {type(force)}')
  83. # NOTE: This is a walkaround to be compatible with the old api,
  84. # while it may introduce unexpected bugs.
  85. if isinstance(name, type):
  86. return self.deprecated_register_module(name, force=force)
  87. # use it as a normal method: x.register_module(module=SomeClass)
  88. if module is not None:
  89. self._register_module(
  90. module_class=module, module_name=name, force=force)
  91. return module
  92. # raise the error ahead of time
  93. if not (name is None or isinstance(name, str)):
  94. raise TypeError(f'name must be a str, but got {type(name)}')
  95. # use it as a decorator: @x.register_module()
  96. def _register(cls):
  97. self._register_module(
  98. module_class=cls, module_name=name, force=force)
  99. return cls
  100. return _register

(二) build_dataset(mmdet/datasets/builder)

build_dataset 也类似,通过调用 build_from_cfg 创建。

  1. def build_dataset(cfg, default_args=None):
  2. from .dataset_wrappers import (ConcatDataset, RepeatDataset,
  3. ClassBalancedDataset)
  4. if isinstance(cfg, (list, tuple)):
  5. dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
  6. elif cfg['type'] == 'RepeatDataset':
  7. dataset = RepeatDataset(
  8. build_dataset(cfg['dataset'], default_args), cfg['times'])
  9. elif cfg['type'] == 'ClassBalancedDataset':
  10. dataset = ClassBalancedDataset(
  11. build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
  12. elif isinstance(cfg.get('ann_file'), (list, tuple)):
  13. dataset = _concat_dataset(cfg, default_args)
  14. else:
  15. dataset = build_from_cfg(cfg, DATASETS, default_args)
  16. return dataset

(三) train_detector(mmdet/apis/train.py)

train_detector 的主要流程为:

(1.)构建 data loaders:

  1. data_loaders = [
  2. build_dataloader(
  3. ds,
  4. cfg.data.samples_per_gpu,
  5. cfg.data.workers_per_gpu,
  6. # cfg.gpus will be ignored if distributed
  7. len(cfg.gpu_ids),
  8. dist=distributed,
  9. seed=cfg.seed) for ds in dataset
  10. ]

(2.) 构建分布式处理对象:

  1. model = MMDistributedDataParallel(
  2. model.cuda(),
  3. device_ids=[torch.cuda.current_device()],
  4. broadcast_buffers=False,
  5. find_unused_parameters=find_unused_parameters)

(3.) 构建优化器:

  1. optimizer = build_optimizer(model, cfg.optimizer)

(4.) 创建 EpochBasedRunner 并进行训练:

  1. runner = EpochBasedRunner(
  2. model,
  3. optimizer=optimizer,
  4. work_dir=cfg.work_dir,
  5. logger=logger,
  6. meta=meta)

源码如下

  1. def train_detector(model,
  2. dataset,
  3. cfg,
  4. distributed=False,
  5. validate=False,
  6. timestamp=None,
  7. meta=None):
  8. # 获取 logger
  9. logger = get_root_logger(cfg.log_level)
  10. # ==================== 构建 data loaders ====================
  11. dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
  12. # 获得 samples_per_gpu
  13. if 'imgs_per_gpu' in cfg.data:
  14. logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
  15. 'Please use "samples_per_gpu" instead')
  16. if 'samples_per_gpu' in cfg.data:
  17. logger.warning(
  18. f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
  19. f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
  20. f'={cfg.data.imgs_per_gpu} is used in this experiments')
  21. else:
  22. logger.warning(
  23. 'Automatically set "samples_per_gpu"="imgs_per_gpu"='
  24. f'{cfg.data.imgs_per_gpu} in this experiments')
  25. cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
  26. data_loaders = [
  27. build_dataloader(
  28. ds,
  29. cfg.data.samples_per_gpu,
  30. cfg.data.workers_per_gpu,
  31. # cfg.gpus will be ignored if distributed
  32. len(cfg.gpu_ids),
  33. dist=distributed,
  34. seed=cfg.seed) for ds in dataset
  35. ]
  36. # ==================== 构建分布式处理对象 =====================
  37. # 如果是多卡会进入此 if
  38. if distributed:
  39. find_unused_parameters = cfg.get('find_unused_parameters', False)
  40. # Sets the `find_unused_parameters` parameter in
  41. # torch.nn.parallel.DistributedDataParallel
  42. model = MMDistributedDataParallel(
  43. model.cuda(),
  44. device_ids=[torch.cuda.current_device()],
  45. broadcast_buffers=False,
  46. find_unused_parameters=find_unused_parameters)
  47. # 单卡进入
  48. else:
  49. model = MMDataParallel(
  50. model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
  51. # ====================== 构建优化器 ==========================
  52. optimizer = build_optimizer(model, cfg.optimizer)
  53. # ============= 创建 EpochBasedRunner 并进行训练 ==============
  54. runner = EpochBasedRunner(
  55. model,
  56. optimizer=optimizer,
  57. work_dir=cfg.work_dir,
  58. logger=logger,
  59. meta=meta)
  60. # an ugly workaround to make .log and .log.json filenames the same
  61. runner.timestamp = timestamp
  62. # fp16 setting
  63. fp16_cfg = cfg.get('fp16', None)
  64. if fp16_cfg is not None:
  65. optimizer_config = Fp16OptimizerHook(
  66. **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
  67. elif distributed and 'type' not in cfg.optimizer_config:
  68. optimizer_config = OptimizerHook(**cfg.optimizer_config)
  69. else:
  70. optimizer_config = cfg.optimizer_config
  71. # register hooks
  72. runner.register_training_hooks(cfg.lr_config, optimizer_config,
  73. cfg.checkpoint_config, cfg.log_config,
  74. cfg.get('momentum_config', None))
  75. if distributed:
  76. runner.register_hook(DistSamplerSeedHook())
  77. # register eval hooks
  78. if validate:
  79. val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
  80. val_dataloader = build_dataloader(
  81. val_dataset,
  82. samples_per_gpu=1,
  83. workers_per_gpu=cfg.data.workers_per_gpu,
  84. dist=distributed,
  85. shuffle=False)
  86. eval_cfg = cfg.get('evaluation', {})
  87. eval_hook = DistEvalHook if distributed else EvalHook
  88. runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
  89. if cfg.resume_from:
  90. runner.resume(cfg.resume_from)
  91. elif cfg.load_from:
  92. runner.load_checkpoint(cfg.load_from)
  93. runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

(四)set_random_seed:

  1. 此函数会对 pythonnumpytorch 都设置随机数种子。
  2. 保持随机数种子相同时,卷积的结果在CPU上相同,在GPU上仍然不相同。这是因为,cudnn卷积行为的不确定性。使用 torch.backends.cudnn.deterministic = True 可以解决。
  3. cuDNN 使用非确定性算法,并且可以使用 torch.backends.cudnn.enabled = False 来进行禁用。如果设置为 torch.backends.cudnn.enabled = True,说明设置为使用非确定性算法(即会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题)。

一般来讲,应该遵循以下准则:

  1. 如果网络的输入数据维度或类型上变化不大,设置 torch.backends.cudnn.benchmark = true 可以增加运行效率
  2. 如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率。设置 torch.backends.cudnn.benchmark = False 避免重复搜索。
  1. def set_random_seed(seed, deterministic=False):
  2. random.seed(seed)
  3. np.random.seed(seed)
  4. torch.manual_seed(seed)
  5. # manual_seed_all 是为所有 GPU 都设置随机数种子。
  6. torch.cuda.manual_seed_all(seed)
  7. if deterministic:
  8. torch.backends.cudnn.deterministic = True
  9. torch.backends.cudnn.benchmark = False

(五)get_root_logger:

get_root_logger 调用 get_logger 函数获取 logger 对象。

  1. import logging
  2. from mmcv.utils import get_logger
  3. def get_root_logger(log_file=None, log_level=logging.INFO):
  4. logger = get_logger(name='mmdet', log_file=log_file, log_level=log_level)
  5. return logger

这里实现的 get_logger 函数非常灵活,如果传入相同的 log 的 name,会返回配置相同的 log。传入以点分割的日志名称的子模块,也会返回相同的 log。如:a 和 a.b 会返回相同的 log。如果传入 log_file 会保存 log 的输出到 log_file 指定的路径,如果不传入 log_file,不保存日志的输出。只在控制台输出。

  1. import logging
  2. import torch.distributed as dist
  3. # 记录是否创建过 name 对应的 log,如果创建过设置为 True
  4. logger_initialized = {}
  5. def get_logger(name, log_file=None, log_level=logging.INFO):
  6. # 获取 log 对象。
  7. logger = logging.getLogger(name)
  8. # 如果已经创建过,直接返回
  9. if name in logger_initialized:
  10. return logger
  11. # 如果是创建过的以 ‘.’ 分割的子模块,也直接返回
  12. for logger_name in logger_initialized:
  13. if name.startswith(logger_name):
  14. return logger
  15. stream_handler = logging.StreamHandler()
  16. handlers = [stream_handler]
  17. # 获取当前的 rank(总进程编号)
  18. if dist.is_available() and dist.is_initialized():
  19. rank = dist.get_rank()
  20. else:
  21. rank = 0
  22. # 只有 rank 0(master 节点的 local_rank 为 0 的进程)的主机才保存日志
  23. if rank == 0 and log_file is not None:
  24. file_handler = logging.FileHandler(log_file, 'w')
  25. handlers.append(file_handler)
  26. formatter = logging.Formatter(
  27. '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  28. for handler in handlers:
  29. handler.setFormatter(formatter)
  30. handler.setLevel(log_level)
  31. logger.addHandler(handler)
  32. # 对于非 rank 为 0 的进程,只有 error 以上的信息才会显示
  33. if rank == 0:
  34. logger.setLevel(log_level)
  35. else:
  36. logger.setLevel(logging.ERROR)
  37. # 将 log name 对应的值设为 True,表示创建过。
  38. logger_initialized[name] = True
  39. return logger

本文转载自: https://blog.csdn.net/weixin_45912366/article/details/124504983
版权归原作者 小刺猬69 所有, 如有侵权,请联系我们删除。

“MMdetection之train.py源码详解”的评论:

还没有评论