in online_attacks/experiments/mnist_online.py [0:0]
def main(args, params: MnistParams = MnistParams()):
dataset = load_mnist_dataset(train=False)
dataset = datastream.PermuteDataset(
dataset, permutation=np.random.permutation(len(dataset))
)
target_classifier = load_mnist_classifier(
args.model_type,
index=0,
model_dir=params.model_dir,
device=args.device,
eval=True,
)
source_classifier = load_mnist_classifier(
args.model_type,
index=1,
model_dir=params.model_dir,
device=args.device,
eval=True,
)
criterion = CrossEntropyLoss(reduction="none")
attacker = create_attacker(
source_classifier, params.attacker_type, params.attacker_params
)
transform = datastream.Compose(
[
datastream.ToDevice(args.device),
datastream.AttackerTransform(attacker),
datastream.ClassifierTransform(target_classifier),
datastream.LossTransform(criterion),
]
)
target_stream = datastream.BatchDataStream(
dataset, batch_size=1000, transform=transform
)
transform = datastream.Compose(
[
datastream.ToDevice(args.device),
datastream.AttackerTransform(attacker),
datastream.ClassifierTransform(source_classifier),
datastream.LossTransform(criterion),
]
)
source_stream = datastream.BatchDataStream(
dataset, batch_size=1000, transform=transform
)
params.online_params.N = len(dataset)
offline_algorithm, online_algorithm = create_online_algorithm(params.online_params)
print("Computing indices...")
source_online_indices, source_offline_indices = compute_indices(
source_stream, [online_algorithm, offline_algorithm], pbar_flag=True
)
target_offline_indices = compute_indices(
target_stream, [offline_algorithm], pbar_flag=True
)[0]
print("Computing Competitive Ratio...")
comp_ratio = compute_competitive_ratio(
source_online_indices, source_offline_indices
)
print(
"Comp ratio source online vs target online: %.2f"
% (comp_ratio / params.online_params.K)
)
comp_ratio = compute_competitive_ratio(
source_online_indices, target_offline_indices
)
print(
"Comp ratio source online vs target offline: %.2f"
% (comp_ratio / params.online_params.K)
)
comp_ratio = compute_competitive_ratio(
source_offline_indices, target_offline_indices
)
print(
"Comp ratio target online vs target offline: %.2f"
% (comp_ratio / params.online_params.K)
)
transform = datastream.Compose(
[
datastream.ToDevice(args.device),
datastream.AttackerTransform(attacker),
datastream.ClassifierTransform(target_classifier),
]
)
target_stream = datastream.BatchDataStream(
dataset, batch_size=1000, transform=transform
)
random_indices = np.random.permutation(len(dataset))[: params.online_params.K]
stream = target_stream.subset(random_indices)
fool_rate = compute_attack_success_rate(stream)
print("Attack success rate (Random): %.4f" % fool_rate)
source_online_indices = [x[1] for x in source_online_indices]
stream = target_stream.subset(source_online_indices)
fool_rate = compute_attack_success_rate(stream)
print("Attack success rate (Source Online): %.4f" % fool_rate)
source_offline_indices = [x[1] for x in source_offline_indices]
stream = target_stream.subset(source_offline_indices)
fool_rate = compute_attack_success_rate(stream)
print("Attack success rate (Source Offline): %.4f" % fool_rate)
target_offline_indices = [x[1] for x in target_offline_indices]
stream = target_stream.subset(target_offline_indices)
fool_rate = compute_attack_success_rate(stream)
print("Attack success rate (Source Offline): %.4f" % fool_rate)
stream = datastream.BatchDataStream(dataset)
stream = stream.subset(source_online_indices, batch_size=len(source_online_indices))
x, target = next(stream)
output_dir = os.path.join(args.output_dir, "img")
os.makedirs(output_dir, exist_ok=True)
torchvision.utils.save_image(
x,
os.path.join(output_dir, "clean.png"),
nrow=int(math.sqrt(len(source_online_indices))),
)
x = attacker.perturb(x.to(args.device))
torchvision.utils.save_image(
x,
os.path.join(output_dir, "attack.png"),
nrow=int(math.sqrt(len(source_online_indices))),
)
return comp_ratio