def compute_hist()

in online_attacks/scripts/compute_hist_all.py [0:0]


def compute_hist(logger, model_type, model_name, list_records=None):
    dir_name = os.path.join(model_type, model_name)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    params = logger.load_hparams()
    params = OmegaConf.structured(OnlineAttackParams(**params))
    params = create_params(params)

    if params.dataset == DatasetType.MNIST:
        model_type = MnistModel(model_type)
    elif params.dataset == DatasetType.CIFAR:
        model_type = CifarModel(model_type)

    dataset = load_dataset(params.dataset, train=False)
    target_classifier = load_classifier(
        params.dataset,
        model_type,
        name=model_name,
        model_dir=params.model_dir,
        device=device,
        eval=True,
    )
    source_classifier = load_classifier(
        params.dataset,
        params.model_type,
        name=params.model_name,
        model_dir=params.model_dir,
        device=device,
        eval=True,
    )
    attacker = create_attacker(
        source_classifier, params.attacker_type, params.attacker_params
    )

    target_transform = datastream.Compose(
        [
            datastream.ToDevice(device),
            datastream.AttackerTransform(attacker),
            datastream.ClassifierTransform(target_classifier),
            datastream.LossTransform(CrossEntropyLoss(reduction="none")),
        ]
    )

    if list_records is None:
        list_records = logger.list_all_records()
    for record_name in tqdm.tqdm(list_records):
        if logger.check_hist_exist(dir_name, record_name):
            # print("Ignoring %s/%s, already exists."%(dir_name, record_name))
            continue
        record = logger.load_record(record_name)
        permutation = record["permutation"]
        eval_results = {}
        for name, indices in record["indices"].items():
            indices = [x[1] for x in indices]
            target_stream = datastream.BatchDataStream(
                dataset,
                batch_size=params.batch_size,
                transform=target_transform,
                permutation=permutation,
            )
            stream = target_stream.subset(indices)
            loss_values = compute_loss_values(stream)
            eval_results[name] = loss_values
        logger.save_hist(eval_results, dir_name, record_name)