def main()

in main_train.py [0:0]


def main(opts, **kwargs):
    num_gpus = getattr(opts, "dev.num_gpus", 0) # defaults are for CPU
    dev_id = getattr(opts, "dev.device_id", torch.device('cpu'))
    device = getattr(opts, "dev.device", torch.device('cpu'))
    is_distributed = getattr(opts, "ddp.use_distributed", False)

    is_master_node = is_master(opts)

    # set-up data loaders
    train_loader, val_loader, train_sampler = create_train_val_loader(opts)

    # compute max iterations based on max epochs
    # Useful in doing polynomial decay
    is_iteration_based = getattr(opts, "scheduler.is_iteration_based", False)
    if is_iteration_based:
        max_iter = getattr(opts, "scheduler.max_iterations", DEFAULT_ITERATIONS)
        if max_iter is None or max_iter <= 0:
            logger.log('Setting max. iterations to {}'.format(DEFAULT_ITERATIONS))
            setattr(opts, "scheduler.max_iterations", DEFAULT_ITERATIONS)
            max_iter = DEFAULT_ITERATIONS
        setattr(opts, "scheduler.max_epochs", DEFAULT_MAX_EPOCHS)
        if is_master_node:
            logger.log('Max. iteration for training: {}'.format(max_iter))
    else:
        max_epochs = getattr(opts, "scheduler.max_epochs", DEFAULT_EPOCHS)
        if max_epochs is None or max_epochs <= 0:
            logger.log('Setting max. epochs to {}'.format(DEFAULT_EPOCHS))
            setattr(opts, "scheduler.max_epochs", DEFAULT_EPOCHS)
        setattr(opts, "scheduler.max_iterations", DEFAULT_MAX_ITERATIONS)
        max_epochs = getattr(opts, "scheduler.max_epochs", DEFAULT_EPOCHS)
        if is_master_node:
            logger.log('Max. epochs for training: {}'.format(max_epochs))
    # set-up the model
    model = get_model(opts)

    if num_gpus == 0:
        logger.error('Need atleast 1 GPU for training. Got {} GPUs'.format(num_gpus))
    elif num_gpus == 1:
        model = model.to(device=device)
    elif is_distributed:
        model = model.to(device=device)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
        if is_master_node:
            logger.log('Using DistributedDataParallel for training')
    else:
        model = torch.nn.DataParallel(model)
        model = model.to(device=device)
        if is_master_node:
            logger.log('Using DataParallel for training')

    # setup criteria
    criteria = build_loss_fn(opts)
    criteria = criteria.to(device=device)

    # create the optimizer
    optimizer = build_optimizer(model, opts=opts)

    # create the gradient scalar
    gradient_scalar = GradScaler(
        enabled=getattr(opts, "common.mixed_precision", False)
    )

    # LR scheduler
    scheduler = build_scheduler(opts=opts)

    model_ema = None
    use_ema = getattr(opts, "ema.enable", False)

    if use_ema:
        ema_momentum = getattr(opts, "ema.momentum", 0.0001)
        model_ema = EMA(
            model=model,
            ema_momentum=ema_momentum,
            device=device
        )
        if is_master_node:
            logger.log('Using EMA')

    best_metric = 0.0 if getattr(opts, "stats.checkpoint_metric_max", False) else math.inf

    start_epoch = 0
    start_iteration = 0
    resume_loc = getattr(opts, "common.resume", None)
    finetune_loc = getattr(opts, "common.finetune", None)
    auto_resume = getattr(opts, "common.auto_resume", False)
    if resume_loc is not None or auto_resume:
        model, optimizer, gradient_scalar, start_epoch, start_iteration, best_metric, model_ema = load_checkpoint(
            opts=opts,
            model=model,
            optimizer=optimizer,
            model_ema=model_ema,
            gradient_scalar=gradient_scalar
        )
    elif finetune_loc is not None:
        model, model_ema = load_model_state(opts=opts, model=model, model_ema=model_ema)
        if is_master_node:
            logger.log('Finetuning model from checkpoint {}'.format(finetune_loc))

    training_engine = Trainer(opts=opts,
                              model=model,
                              validation_loader=val_loader,
                              training_loader=train_loader,
                              optimizer=optimizer,
                              criterion=criteria,
                              scheduler=scheduler,
                              start_epoch=start_epoch,
                              start_iteration=start_iteration,
                              best_metric=best_metric,
                              model_ema=model_ema,
                              gradient_scalar=gradient_scalar
                              )

    training_engine.run(train_sampler=train_sampler)