def run()

in data_utils/calculate_inception_moments.py [0:0]


def run(config):
    # Get dataset and loader
    kwargs = {
        "num_workers": config["num_workers"],
        "pin_memory": False,
        "drop_last": False,
        "load_in_mem": config["load_in_mem"],
    }
    if config["which_dataset"] in ["imagenet", "imagenet_lt"]:
        dataset_name_prefix = "ILSVRC"
    elif config["which_dataset"] == "coco":
        dataset_name_prefix = "COCO"
    test_part = False
    if config["which_dataset"] == "coco" and config["split"] == "val":
        test_part = True

    # Using hdf5 filename
    dataset = utils.get_dataset_hdf5(
        config["resolution"],
        data_path=config["data_root"],
        longtail=config["which_dataset"] == "imagenet_lt"
        and config["split"] == "train",
        split=config["split"],
        load_in_mem=config["load_in_mem"],
        which_dataset=config["which_dataset"],
        test_part=test_part,
    )

    loader = utils.get_dataloader(
        dataset, config["batch_size"], shuffle=False, **kwargs
    )

    # Load inception net
    net = inception_utils.load_inception_net(parallel=config["parallel"])
    device = "cuda"

    # Accumulate logits
    pool, logits, labels = [], [], []
    for i, batch in enumerate(tqdm(loader)):
        (x, y) = (batch[0], batch[1])
        x = x.to(device)
        with torch.no_grad():
            pool_val, logits_val = net(x)
            pool += [np.asarray(pool_val.cpu())]
            logits += [np.asarray(F.softmax(logits_val, 1).cpu())]
            labels += [np.asarray(y.cpu())]

    pool, logits, labels = [np.concatenate(item, 0) for item in [pool, logits, labels]]

    print("Calculating inception metrics...")
    IS_mean, IS_std = inception_utils.calculate_inception_score(logits)
    print(
        "Training data from dataset %s has IS of %5.5f +/- %5.5f"
        % (config["which_dataset"], IS_mean, IS_std)
    )
    # Prepare mu and sigma, save to disk. Remove "hdf5" by default
    # (the FID code also knows to strip "hdf5")
    print("Calculating means and covariances...")
    mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
    print("Saving calculated means and covariances to disk...")
    if config["which_dataset"] in ["imagenet", "imagenet_lt"]:
        dataset_name_prefix = "I"
    elif config["which_dataset"] == "coco":
        dataset_name_prefix = "COCO"
    np.savez(
        os.path.join(
            config["out_path"],
            "%s%i_%s%s%s_inception_moments.npz"
            % (
                dataset_name_prefix,
                config["resolution"],
                "longtail"
                if config["which_dataset"] == "imagenet_lt"
                and config["split"] == "train"
                else "",
                "_val" if config["split"] == "val" else "",
                "_test" if test_part else "",
            ),
        ),
        **{"mu": mu, "sigma": sigma}
    )
    # Compute stratified moments for ImageNet-LT dataset
    if config["stratified_moments"]:
        samples_per_class = np.load(
            "BigGAN_PyTorch/imagenet_lt/imagenet_lt_samples_per_class.npy",
            allow_pickle=True,
        )
        for strat_name in ["_many", "_low", "_few"]:
            if strat_name == "_many":
                logits_ = logits[samples_per_class[labels] >= 100]
                pool_ = pool[samples_per_class[labels] >= 100]
            elif strat_name == "_low":
                logits_ = logits[samples_per_class[labels] < 100]
                pool_ = pool[samples_per_class[labels] < 100]
                labels_ = labels[samples_per_class[labels] < 100]
                logits_ = logits_[samples_per_class[labels_] > 20]
                pool_ = pool_[samples_per_class[labels_] > 20]
            elif strat_name == "_few":
                logits_ = logits[samples_per_class[labels] <= 20]
                pool_ = pool[samples_per_class[labels] <= 20]
            print(
                "Calculating inception metrics for strat ",
                strat_name,
                " with number of samples ",
                len(logits_),
                "...",
            )
            IS_mean, IS_std = inception_utils.calculate_inception_score(logits_)
            print(
                "Training data from dataset %s has IS of %5.5f +/- %5.5f"
                % (config["which_dataset"], IS_mean, IS_std)
            )
            # Prepare mu and sigma, save to disk. Remove "hdf5" by default
            # (the FID code also knows to strip "hdf5")
            print("Calculating means and covariances...")
            mu, sigma = np.mean(pool_, axis=0), np.cov(pool_, rowvar=False)
            print("Saving calculated means and covariances to disk...")
            np.savez(
                os.path.join(
                    config["data_root"],
                    "%s%i__val%s_inception_moments.npz"
                    % (dataset_name_prefix, config["resolution"], strat_name),
                ),
                **{"mu": mu, "sigma": sigma}
            )