def train()

in experiments/train_imagenet.py [0:0]


def train(cfg, is_leader):

    np.random.seed(cfg.rng_seed)
    torch.manual_seed(cfg.rng_seed)

    log.info(cfg.pretty())
    cur_device = torch.cuda.current_device()
    model = instantiate(cfg.model).cuda(device=cur_device)
    if cfg.num_gpus > 1:
        model = torch.nn.parallel.DistributedDataParallel(
                module=model,
                device_ids=[cur_device],
                output_device=cur_device
                )
    optimizer = instantiate(cfg.optim, model.parameters())
    if cfg.optim.max_epoch > 0 and cfg.train.weights is None:
        print("Loading training set...")
        train_dataset = instantiate(cfg.train)
    else:
        print("Skipping loading the training dataset, 0 epochs of training to perform "
        " or pre-trained weights provided.")
        train_dataset = None
    print("Loading test set...")
    test_dataset = instantiate(cfg.test)
    lr_policy = instantiate(cfg.optim.lr_policy)  

    print("Training...")
    train_net(model=model,
            optimizer=optimizer,
            train_dataset=train_dataset,
            batch_size=cfg.train.batch_size,
            max_epoch=cfg.optim.max_epoch,
            loader_params=cfg.data_loader,
            lr_policy=lr_policy,
            save_period=cfg.train.checkpoint_period,
            weights=cfg.train.weights,
            num_gpus=cfg.num_gpus,
            is_leader=is_leader
            )

    print("Testing...")
    err = test_net(model=model,
            test_dataset=test_dataset,
            batch_size=cfg.test.batch_size,
            loader_params=cfg.data_loader,
            num_gpus=cfg.num_gpus)

    test_corrupt_net(model=model,
            corrupt_cfg=cfg.corrupt,
            batch_size=cfg.corrupt.batch_size,
            loader_params=cfg.data_loader,
            aug_string=cfg.corrupt.aug_string,
            clean_err=err,
            mCE_denom=cfg.corrupt.mCE_baseline_file,
            num_gpus=cfg.num_gpus,
            log_name='train_imagenet.log')