def predict()

in predict.py [0:0]


    def predict(self, image, gen_model="icgan", conditional_class=None, num_samples=1, seed=0):
        assert isinstance(seed, int), "seed should be an integer"
        if gen_model == 'cc_icgan':
            assert conditional_class is not None, 'please set conditional_class for cc_icgan'
        num_samples_ranked = num_samples
        experiment_name = (
            "icgan_biggan_imagenet_res256"
            if gen_model == "icgan"
            else "cc_icgan_biggan_imagenet_res256"
        )
        num_samples_total = num_samples * 10
        truncation = 0.7
        if conditional_class is not None:
            class_index = NAME2IND[conditional_class]

        input_image_instance = str(image)

        if gen_model == "icgan":
            class_index = None

        if seed == 0:
            seed = None

        state = None if not seed else np.random.RandomState(seed)
        np.random.seed(seed)

        feature_extractor_name = ("classification" if gen_model == "cc_icgan" else "selfsupervised")

        # Load feature extractor (outlier filtering and optionally input image feature extraction)
        self.feature_extractor, self.last_feature_extractor = load_feature_extractor(
            gen_model, self.last_feature_extractor, self.feature_extractor)
        # Load features
        if input_image_instance not in ["None", "", None]:
            print("Obtaining instance features from input image!")
            input_feature_index = None
            input_image_tensor = preprocess_input_image(input_image_instance, self.size)
            with torch.no_grad():
                input_features, _ = self.feature_extractor(input_image_tensor.cuda())
            input_features /= torch.linalg.norm(input_features, dim=-1, keepdims=True)
        elif input_feature_index is not None:
            print("Selecting an instance from pre-extracted vectors!")
            input_features = np.load(
                "stored_instances/imagenet_res"
                + str(self.size)
                + "_rn50_"
                + feature_extractor_name
                + "_kmeans_k1000_instance_features.npy",
                allow_pickle=True,
            ).item()["instance_features"][input_feature_index: input_feature_index + 1]
        else:
            input_features = None

        # Load generative model
        self.model, self.last_gen_model = load_generative_model(
            gen_model, self.last_gen_model, experiment_name, self.model)
        # Prepare other variables

        replace_to_inplace_relu(self.model)

        # Create noise, instance and class vector
        noise_vector = truncnorm.rvs(
            -2 * truncation,
            2 * truncation,
            size=(num_samples_total, self.noise_size),
            random_state=state,
        ).astype(np.float32)
        noise_vector = torch.tensor(noise_vector, requires_grad=False, device="cuda")
        if input_features is not None:
            instance_vector = torch.tensor(
                input_features, requires_grad=False, device="cuda"
            ).repeat(num_samples_total, 1)
        else:
            instance_vector = None
        if class_index is not None:
            input_label = torch.LongTensor([class_index] * num_samples_total)
        else:
            input_label = None
        if input_feature_index is not None:
            print("Conditioning on instance with index: ", input_feature_index)

        all_outs, all_dists = [], []
        for i_bs in range(num_samples_total // self.batch_size + 1):
            start = i_bs * self.batch_size
            end = min(start + self.batch_size, num_samples_total)
            if start == end:
                break
            out = get_output(
                noise_vector[start:end],
                input_label[start:end] if input_label is not None else None,
                instance_vector[start:end] if instance_vector is not None else None,
                self.model,
                truncation,
                channels=3,
            )

            if instance_vector is not None:
                # Get features from generated images + feature extractor
                out_ = preprocess_generated_image(out)
                with torch.no_grad():
                    out_features, _ = self.feature_extractor(out_.cuda())
                out_features /= torch.linalg.norm(out_features, dim=-1, keepdims=True)
                dists = sklearn.metrics.pairwise_distances(
                    out_features.cpu(),
                    instance_vector[start:end].cpu(),
                    metric="euclidean",
                    n_jobs=-1,
                )
                all_dists.append(np.diagonal(dists))
                all_outs.append(out.detach().cpu())
            del out
        all_outs = torch.cat(all_outs)
        all_dists = np.concatenate(all_dists)

        # Order samples by distance to conditioning feature vector and select only num_samples_ranked images
        selected_idxs = np.argsort(all_dists)[:num_samples_ranked]
        # Create figure
        row_i, col_i, i_im = 0, 0, 0
        all_images_mosaic = np.zeros(
            (
                3,
                self.size * (int(np.sqrt(num_samples_ranked))),
                self.size * (int(np.sqrt(num_samples_ranked))),
            )
        )
        for j in selected_idxs:
            all_images_mosaic[
            :,
            row_i * self.size: row_i * self.size + self.size,
            col_i * self.size: col_i * self.size + self.size,
            ] = all_outs[j]
            if row_i == int(np.sqrt(num_samples_ranked)) - 1:
                row_i = 0
                if col_i == int(np.sqrt(num_samples_ranked)) - 1:
                    col_i = 0
                else:
                    col_i += 1
            else:
                row_i += 1
            i_im += 1

        out_path = Path(tempfile.mkdtemp()) / "out.png"
        save(all_images_mosaic[np.newaxis, ...], str(out_path), torch_format=False)
        return out_path