def network_and_params()

in utils.py [0:0]


def network_and_params(config=None):
    """
    Returns the network and training parameters for the specified model type.
    """

    model_config = config["parameters"].get("model_config")
    model_class = model_config["model_class"]
    model_class_dict = {
        "cpreresnet20": networks.cpreresnet,
        "resnet18": networks.ResNet18,
        "vgg19": networks.vgg19,
    }
    if model_config is not None:
        # Get the class.
        if model_class in model_class_dict:
            model_class = model_class_dict[model_class]
        else:
            raise NotImplementedError(
                f"Invalid model_class={model_config['model_class']}"
            )
        if "model_kwargs" in model_config:
            extra_model_kwargs = model_config["model_kwargs"]
        else:
            extra_model_kwargs = {}
    else:
        extra_model_kwargs = {}

    # General params
    epochs = config["parameters"].get("epochs", 200)

    test_freq = config["parameters"].get("test_freq", 20)
    batch_size = config["parameters"].get("batch_size", 128)
    learning_rate = config["parameters"].get("learning_rate", 0.01)
    momentum = config["parameters"].get("momentum", 0.9)
    weight_decay = config["parameters"].get("weight_decay", 0.0005)
    warmup_budget = config["parameters"].get("warmup_budget", 80) / 100.0
    dataset = config["parameters"].get("dataset", "cifar10")
    alpha_grid = config["parameters"].get("alpha_grid", None)

    regime = config["parameters"]["regime"]

    if dataset == "cifar10":
        data = get_cifar10_data(batch_size)
    elif dataset == "imagenet":
        imagenet_dir = config["parameters"]["dataset_dir"]
        data = get_imagenet_data(imagenet_dir, batch_size)
    else:
        raise ValueError(f"Dataset {dataset} not supported")

    train_size = len(data[0])
    warmup_iterations = np.ceil(
        warmup_budget * epochs * train_size / batch_size
    )

    # Get model layers
    conv_type = config["parameters"]["conv_type"]
    bn_type = config["parameters"]["bn_type"]
    block_conv_type = config["parameters"]["block_conv_type"]
    block_bn_type = config["parameters"]["block_bn_type"]

    # Get regime-specific parameters
    regime_params = config["parameters"].get("regime_params", {})
    regime_params["regime"] = regime
    builder_parms = {}
    block_builder_params = {}

    # Append dataset to extra_model_kwargs args
    extra_model_kwargs["dataset"] = dataset

    if regime == "sparse":
        if "Sparse" not in block_conv_type:
            raise ValueError(
                "Regime set to sparse but non-sparse convolution layer received..."
            )
        regime_params["topk"] = config["parameters"].get("topk", 0.0)
        regime_params["current_iteration"] = 0
        regime_params["warmup_iterations"] = warmup_iterations
        regime_params["alpha_sampling"] = config["parameters"].get(
            "alpha_sampling", [0.025, 1, 0]
        )

        method = config["parameters"].get("method", "topk")
        block_builder_params["method"] = method
        if "Sparse" in conv_type:
            builder_parms["method"] = method

    elif regime == "lec":
        regime_params["topk"] = config["parameters"].get("topk", 0.0)
        regime_params["current_iteration"] = 0
        regime_params["warmup_iterations"] = warmup_iterations
        regime_params["alpha_sampling"] = config["parameters"].get(
            "alpha_sampling", [0, 1, 0]
        )
        regime_params["model_kwargs"] = {
            "dataset": dataset,
            **extra_model_kwargs,
        }
        regime_params["bn_update_factor"] = config["parameters"].get(
            "bn_update_factor", 0
        )

        regime_params["bn_type"] = config["parameters"]["bn_type"]

    elif regime == "ns":
        width_factors_list = config["parameters"]["builder_kwargs"][
            "width_factors_list"
        ]
        regime_params["width_factors_list"] = width_factors_list

        builder_parms["pass_first_last"] = True

        block_builder_params["pass_first_last"] = True

        if config["parameters"]["block_conv_type"] != "AdaptiveConv2d":
            block_builder_params["width_factors_list"] = width_factors_list
            builder_parms["width_factors_list"] = width_factors_list

        if "BN" in config["parameters"]["bn_type"]:
            norm_kwargs = config["parameters"].get("norm_kwargs", {})

            builder_parms["norm_kwargs"] = {
                "width_factors_list": width_factors_list,
                **norm_kwargs,
            }
            block_builder_params["norm_kwargs"] = {
                "width_factors_list": width_factors_list,
                **norm_kwargs,
            }

        regime_params["bn_type"] = config["parameters"]["bn_type"]

    elif regime == "us":
        builder_parms["pass_first_last"] = True
        block_builder_params["pass_first_last"] = True

        if "BN" in config["parameters"]["bn_type"]:
            norm_kwargs = config["parameters"]["norm_kwargs"]
            assert "width_factors_list" in norm_kwargs

            block_builder_params["norm_kwargs"] = norm_kwargs
            builder_parms["norm_kwargs"] = norm_kwargs

    elif regime == "quantized":
        if "ConvBn2d" not in block_conv_type:
            raise ValueError(
                "Regime set to quanitzed but non-quantized convolution layer received..."
            )
        block_builder_params["num_bits"] = config["parameters"].get(
            "num_bits", 8
        )
        block_builder_params["iteration_delay"] = warmup_iterations

        if conv_type == "ConvBn2d":
            builder_parms["num_bits"] = config["parameters"].get("num_bits", 8)
            builder_parms["iteration_delay"] = warmup_iterations

        regime_params["min_bits"] = config["parameters"].get("min_bits", 2)
        regime_params["max_bits"] = config["parameters"].get("max_bits", 8)
        regime_params["num_bits"] = config["parameters"].get("num_bits", 8)
        regime_params["discrete"] = config["parameters"].get(
            "discrete_alpha_map", False
        )

    regime_params["is_standard"] = config["parameters"].get(
        "is_standard", False
    )
    regime_params["random_param"] = config["parameters"].get(
        "random_param", False
    )
    regime_params["num_points"] = config["parameters"].get("num_points", 0)

    # Evaluation parameters for independent models
    regime_params["eval_param_grid"] = config["parameters"].get(
        "eval_param_grid", None
    )

    # If norm_kwargs haven't been set, and they are present, add them to
    # builder_params and block_builder_params.
    if "norm_kwargs" not in builder_parms:
        builder_parms["norm_kwargs"] = config["parameters"].get(
            "norm_kwargs", {}
        )
    if "norm_kwargs" not in block_builder_params:
        block_builder_params["norm_kwargs"] = config["parameters"].get(
            "norm_kwargs", {}
        )

    # Construct network
    builder = Builder(conv_type=conv_type, bn_type=bn_type, **builder_parms)
    block_builder = Builder(
        block_conv_type, block_bn_type, **block_builder_params
    )

    net = model_class(
        builder=builder, block_builder=block_builder, **extra_model_kwargs
    )

    # Input size
    regime_params["input_size"] = get_input_size(dataset)

    # Save directory
    regime_params["save_dir"] = config["parameters"]["save_dir"]

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        net.parameters(),
        lr=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay,
    )

    scheduler = schedulers.cosine_lr(
        optimizer, learning_rate, warmup_length=5, epochs=epochs
    )

    train_params = [epochs, test_freq, alpha_grid]
    opt_params = [criterion, optimizer, scheduler]

    print(f"Got network:\n{net}")

    return net, opt_params, train_params, regime_params, data