def train_model()

in easycv/apis/train.py [0:0]


def train_model(model,
                data_loaders,
                cfg,
                distributed=False,
                timestamp=None,
                meta=None,
                use_fp16=False,
                validate=True,
                gpu_collect=True):
    """ Training API.

    Args:
        model (:obj:`nn.Module`): user defined model
        data_loaders: a list of dataloader for training data
        cfg: config object
        distributed: distributed training or not
        timestamp: time str formated as '%Y%m%d_%H%M%S'
        meta: a dict containing meta data info, such as env_info, seed, iter, epoch
        use_fp16: use fp16 training or not
        validate: do evaluation while training
        gpu_collect: use gpu collect or cpu collect for tensor gathering

    """
    logger = get_root_logger(cfg.log_level)
    print('GPU INFO : ', torch.cuda.get_device_name(0))

    # model.cuda() must be before build_optimizer in torchacc mode
    model = model.cuda()

    if cfg.model.type == 'YOLOX':
        optimizer = build_yolo_optimizer(model, cfg.optimizer)
    else:
        optimizer = build_optimizer(model, cfg.optimizer)

    # when use amp from apex, we should initialze amp with model not wrapper by DDP or DP,
    # so  we need to inialize amp here. In torch 1.6 or later, we do not need this
    if use_fp16 and LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # SyncBatchNorm
    open_sync_bn = cfg.get('sync_bn', False)

    if open_sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        logger.info('Using SyncBatchNorm()')

    # the functions of torchacc DDP split into OptimizerHook and TorchaccLoaderWrapper
    if not is_torchacc_enabled():
        if distributed:
            find_unused_parameters = cfg.get('find_unused_parameters', False)
            model = MMDistributedDataParallel(
                model,
                find_unused_parameters=find_unused_parameters,
                device_ids=[torch.cuda.current_device()],
                broadcast_buffers=False)
        else:
            model = MMDataParallel(model, device_ids=range(cfg.gpus))

    # build runner
    runner = EVRunner(
        model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta,
        fp16_enable=use_fp16)
    runner.data_loader = data_loaders

    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp
    optimizer_config = cfg.optimizer_config

    if use_fp16:
        assert torch.cuda.is_available(), 'cuda is needed for fp16'
        optimizer_config = AMPFP16OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)

    # process tensor type, convert to numpy for dump logs
    if len(cfg.log_config.get('hooks', [])) > 0:
        cfg.log_config.hooks.insert(0, dict(type='PreLoggerHook'))
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)

    if distributed:
        logger.info('register DistSamplerSeedHook')
        runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        if 'eval_pipelines' not in cfg:
            runner.logger.warning(
                'Not find `eval_pipelines` in cfg, skip validation!')
            validate = False
        else:
            if isinstance(cfg.eval_pipelines, dict):
                cfg.eval_pipelines = [cfg.eval_pipelines]
            assert len(cfg.eval_pipelines) > 0
            runner.logger.info('open validate hook')

    best_metric_name = [
    ]  # default is eval_pipe.evaluators[0]['type'] + eval_dataset_name + [metric_names]
    best_metric_type = []
    if validate:
        interval = cfg.eval_config.pop('interval', 1)
        for idx, eval_pipe in enumerate(cfg.eval_pipelines):
            data = eval_pipe.get('data', None) or cfg.data.val
            dist_eval = eval_pipe.get('dist_eval', False)

            evaluator_cfg = eval_pipe.evaluators[0]
            # get the metric_name
            eval_dataset_name = evaluator_cfg.get('dataset_name', None)
            default_metrics = METRICS.get(evaluator_cfg['type'])['metric_name']
            default_metric_type = METRICS.get(
                evaluator_cfg['type'])['metric_cmp_op']
            if 'metric_names' not in evaluator_cfg:
                evaluator_cfg['metric_names'] = default_metrics
            eval_metric_names = evaluator_cfg['metric_names']

            # get the metric_name
            this_metric_names = generate_best_metric_name(
                evaluator_cfg['type'], eval_dataset_name, eval_metric_names)
            best_metric_name = best_metric_name + this_metric_names

            # get the metric_type
            this_metric_type = evaluator_cfg.pop('metric_type',
                                                 default_metric_type)
            this_metric_type = this_metric_type + ['max'] * (
                len(this_metric_names) - len(this_metric_type))
            best_metric_type = best_metric_type + this_metric_type

            imgs_per_gpu = data.pop('imgs_per_gpu', cfg.data.imgs_per_gpu)
            workers_per_gpu = data.pop('workers_per_gpu',
                                       cfg.data.workers_per_gpu)
            if not is_dali_dataset_type(data['type']):
                val_dataset = build_dataset(data)
                val_dataloader = build_dataloader(
                    val_dataset,
                    imgs_per_gpu=imgs_per_gpu,
                    workers_per_gpu=workers_per_gpu,
                    dist=(distributed and dist_eval),
                    shuffle=False,
                    seed=cfg.seed)
            else:
                default_args = dict(
                    batch_size=imgs_per_gpu,
                    workers_per_gpu=workers_per_gpu,
                    distributed=distributed)
                val_dataset = build_dataset(data, default_args)
                val_dataloader = val_dataset.get_dataloader()

            evaluators = build_evaluator(eval_pipe.evaluators)
            eval_cfg = cfg.eval_config
            eval_cfg['evaluators'] = evaluators
            eval_hook = DistEvalHook if (distributed
                                         and dist_eval) else EvalHook
            if eval_hook == EvalHook:
                eval_cfg.pop('gpu_collect', None)  # only use in DistEvalHook
            logger.info(f'register EvaluationHook {eval_cfg}')
            # only flush log buffer at the last eval hook
            flush_buffer = (idx == len(cfg.eval_pipelines) - 1)
            runner.register_hook(
                eval_hook(
                    val_dataloader,
                    interval=interval,
                    mode=eval_pipe.mode,
                    flush_buffer=flush_buffer,
                    **eval_cfg))

    # user-defined hooks
    if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        assert isinstance(custom_hooks, list), \
            f'custom_hooks expect list type, but got {type(custom_hooks)}'
        for hook_cfg in cfg.custom_hooks:
            assert isinstance(hook_cfg, dict), \
                'Each item in custom_hooks expects dict type, but got ' \
                f'{type(hook_cfg)}'
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')

            common_params = {}
            if hook_cfg.type == 'DeepClusterHook':
                common_params = dict(
                    dist_mode=distributed, data_loaders=data_loaders)
            else:
                common_params = dict(dist_mode=distributed)

            hook = build_hook(hook_cfg, default_args=common_params)
            runner.register_hook(hook, priority=priority)

    if cfg.get('ema', None):
        runner.logger.info('register ema hook')
        runner.register_hook(EMAHook(decay=cfg.ema.decay))

    if len(best_metric_name) > 0:
        runner.register_hook(
            BestCkptSaverHook(
                by_epoch=True,
                save_optimizer=True,
                best_metric_name=best_metric_name,
                best_metric_type=best_metric_type))

    # export hook
    if getattr(cfg, 'checkpoint_sync_export', False):
        runner.register_hook(ExportHook(cfg))

    # oss sync hook
    if cfg.oss_work_dir is not None:
        if cfg.checkpoint_config.get('by_epoch', True):
            runner.register_hook(
                OSSSyncHook(
                    cfg.work_dir,
                    cfg.oss_work_dir,
                    interval=cfg.checkpoint_config.interval,
                    **cfg.get('oss_sync_config', {})))
        else:
            runner.register_hook(
                OSSSyncHook(
                    cfg.work_dir,
                    cfg.oss_work_dir,
                    interval=1,
                    iter_interval=cfg.checkpoint_config.interval),
                **cfg.get('oss_sync_config', {}))

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.logger.info(f'load checkpoint from {cfg.load_from}')
        runner.load_checkpoint(cfg.load_from)

    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)