in online_attacks/scripts/compute_hist_all.py [0:0]
def compute_hist(logger, model_type, model_name, list_records=None):
dir_name = os.path.join(model_type, model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
params = logger.load_hparams()
params = OmegaConf.structured(OnlineAttackParams(**params))
params = create_params(params)
if params.dataset == DatasetType.MNIST:
model_type = MnistModel(model_type)
elif params.dataset == DatasetType.CIFAR:
model_type = CifarModel(model_type)
dataset = load_dataset(params.dataset, train=False)
target_classifier = load_classifier(
params.dataset,
model_type,
name=model_name,
model_dir=params.model_dir,
device=device,
eval=True,
)
source_classifier = load_classifier(
params.dataset,
params.model_type,
name=params.model_name,
model_dir=params.model_dir,
device=device,
eval=True,
)
attacker = create_attacker(
source_classifier, params.attacker_type, params.attacker_params
)
target_transform = datastream.Compose(
[
datastream.ToDevice(device),
datastream.AttackerTransform(attacker),
datastream.ClassifierTransform(target_classifier),
datastream.LossTransform(CrossEntropyLoss(reduction="none")),
]
)
if list_records is None:
list_records = logger.list_all_records()
for record_name in tqdm.tqdm(list_records):
if logger.check_hist_exist(dir_name, record_name):
# print("Ignoring %s/%s, already exists."%(dir_name, record_name))
continue
record = logger.load_record(record_name)
permutation = record["permutation"]
eval_results = {}
for name, indices in record["indices"].items():
indices = [x[1] for x in indices]
target_stream = datastream.BatchDataStream(
dataset,
batch_size=params.batch_size,
transform=target_transform,
permutation=permutation,
)
stream = target_stream.subset(indices)
loss_values = compute_loss_values(stream)
eval_results[name] = loss_values
logger.save_hist(eval_results, dir_name, record_name)