def train_model()

in train_indep.py [0:0]


def train_model(config):
    # Get network, data, and training parameters
    (
        net,
        opt_params,
        train_params,
        regime_params,
        data,
    ) = utils.network_and_params(config)

    # Unpack training parameters
    epochs, test_freq, _ = train_params

    # Unpack dataset
    trainset, trainloader, testset, testloader = data

    # Unpack optimizer parameters
    criterion, optimizer, scheduler = opt_params

    device = "cuda" if torch.cuda.is_available() else "cpu"
    net = net.to(device)

    if device == "cuda":
        cudnn.benchmark = True
    net = torch.nn.DataParallel(net)

    start_epoch = 0

    metric_dict = {"acc": [], "alpha": []}
    if regime_params["regime"] == "quantized":
        metric_dict["num_bits"] = []
    if regime_params["regime"] in (
        "sparse",
        "lec",
        "us",
        "ns",
    ):
        metric_dict["sparsity"] = []

    # Get evaluation parameter grid
    eval_param_grid = regime_params["eval_param_grid"]

    save_dir = regime_params["save_dir"]

    # Training loop
    for epoch in range(start_epoch, start_epoch + epochs):
        scheduler(epoch, None)
        if epoch % test_freq == 0:
            for param in eval_param_grid:
                test(
                    net,
                    param,
                    testloader,
                    criterion,
                    epoch,
                    device,
                    metric_dict=None,
                    **regime_params,
                )

            model_logging.save_model_at_epoch(net, epoch, save_dir)

        _, regime_params = train(
            net,
            trainloader,
            optimizer,
            criterion,
            epoch,
            device,
            **regime_params,
        )

    # Save final model
    for param in eval_param_grid:
        metric_dict = test(
            net,
            param,
            testloader,
            criterion,
            epoch,
            device,
            metric_dict=metric_dict,
            **regime_params,
        )

    model_logging.save_model_at_epoch(net, epoch + 1, save_dir)
    model_logging.save("test_metrics.npy", metric_dict, save_dir)