def run()

in data_utils/inception_tf13.py [0:0]


def run(config):
    assert (
        config["strat_name"] != ""
        and config["which_dataset"] == "imagenet_lt"
        and config["split"] == "val"
    ) or config["strat_name"] == ""
    # Inception with TF1.3 or earlier.
    # Call this function with list of images. Each of elements should be a
    # numpy array with values ranging from 0 to 255.
    def get_inception_score(images, splits=10, normalize=True):
        assert type(images) == list
        assert type(images[0]) == np.ndarray
        assert len(images[0].shape) == 3
        #  assert(np.max(images[0]) > 10)
        #  assert(np.min(images[0]) >= 0.0)
        inps = []
        for img in images:
            if normalize:
                img = np.uint8(255 * (img + 1) / 2.0)
            img = img.astype(np.float32)
            inps.append(np.expand_dims(img, 0))
        bs = config["batch_size"]
        with tf.Session() as sess:
            preds, pools = [], []
            n_batches = int(math.ceil(float(len(inps)) / float(bs)))
            for i in trange(n_batches):
                inp = inps[(i * bs) : min((i + 1) * bs, len(inps))]
                inp = np.concatenate(inp, 0)
                pred, pool = sess.run([softmax, pool3], {"ExpandDims:0": inp})
                preds.append(pred)
                pools.append(pool)
            preds = np.concatenate(preds, 0)
            scores = []
            for i in range(splits):
                part = preds[
                    (i * preds.shape[0] // splits) : (
                        (i + 1) * preds.shape[0] // splits
                    ),
                    :,
                ]
                kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
                kl = np.mean(np.sum(kl, 1))
                scores.append(np.exp(kl))
            return np.mean(scores), np.std(scores), np.squeeze(np.concatenate(pools, 0))

    # Init inception
    def _init_inception():
        global softmax, pool3
        if not os.path.exists(MODEL_DIR):
            os.makedirs(MODEL_DIR)
        filename = DATA_URL.split("/")[-1]
        filepath = os.path.join(MODEL_DIR, filename)
        if not os.path.exists(filepath):

            def _progress(count, block_size, total_size):
                sys.stdout.write(
                    "\r>> Downloading %s %.1f%%"
                    % (filename, float(count * block_size) / float(total_size) * 100.0)
                )
                sys.stdout.flush()

            filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
            print()
            statinfo = os.stat(filepath)
            print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
        tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
        with tf.gfile.FastGFile(
            os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
        ) as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(graph_def, name="")
        # Works with an arbitrary minibatch size.
        with tf.Session() as sess:
            pool3 = sess.graph.get_tensor_by_name("pool_3:0")
            ops = pool3.graph.get_operations()
            for op_idx, op in enumerate(ops):
                for o in op.outputs:
                    shape = o.get_shape()
                    shape = [s.value for s in shape]
                    new_shape = []
                    for j, s in enumerate(shape):
                        if s == 1 and j == 0:
                            new_shape.append(None)
                        else:
                            new_shape.append(s)
                    o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
            w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
            logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
            softmax = tf.nn.softmax(logits)

    # if softmax is None: # No need to functionalize like this.
    _init_inception()

    if config["use_ground_truth_data"]:
        # HDF5 file name
        if config["which_dataset"] in ["imagenet", "imagenet_lt"]:
            dataset_name_prefix = "ILSVRC"
        elif config["which_dataset"] == "coco":
            dataset_name_prefix = "COCO"
        hdf5_filename = "%s%i%s%s%s_xy.hdf5" % (
            dataset_name_prefix,
            config["resolution"],
            "longtail"
            if config["which_dataset"] == "imagenet_lt" and config["split"] == "train"
            else "",
            "_val" if config["split"] == "val" else "",
            "_test"
            if config["split"] == "val" and config["which_dataset"] == "coco"
            else "",
        )
        with h5.File(os.path.join(config["data_root"], hdf5_filename), "r") as f:
            data_imgs = f["imgs"][:]
            data_labels = f["labels"][:]
        ims = data_imgs.transpose(0, 2, 3, 1)

    else:
        if config["strat_name"] != "":
            fname = "%s/%s/samples%s_seed%i_strat_%s.pickle" % (
                config["experiment_root"],
                config["experiment_name"],
                "_kmeans" + str(config["kmeans_subsampled"])
                if config["kmeans_subsampled"] > -1
                else "",
                config["seed"],
                config["strat_name"],
            )
        else:
            fname = "%s/%s/samples%s_seed%i.pickle" % (
                config["experiment_root"],
                config["experiment_name"],
                "_kmeans" + str(config["kmeans_subsampled"])
                if config["kmeans_subsampled"] > -1
                else "",
                config["seed"],
            )
        print("loading %s ..." % fname)
        file_to_read = open(fname, "rb")
        ims = pickle.load(file_to_read)["x"]
        print("loading %s ..." % fname)
        print("number of images saved are ", len(ims))
        file_to_read.close()
        ims = ims.swapaxes(1, 2).swapaxes(2, 3)

    import time

    t0 = time.time()
    inc_mean, inc_std, pool_activations = get_inception_score(
        list(ims), splits=10, normalize=not config["use_ground_truth_data"]
    )
    t1 = time.time()
    print("Saving pool to numpy file for FID calculations...")
    mu = np.mean(pool_activations, axis=0)
    sigma = np.cov(pool_activations, rowvar=False)
    if config["use_ground_truth_data"]:
        np.savez(
            "%s/%s%s_res%i_tf_inception_moments_ground_truth.npz"
            % (
                config["data_root"],
                config["which_dataset"],
                "_val" if config["split"] == "val" else "",
                config["resolution"],
            ),
            **{"mu": mu, "sigma": sigma}
        )
    else:
        np.savez(
            "%s/%s/TF_pool%s_%s.npz"
            % (
                config["experiment_root"],
                config["experiment_name"],
                "_val" if config["split"] == "val" else "",
                "_strat_" + config["strat_name"] if config["strat_name"] != "" else "",
            ),
            **{"mu": mu, "sigma": sigma}
        )
    print(
        "Inception took %3f seconds, score of %3f +/- %3f."
        % (t1 - t0, inc_mean, inc_std)
    )

    # If ground-truth data moments, also compute the moments for stratified FID.
    if (
        config["split"] == "val"
        and config["which_dataset"] == "imagenet_lt"
        and config["use_ground_truth_data"]
    ):
        samples_per_class = np.load(
            "BigGAN_PyTorch/imagenet_lt/imagenet_lt_samples_per_class.npy",
            allow_pickle=True,
        )
        for strat_name in ["_many", "_low", "_few"]:
            if strat_name == "_many":
                pool_ = pool_activations[samples_per_class[data_labels] >= 100]
            elif strat_name == "_low":
                pool_ = pool_activations[samples_per_class[data_labels] < 100]
                labels_ = data_labels[samples_per_class[data_labels] < 100]
                pool_ = pool_[samples_per_class[labels_] > 20]
            elif strat_name == "_few":
                pool_ = pool_activations[samples_per_class[data_labels] <= 20]
            print("Size for strat ", strat_name, " is ", len(pool_))
            mu = np.mean(pool_, axis=0)
            sigma = np.cov(pool_, rowvar=False)

            np.savez(
                "%s/%s%s_res%i_tf_inception_moments%s_ground_truth.npz"
                % (
                    config["data_root"],
                    config["which_dataset"],
                    "_val" if config["split"] == "val" else "",
                    config["resolution"],
                    strat_name,
                ),
                **{"mu": mu, "sigma": sigma}
            )