in online_attacks/scripts/random_eval.py [0:0]
def random_model(model_dir, dataset, pattern="train_*.pth"):
while True:
if dataset == DatasetType.MNIST:
model_type = random.choice(list(MnistModel))
elif dataset == DatasetType.CIFAR:
model_type = random.choice(list(CifarModel))
else:
raise ValueError("%s not in DatasetType" % dataset)
list_models = glob.glob(
os.path.join(model_dir, dataset.value, model_type.value, pattern)
)
if len(list_models) > 0:
break
list_models = [
os.path.splitext(os.path.basename(model))[0] for model in list_models
]
model_name = random.choice(list_models)
return model_type, model_name