def objective()

in benchmarking/training_scripts/resnet_cifar10/resnet_cifar10.py [0:0]


def objective(config):
    torch.manual_seed(np.random.randint(10000))
    batch_size = config['batch_size']
    lr = config['lr']
    momentum = config['momentum']
    weight_decay = config['weight_decay']
    num_gpus = config.get('num_gpus')
    if num_gpus is None:
        num_gpus = 1
    trial_id = config.get('trial_id')
    debug_log = trial_id is not None
    if debug_log:
        print("Trial {}: Starting evaluation".format(trial_id), flush=True)

    path = config['dataset_path']
    os.makedirs(path, exist_ok=True)
    # Lock protection is needed for backends which run multiple worker
    # processes on the same instance
    lock_path = os.path.join(path, 'lock')
    lock = SoftFileLock(lock_path)
    try:
        with lock.acquire(timeout=120, poll_intervall=1):
            input_size, num_classes, train_dataset, valid_dataset = get_CIFAR10(
                root=path)
    except Timeout:
        print(
            "WARNING: Could not obtain lock for dataset files. Trying anyway...",
            flush=True)
        input_size, num_classes, train_dataset, valid_dataset = get_CIFAR10(
            root=path)

    # Do not want to count the time to download the dataset, which can be
    # substantial the first time
    ts_start = time.time()
    report = Reporter()

    indices = list(range(train_dataset.data.shape[0]))
    train_idx, valid_idx = indices[:40000], indices[40000:]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               # shuffle=True,
                                               num_workers=0,
                                               sampler=train_sampler,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=128,
                                               # shuffle=False,
                                               num_workers=0,
                                               sampler=valid_sampler,
                                               pin_memory=True)

    model = Model()
    if torch.cuda.is_available():
        model = model.cuda()
        device = torch.device("cuda")
        # print(device)
        model = torch.nn.DataParallel(
            model, device_ids=[i for i in range(num_gpus)]).to(device)
    milestones = [25, 40]
    optimizer = torch.optim.SGD(
        model.parameters(), lr=lr, momentum=momentum,
        weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.1)

    # Checkpointing
    load_model_fn, save_model_fn = pytorch_load_save_functions(
        {'model': model, 'optimizer': optimizer, 'lr_scheduler': scheduler})
    # Resume from checkpoint (optional)
    resume_from = resume_from_checkpointed_model(config, load_model_fn)

    for epoch in range(resume_from + 1, config['epochs'] + 1):
        train(model, train_loader, optimizer)
        loss, y = valid(model, valid_loader)
        scheduler.step()
        elapsed_time = time.time() - ts_start

        # Feed the score back back to Tune.
        report(**{RESOURCE_ATTR: epoch,
                  METRIC_NAME: y,
                  ELAPSED_TIME_ATTR: elapsed_time})

        # Write checkpoint (optional)
        checkpoint_model_at_rung_level(config, save_model_fn, epoch)

        if debug_log:
            print("Trial {}: epoch = {}, objective = {:.3f}, elapsed_time = {:.2f}".format(
                trial_id, epoch, y, elapsed_time), flush=True)