in easycv/apis/train.py [0:0]
def train_model(model,
data_loaders,
cfg,
distributed=False,
timestamp=None,
meta=None,
use_fp16=False,
validate=True,
gpu_collect=True):
""" Training API.
Args:
model (:obj:`nn.Module`): user defined model
data_loaders: a list of dataloader for training data
cfg: config object
distributed: distributed training or not
timestamp: time str formated as '%Y%m%d_%H%M%S'
meta: a dict containing meta data info, such as env_info, seed, iter, epoch
use_fp16: use fp16 training or not
validate: do evaluation while training
gpu_collect: use gpu collect or cpu collect for tensor gathering
"""
logger = get_root_logger(cfg.log_level)
print('GPU INFO : ', torch.cuda.get_device_name(0))
# model.cuda() must be before build_optimizer in torchacc mode
model = model.cuda()
if cfg.model.type == 'YOLOX':
optimizer = build_yolo_optimizer(model, cfg.optimizer)
else:
optimizer = build_optimizer(model, cfg.optimizer)
# when use amp from apex, we should initialze amp with model not wrapper by DDP or DP,
# so we need to inialize amp here. In torch 1.6 or later, we do not need this
if use_fp16 and LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
# SyncBatchNorm
open_sync_bn = cfg.get('sync_bn', False)
if open_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info('Using SyncBatchNorm()')
# the functions of torchacc DDP split into OptimizerHook and TorchaccLoaderWrapper
if not is_torchacc_enabled():
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
model = MMDistributedDataParallel(
model,
find_unused_parameters=find_unused_parameters,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
else:
model = MMDataParallel(model, device_ids=range(cfg.gpus))
# build runner
runner = EVRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta,
fp16_enable=use_fp16)
runner.data_loader = data_loaders
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
optimizer_config = cfg.optimizer_config
if use_fp16:
assert torch.cuda.is_available(), 'cuda is needed for fp16'
optimizer_config = AMPFP16OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
# process tensor type, convert to numpy for dump logs
if len(cfg.log_config.get('hooks', [])) > 0:
cfg.log_config.hooks.insert(0, dict(type='PreLoggerHook'))
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
if distributed:
logger.info('register DistSamplerSeedHook')
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
if 'eval_pipelines' not in cfg:
runner.logger.warning(
'Not find `eval_pipelines` in cfg, skip validation!')
validate = False
else:
if isinstance(cfg.eval_pipelines, dict):
cfg.eval_pipelines = [cfg.eval_pipelines]
assert len(cfg.eval_pipelines) > 0
runner.logger.info('open validate hook')
best_metric_name = [
] # default is eval_pipe.evaluators[0]['type'] + eval_dataset_name + [metric_names]
best_metric_type = []
if validate:
interval = cfg.eval_config.pop('interval', 1)
for idx, eval_pipe in enumerate(cfg.eval_pipelines):
data = eval_pipe.get('data', None) or cfg.data.val
dist_eval = eval_pipe.get('dist_eval', False)
evaluator_cfg = eval_pipe.evaluators[0]
# get the metric_name
eval_dataset_name = evaluator_cfg.get('dataset_name', None)
default_metrics = METRICS.get(evaluator_cfg['type'])['metric_name']
default_metric_type = METRICS.get(
evaluator_cfg['type'])['metric_cmp_op']
if 'metric_names' not in evaluator_cfg:
evaluator_cfg['metric_names'] = default_metrics
eval_metric_names = evaluator_cfg['metric_names']
# get the metric_name
this_metric_names = generate_best_metric_name(
evaluator_cfg['type'], eval_dataset_name, eval_metric_names)
best_metric_name = best_metric_name + this_metric_names
# get the metric_type
this_metric_type = evaluator_cfg.pop('metric_type',
default_metric_type)
this_metric_type = this_metric_type + ['max'] * (
len(this_metric_names) - len(this_metric_type))
best_metric_type = best_metric_type + this_metric_type
imgs_per_gpu = data.pop('imgs_per_gpu', cfg.data.imgs_per_gpu)
workers_per_gpu = data.pop('workers_per_gpu',
cfg.data.workers_per_gpu)
if not is_dali_dataset_type(data['type']):
val_dataset = build_dataset(data)
val_dataloader = build_dataloader(
val_dataset,
imgs_per_gpu=imgs_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=(distributed and dist_eval),
shuffle=False,
seed=cfg.seed)
else:
default_args = dict(
batch_size=imgs_per_gpu,
workers_per_gpu=workers_per_gpu,
distributed=distributed)
val_dataset = build_dataset(data, default_args)
val_dataloader = val_dataset.get_dataloader()
evaluators = build_evaluator(eval_pipe.evaluators)
eval_cfg = cfg.eval_config
eval_cfg['evaluators'] = evaluators
eval_hook = DistEvalHook if (distributed
and dist_eval) else EvalHook
if eval_hook == EvalHook:
eval_cfg.pop('gpu_collect', None) # only use in DistEvalHook
logger.info(f'register EvaluationHook {eval_cfg}')
# only flush log buffer at the last eval hook
flush_buffer = (idx == len(cfg.eval_pipelines) - 1)
runner.register_hook(
eval_hook(
val_dataloader,
interval=interval,
mode=eval_pipe.mode,
flush_buffer=flush_buffer,
**eval_cfg))
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
common_params = {}
if hook_cfg.type == 'DeepClusterHook':
common_params = dict(
dist_mode=distributed, data_loaders=data_loaders)
else:
common_params = dict(dist_mode=distributed)
hook = build_hook(hook_cfg, default_args=common_params)
runner.register_hook(hook, priority=priority)
if cfg.get('ema', None):
runner.logger.info('register ema hook')
runner.register_hook(EMAHook(decay=cfg.ema.decay))
if len(best_metric_name) > 0:
runner.register_hook(
BestCkptSaverHook(
by_epoch=True,
save_optimizer=True,
best_metric_name=best_metric_name,
best_metric_type=best_metric_type))
# export hook
if getattr(cfg, 'checkpoint_sync_export', False):
runner.register_hook(ExportHook(cfg))
# oss sync hook
if cfg.oss_work_dir is not None:
if cfg.checkpoint_config.get('by_epoch', True):
runner.register_hook(
OSSSyncHook(
cfg.work_dir,
cfg.oss_work_dir,
interval=cfg.checkpoint_config.interval,
**cfg.get('oss_sync_config', {})))
else:
runner.register_hook(
OSSSyncHook(
cfg.work_dir,
cfg.oss_work_dir,
interval=1,
iter_interval=cfg.checkpoint_config.interval),
**cfg.get('oss_sync_config', {}))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.logger.info(f'load checkpoint from {cfg.load_from}')
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)