mmdetection2.x框架源码梳理
MMDetection是商汤联合港中文开发的一个基于pytorch的深度学习目标检测框架,mmdetection是一个包含流行的目标检测算法以及现在常用模块的统一的目标检测平台。最近研究了一下mmdetection的使??式以及源码组织?式,学会??扩展模型以及?些特定的函数。主要介绍关于mmdetection代码框架的组织方式以及其中hook,runner,register这三个组件的理解。
?
OpenMMlab的其他开源代码库,例如分类MMClassification、分割MMSegmentation、视频动作识别MMAction2等,其代码框架的组织方式基本上都是一致的组织方式,本文对于其他的openmmlab开源代码依然有用。
?
?
文章较长,建议先赞后看。完整版pdf资料,请至小哲AI公众号,回“mmdetection”自取。
?
(?) mmdetection2.x总体结构
mmdetection依赖于mmcv.
1.1 mmdetection2.x的整体代码结构
- configs: ?络组件结构的配置信息,在mmdetection2中采?继承结构。
- tools: 训练与测试的最终包装以及?些?常实?的程序脚本?件
- mmdet: apis: 推理训练测试的基础代码,这?就存在三个程序脚本?件train.py, test.py, inference.py. core: anchor ?成,bbox,mask 编解码, 变换, 标签锚定, 采样等,模型评估, 加速, 优化器,后处理等 datasets: coco,voc 等数据类, 数据 pipelines 的统?格式, 数据增强,数据采样 models: 模型组件(backbone, dense_heads, roi_heads, necks,detectors, losses)
1.2 训练过程逻辑
训练入口从tools/train.py进?, 从其中代码可以看到整体可分为如下的?个步骤:
- mmcv.utils.config.Config.fromfile从配置文件来解析配置信息
- mmdet中的builder.py来构建模型数据类以 及对应的训练train_pipeline。 builder.py依据mmcv.utils中的build_from_cfg定义了build_backbone,build_detector,build_head, build_loss,build_neck,build_roi_extractor,build_shared_head的函数
- 在mmdet.models中利?上边定义的build_detector来依据config文件的配置信息来构建模型. build 系列函数调? build_from_cfg 函数, 按 type 关键字从register表中获取相应的对象
- 然后依据mmcv.utils.registry.py中的模型组件注册器。其中注册器的 register_module成员函数是?个装饰器功能函数完成注册.
- 在mmdet.builder.py中定义了BACKBONES, DETECTORS,HEADS, LOSSES, NECKS,ROI_EXTRACTORS, SHARED_HEADS的模型注册器.
- 在mmdet.datasets中builder.py中定义了DATASETS,PIPELINES, build_dataloader, build_dataset 4.1. 然后依据对应的config?件构建数据集以及pipline
- 调?mmdet.apis中的train_detector函数进?模型训练.
?
这里需要研究?下register, runner与hook的?作?式以及在register中采?的python装饰器的功能作?。
?
- Runner 是整个训练过程的流程控制
- HOOK 定义在每个迭代周期(epoch)或者每个迭代步骤(iter)之前之后的?些操作
- Register 完成了mmdetection模块化设计(为模块化服务的字符串->类的字典)
1.3 HOOK
hook(钩子),又称为hook函数,使用技术手段在运行时动态的将额外代码依附现进程,从而实现替换现有处理逻辑或插入额外功能的目的。
mmdetection中的hook机制定义在mmcv.runner.hooks中. hook的基类定义在 mmcv.runner.hooks.hook.py中。这个基类函数定义了在每个训练epoch或者iter开始结束时需要执行的操作传入的参数均为runner。runner是什么下一节再聊。
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, is_method_overridden
HOOKS = Registry('hook')
class Hook:
stages = ('before_run', 'before_train_epoch', 'before_train_iter',
'after_train_iter', 'after_train_epoch', 'before_val_epoch',
'before_val_iter', 'after_val_iter', 'after_val_epoch',
'after_run')
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)
def is_last_epoch(self, runner):
return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters
def get_triggered_stages(self):
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):
trigger_stages.add(stage)
# some methods will be triggered in multi stages
# use this dict to map method to stages.
method_stages_map = {
'before_epoch': ['before_train_epoch', 'before_val_epoch'],
'after_epoch': ['after_train_epoch', 'after_val_epoch'],
'before_iter': ['before_train_iter', 'before_val_iter'],
'after_iter': ['after_train_iter', 'after_val_iter'],
}
for method, map_stages in method_stages_map.items():
if is_method_overridden(method, Hook, self):
trigger_stages.update(map_stages)
return [stage for stage in Hook.stages if stage in trigger_stages]
在runner中定义了?个hook的list,list中的每?个元素就是?个实例化的HOOK对象。
注册hook的?法一共有两种:
- register_hook是传??个实例化的HOOK对象,并将它插?到?个列表中。
- register_hook_from_cfg是传??个配置项,根据配置项来实例化HOOK对象并插?到列表中。
第?种?法?是MMLab的开源?态中定义的?种基础?法,mmcv.build_from_cfg (存在于mmcv.utils.register.py中)了,?论在MMdetection还是其他MMLab开源的算法框架中,都遵循着MMCV的这套基于配置项实例化对象的?法。
register_hook与register_hook_from_cfg的代码存在于mmcv.runner.base_runner.py中.
# 将?个hook对象插?到hook_list中.
def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list.
hook将会插?优先级队列中, 对于同样优先级的hook, 将会按照注册
的顺序进?触发.
The hook will be inserted into a priority queue, with the specified
priority (See :class:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same
order as they are registered.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
"""
assert isinstance(hook, Hook)
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
# 得到hook对应的优先级.
priority = get_priority(priority)
hook.priority = priority
# 按照制定额的优先级插?list中.
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
# 利?cfg配置项来讲hook对象插?hook_list中.
def register_hook_from_cfg(self, hook_cfg):
"""Register a hook from its cfg.
Args:
hook_cfg (dict): Hook config. It should have at least keys 'type'
and 'priority' indicating its type and priority.
Note:
The specific hook class to register should not use 'type' and
'priority' arguments during initialization.
"""
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
self.register_hook(hook, priority=priority)
# 调?hook函数
def call_hook(self, fn_name):
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""
for hook in self._hooks:
getattr(hook, fn_name)(self)
从上面代码可以看到HOOK是调用的时候是遍历hook_List,然后根据HOOK的名字来调?。优先级越?的放在hook_list的前面,这样就能更快地被调用。当想要使用hook函数,例如before_run_epoch来做A和B两件事情的时候就是通过不同的HOOK的优先级来决定了A与B分别的执行顺序。
在mmcv.runner.priority.py中?共定义了7种优先级.
class Priority(Enum):
"""Hook priority levels.
+--------------+------------+
| Level | Value |
+==============+============+
| HIGHEST | 0 |
+--------------+------------+
| VERY_HIGH | 10 |
+--------------+------------+
| HIGH | 30 |
+--------------+------------+
| ABOVE_NORMAL | 40 |
+--------------+------------+
| NORMAL | 50 |
+--------------+------------+
| BELOW_NORMAL | 60 |
+--------------+------------+
| LOW | 70 |
+--------------+------------+
| VERY_LOW | 90 |
+--------------+------------+
| LOWEST | 100 |
+--------------+------------+
"""
HIGHEST = 0
VERY_HIGH = 10
HIGH = 30
ABOVE_NORMAL = 40
NORMAL = 50
BELOW_NORMAL = 60
LOW = 70
VERY_LOW = 90
LOWEST = 100
def get_priority(priority):
"""Get priority value.
Args:
priority (int or str or :obj:`Priority`): Priority.
Returns:
int: The priority value.
"""
if isinstance(priority, int):
if priority < 0 or priority > 100:
raise ValueError('priority must be between 0 and 100')
return priority
elif isinstance(priority, Priority):
return priority.value
elif isinstance(priority, str):
return Priority[priority.upper()].value
else:
raise TypeError('priority must be an integer or Priority enum value')
1.4 Register
注册器register实际上完成了字符串到类的映射,因此我们可以直接通过更改config文件直接实现模型或者模块的更换。在MMDetection中所有功能都是基于注册器来完成模块化操作的。其中最经典的就是在MMdetection构件模型的builder.py(代码在mmdetection.mmdet.builder.py )中就通过注册器完成模型的模块化。
在mmcv.utils.registry.py中,register类的源码如下。
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
## register的核心代码,主要实现字符串到类的映射,module_class为一个类
self._module_dict[name] = module_class
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.',
DeprecationWarning)
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
# 这是一个注册器,实现类的注册
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
# 第一种注册的方式
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
# 第二种注册方式(mmdetection采用的方案)
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
# 第三种注册方式
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
下?这个就是在构建模型中builder.py中的实例化的注册器。
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS)
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
def build_backbone(cfg):
"""Build backbone."""
return BACKBONES.build(cfg)
def build_neck(cfg):
"""Build neck."""
return NECKS.build(cfg)
def build_roi_extractor(cfg):
"""Build roi extractor."""
return ROI_EXTRACTORS.build(cfg)
def build_shared_head(cfg):
"""Build shared head."""
return SHARED_HEADS.build(cfg)
def build_head(cfg):
"""Build head."""
return HEADS.build(cfg)
def build_loss(cfg):
"""Build loss."""
return LOSSES.build(cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
在mmcv.cnn.builder.py中实现模型的组装起来。
def build_model_from_cfg(cfg, registry, default_args=None):
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
Args:
cfg (dict, list[dict]): The config of modules, is is either a config
dict or a list of config dicts. If cfg is a list, a
the built modules will be wrapped with ``nn.Sequential``.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
MODELS = Registry('model', build_func=build_model_from_cfg)
backbone neck head loss等等应该是要在代码中?动操作的。?如在mmdetection.mmdet.models.detectors.two_stage.py中(faster_rcnn继承这个基类),forward的流程是这样的。
def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None,
**kwargs):
"""
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
gt_masks (None | Tensor) : true segmentation masks for each box
used if the architecture supports a segmentation task.
proposals : override rpn proposals with custom proposals. Use when
`with_rpn` is False.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
if self.with_rpn:
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
rpn_losses, proposal_list = self.rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg,
**kwargs)
losses.update(rpn_losses)
else:
proposal_list = proposals
roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
gt_bboxes, gt_labels,
gt_bboxes_ignore, gt_masks,
**kwargs)
losses.update(roi_losses)
return losses
?
「总结: 注册器只是提供从配置?件?成实例对象的?种方式,生成字符串->类的字典」
?
1.5 Runner
runner将深度学习算法包含的数据读取、模型构建、训练、评估、推理这五个部分组合在一起。
Runner的源码封装在MMCV库当中,主要包含epoch_runner(mmcv.runner.epoch_based_runner.py )和iter_runner(mmcv.runner.iter_based_runner.py )两种。常?的是epoch_runner。
@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
# 在训练的过程中执行hook函数
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
"""Save the checkpoint.
Args:
out_dir (str): The directory that checkpoints are saved.
filename_tmpl (str, optional): The checkpoint filename template,
which contains a placeholder for the epoch number.
Defaults to 'epoch_{}.pth'.
save_optimizer (bool, optional): Whether to save the optimizer to
the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
# Note: meta.update(self.meta) should be done before
# meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
# there will be problems with resumed checkpoints.
# More details in https://github.com/open-mmlab/mmcv/pull/1108
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
dst_file = osp.join(out_dir, 'latest.pth')
if platform.system() != 'Windows':
mmcv.symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)
@RUNNERS.register_module()
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner."""
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use EpochBasedRunner instead',
DeprecationWarning)
super().__init__(*args, **kwargs)
mmdetection2.x的流程分析
2.1 训练入口函数
直接从训练??处??tools/train.py , 伴随着上?中介绍的HOOK,Register, Runner机制来理解代码.
def main():
args = parse_args()
# mmcv.Config.fromfile 从配置?件解析配置信息, 并做适当更新,
# 包括环境搜集,预加载模型?件, 分布式设置,?志记录等
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)
###################################################################
## 构建检测器mmdet/builder.py
# 这?依据config?件中的model这个字典,来构建模型,返回?个类对象
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
# 构建数据集mmdet/builder.py
# 依据给定的config?件中的data参数
datasets = [build_dataset(cfg.data.train)]
## 在configs/__base__/default_runtime.py中存在的参数,是否需要添加验证集.
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__ + get_git_hash()[:7],
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
## 训练模型,函数存在mmdet/apis/train.py中
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
2.2 模型构建mmdet/builder.py
# 在tools/train.py中的模型构建过程.
# model = build_detector(cfg.model,
# train_cfg=cfg.get('train_cfg'),
# test_cfg=cfg.get('test_cfg'))
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
if train_cfg is not None or test_cfg is not None:
warnings.warn(
'train_cfg and test_cfg is deprecated, '
'please specify them in model', UserWarning)
assert cfg.get('train_cfg') is None or train_cfg is None, \
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
# 代码来?: mmcv/utils/registry.py
# def build_from_cfg(cfg, registry, default_args=None):
# """Build a module from config dict.
#
# Args:
# cfg (dict): Config dict. It should at least contain the key "type".
# config?件中?少要包含"type"这个键
# registry (:obj:`Registry`): The registry to search the type from.
# default_args (dict, optional): Default initialization arguments.
#
# Returns:
# object: The constructed object.
# """
# config必须是dict格式
# if not isinstance(cfg, dict):
# raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
# if 'type' not in cfg:
# if default_args is None or 'type' not in default_args:
# raise KeyError(
# '`cfg` or `default_args` must contain the key "type", '
# f'but got {cfg}\n{default_args}')
# if not isinstance(registry, Registry):
# raise TypeError('registry must be an mmcv.Registry object, '
# f'but got {type(registry)}')
# if not (isinstance(default_args, dict) or default_args is None):
# raise TypeError('default_args must be a dict or None, '
# f'but got {type(default_args)}')
#
# args = cfg.copy()
#
# if default_args is not None:
# for name, value in default_args.items():
# 如果键不存在于字典中,将会添加键并将值设为默认值。
# args.setdefault(name, value)
#
# obj_type = args.pop('type')
# if isinstance(obj_type, str):
# 从注册器中提取出type对应的类,注册器的作?就是将字符串与类进?对应
# obj_cls = registry.get(obj_type)
# if obj_cls is None:
# raise KeyError(
# f'{obj_type} is not in the {registry.name} registry')
# elif inspect.isclass(obj_type):
# obj_cls = obj_type
# else:
# raise TypeError(
# f'type must be a str or valid type, but got {type(obj_type)}')
# try:
# return obj_cls(**args)
# except Exception as e:
# # Normal TypeError does not print class name.
# raise type(e)(f'{obj_cls.__name__}: {e}')
2.3 数据集构建datasets/builder.py
利?datasets/builder.py中的build_dataset函数构建数据集。这个构建过程包含pipeline的构建, 由于数据集均继承?CustomeDataset , 在其中使?Compose类对pipeline进?处理。然后在mmdet/datasets/pipelines/compose.py中完成pipeline的build。
2.4 mmdet/apis/train.py中的训练过程
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
logger = get_root_logger(log_level=cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
'type']
# 构建dataloader
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# `num_gpus` will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
runner_type=runner_type) for ds in dataset
]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if 'runner' not in cfg:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs
# build runner
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
# register hooks
# 注册hook的过程实际就是按照优先级将对应的操作,添加到优先级队列中
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
# 最后开始运?,代码在mmcv/runner中
# 执?runner对象的运?操作
runner.run(data_loaders, cfg.workflow)
2.5 Runner中的处理过程
最后分析runner的处理,主要代码存在mmcv/runner这个?件夹下。 主要分析:
- runner类的build过程
- runner类的hook的注册过程
- runner的成员函数run的执?过程
2.5.1 runner类的构建
代码存在于mmcv/runner/based_runner.py中. EpochBaseRunner 继承BaseRunner ,主要的runner初始化构建过程存在于BaseRunner类中。
class BaseRunner(metaclass=ABCMeta):
"""The base class of Runner, a training helper for PyTorch.
All subclasses should implement the following APIs:
- ``run()``
- ``train()``
- ``val()``
- ``save_checkpoint()``
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
optimizer (in most cases) or a dict of optimizers (in models that
requires more than one optimizer, e.g., GAN).
work_dir (str, optional): The working directory to save checkpoints
and logs. Defaults to None.
logger (:obj:`logging.Logger`): Logger used during training.
Defaults to None. (The default value is just for backward
compatibility)
meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
max_epochs (int, optional): Total training epochs.
max_iters (int, optional): Total training iterations.
"""
def __init__(self,
model,
batch_processor=None,
optimizer=None,
work_dir=None,
logger=None,
meta=None,
max_iters=None,
max_epochs=None):
2.5.2 runner hook的注册过程
两种注册的?式代码存在于mmcv/runner/base_runner.py/BaseRunner中
#runner中的hooks是?个列表优先级?的放置在列表的前边,优先级低的放置在后边.
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
######################下边的这两种hook的注册?式,在前?hook机制的介绍中已经给出.
def register_hook(self, hook, priority='NORMAL'):
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
priority (See :class:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same
order as they are registered.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
"""
assert isinstance(hook, Hook)
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
def register_hook_from_cfg(self, hook_cfg):
"""Register a hook from its cfg.
Args:
hook_cfg (dict): Hook config. It should have at least keys 'type'
and 'priority' indicating its type and priority.
Note:
The specific hook class to register should not use 'type' and
'priority' arguments during initialization.
"""
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
self.register_hook(hook, priority=priority)
2.5.3 runner中的run函数运?机制
代码存在于mmcv/runner/epoch_base_runner.py中,这个类重写的run函数.
@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
# 执?每次的迭代
# 以faster rcnn为例: FasterRCNN类 ==> TwoStageDetector 类==> BaseDetector类
# 代码存在mmdet/models/detectors/base.py
# 因此这个操作执??次前向传播,这?边不包含反向传播与迭代
# 器的更新,这些操作存在于optimizer的hook中,
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
# self._max_iters ?共需要迭代运?的次数.
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
# 每个epoch的迭代过程
for i, data_batch in enumerate(self.data_loader):
# _inner_iter表?在每个epoch迭代的过程的步骤.
self._inner_iter = i
self.call_hook('before_train_iter')
# 每个迭代步骤
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
# 所有的迭代步骤的逐步叠加过程记录.
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
# 执?所有训练过程的迭代.
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
# 获得?个函数,train()还是val()
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
# 执?train()或者Val()函数.
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
@RUNNERS.register_module()
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner."""
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use EpochBasedRunner instead',
DeprecationWarning)
super().__init__(*args, **kwargs)
?
如果本文对你有丝毫的帮助,请帮忙点个免费的赞,你的鼓励是对我莫大的支持
?
?
原文首发于「小哲AI」公众号,公众号主要分享人工智能前沿算法解读,AI项目代码解析,以及编程、互联网求职等技术资料文章,偶尔也会分享个人读书笔记、工作学习心得,欢迎关注,一起学习。
?