def main()

in online_attacks/experiments/mnist_online.py [0:0]


def main(args, params: MnistParams = MnistParams()):
    dataset = load_mnist_dataset(train=False)
    dataset = datastream.PermuteDataset(
        dataset, permutation=np.random.permutation(len(dataset))
    )

    target_classifier = load_mnist_classifier(
        args.model_type,
        index=0,
        model_dir=params.model_dir,
        device=args.device,
        eval=True,
    )
    source_classifier = load_mnist_classifier(
        args.model_type,
        index=1,
        model_dir=params.model_dir,
        device=args.device,
        eval=True,
    )

    criterion = CrossEntropyLoss(reduction="none")

    attacker = create_attacker(
        source_classifier, params.attacker_type, params.attacker_params
    )

    transform = datastream.Compose(
        [
            datastream.ToDevice(args.device),
            datastream.AttackerTransform(attacker),
            datastream.ClassifierTransform(target_classifier),
            datastream.LossTransform(criterion),
        ]
    )
    target_stream = datastream.BatchDataStream(
        dataset, batch_size=1000, transform=transform
    )

    transform = datastream.Compose(
        [
            datastream.ToDevice(args.device),
            datastream.AttackerTransform(attacker),
            datastream.ClassifierTransform(source_classifier),
            datastream.LossTransform(criterion),
        ]
    )
    source_stream = datastream.BatchDataStream(
        dataset, batch_size=1000, transform=transform
    )

    params.online_params.N = len(dataset)
    offline_algorithm, online_algorithm = create_online_algorithm(params.online_params)
    print("Computing indices...")
    source_online_indices, source_offline_indices = compute_indices(
        source_stream, [online_algorithm, offline_algorithm], pbar_flag=True
    )
    target_offline_indices = compute_indices(
        target_stream, [offline_algorithm], pbar_flag=True
    )[0]

    print("Computing Competitive Ratio...")
    comp_ratio = compute_competitive_ratio(
        source_online_indices, source_offline_indices
    )
    print(
        "Comp ratio source online vs target online: %.2f"
        % (comp_ratio / params.online_params.K)
    )
    comp_ratio = compute_competitive_ratio(
        source_online_indices, target_offline_indices
    )
    print(
        "Comp ratio source online vs target offline: %.2f"
        % (comp_ratio / params.online_params.K)
    )
    comp_ratio = compute_competitive_ratio(
        source_offline_indices, target_offline_indices
    )
    print(
        "Comp ratio target online vs target offline: %.2f"
        % (comp_ratio / params.online_params.K)
    )

    transform = datastream.Compose(
        [
            datastream.ToDevice(args.device),
            datastream.AttackerTransform(attacker),
            datastream.ClassifierTransform(target_classifier),
        ]
    )
    target_stream = datastream.BatchDataStream(
        dataset, batch_size=1000, transform=transform
    )

    random_indices = np.random.permutation(len(dataset))[: params.online_params.K]
    stream = target_stream.subset(random_indices)
    fool_rate = compute_attack_success_rate(stream)
    print("Attack success rate (Random): %.4f" % fool_rate)

    source_online_indices = [x[1] for x in source_online_indices]
    stream = target_stream.subset(source_online_indices)
    fool_rate = compute_attack_success_rate(stream)
    print("Attack success rate (Source Online): %.4f" % fool_rate)

    source_offline_indices = [x[1] for x in source_offline_indices]
    stream = target_stream.subset(source_offline_indices)
    fool_rate = compute_attack_success_rate(stream)
    print("Attack success rate (Source Offline): %.4f" % fool_rate)

    target_offline_indices = [x[1] for x in target_offline_indices]
    stream = target_stream.subset(target_offline_indices)
    fool_rate = compute_attack_success_rate(stream)
    print("Attack success rate (Source Offline): %.4f" % fool_rate)

    stream = datastream.BatchDataStream(dataset)
    stream = stream.subset(source_online_indices, batch_size=len(source_online_indices))
    x, target = next(stream)
    output_dir = os.path.join(args.output_dir, "img")
    os.makedirs(output_dir, exist_ok=True)
    torchvision.utils.save_image(
        x,
        os.path.join(output_dir, "clean.png"),
        nrow=int(math.sqrt(len(source_online_indices))),
    )

    x = attacker.perturb(x.to(args.device))
    torchvision.utils.save_image(
        x,
        os.path.join(output_dir, "attack.png"),
        nrow=int(math.sqrt(len(source_online_indices))),
    )

    return comp_ratio