def train()

in MaskRCNN/pytorch/tools/train_net.py [0:0]


def train(cfg, local_rank, distributed, fp16, dllogger,args):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if use_amp:
        # Initialize mixed-precision training
        if fp16:
            use_mixed_precision = True
        else:
            use_mixed_precision = cfg.DTYPE == "float16"

        amp_opt_level = 'O1' if use_mixed_precision else 'O0'
        model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)

    if distributed:
        if use_apex_ddp:
            model = DDP(model, delay_allreduce=True)
        else:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank], output_device=local_rank,
                # this should be removed if we update BatchNorm stats
                broadcast_buffers=False,
            )

    if is_main_process():
        print(model)

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader, iters_per_epoch = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    args.iters_per_epoch=iters_per_epoch
    # set the callback function to evaluate and potentially
    # early exit each epoch
    if cfg.PER_EPOCH_EVAL:
        per_iter_callback_fn = functools.partial(
                mlperf_test_early_exit,
                iters_per_epoch=iters_per_epoch,
                tester=functools.partial(test, cfg=cfg, dllogger=dllogger,args=args),
                model=model,
                distributed=distributed,
                min_bbox_map=cfg.MIN_BBOX_MAP,
                min_segm_map=cfg.MIN_MASK_MAP,
                args=args)
    else:
        per_iter_callback_fn = None

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        use_amp,
        cfg,
        dllogger,
        args,
        per_iter_end_callback_fn=per_iter_callback_fn,
        )

    return model, iters_per_epoch