def random_model()

in online_attacks/scripts/random_eval.py [0:0]


def random_model(model_dir, dataset, pattern="train_*.pth"):
    while True:
        if dataset == DatasetType.MNIST:
            model_type = random.choice(list(MnistModel))
        elif dataset == DatasetType.CIFAR:
            model_type = random.choice(list(CifarModel))
        else:
            raise ValueError("%s not in DatasetType" % dataset)

        list_models = glob.glob(
            os.path.join(model_dir, dataset.value, model_type.value, pattern)
        )
        if len(list_models) > 0:
            break

    list_models = [
        os.path.splitext(os.path.basename(model))[0] for model in list_models
    ]
    model_name = random.choice(list_models)

    return model_type, model_name