in predict.py [0:0]
def predict(self, image, gen_model="icgan", conditional_class=None, num_samples=1, seed=0):
assert isinstance(seed, int), "seed should be an integer"
if gen_model == 'cc_icgan':
assert conditional_class is not None, 'please set conditional_class for cc_icgan'
num_samples_ranked = num_samples
experiment_name = (
"icgan_biggan_imagenet_res256"
if gen_model == "icgan"
else "cc_icgan_biggan_imagenet_res256"
)
num_samples_total = num_samples * 10
truncation = 0.7
if conditional_class is not None:
class_index = NAME2IND[conditional_class]
input_image_instance = str(image)
if gen_model == "icgan":
class_index = None
if seed == 0:
seed = None
state = None if not seed else np.random.RandomState(seed)
np.random.seed(seed)
feature_extractor_name = ("classification" if gen_model == "cc_icgan" else "selfsupervised")
# Load feature extractor (outlier filtering and optionally input image feature extraction)
self.feature_extractor, self.last_feature_extractor = load_feature_extractor(
gen_model, self.last_feature_extractor, self.feature_extractor)
# Load features
if input_image_instance not in ["None", "", None]:
print("Obtaining instance features from input image!")
input_feature_index = None
input_image_tensor = preprocess_input_image(input_image_instance, self.size)
with torch.no_grad():
input_features, _ = self.feature_extractor(input_image_tensor.cuda())
input_features /= torch.linalg.norm(input_features, dim=-1, keepdims=True)
elif input_feature_index is not None:
print("Selecting an instance from pre-extracted vectors!")
input_features = np.load(
"stored_instances/imagenet_res"
+ str(self.size)
+ "_rn50_"
+ feature_extractor_name
+ "_kmeans_k1000_instance_features.npy",
allow_pickle=True,
).item()["instance_features"][input_feature_index: input_feature_index + 1]
else:
input_features = None
# Load generative model
self.model, self.last_gen_model = load_generative_model(
gen_model, self.last_gen_model, experiment_name, self.model)
# Prepare other variables
replace_to_inplace_relu(self.model)
# Create noise, instance and class vector
noise_vector = truncnorm.rvs(
-2 * truncation,
2 * truncation,
size=(num_samples_total, self.noise_size),
random_state=state,
).astype(np.float32)
noise_vector = torch.tensor(noise_vector, requires_grad=False, device="cuda")
if input_features is not None:
instance_vector = torch.tensor(
input_features, requires_grad=False, device="cuda"
).repeat(num_samples_total, 1)
else:
instance_vector = None
if class_index is not None:
input_label = torch.LongTensor([class_index] * num_samples_total)
else:
input_label = None
if input_feature_index is not None:
print("Conditioning on instance with index: ", input_feature_index)
all_outs, all_dists = [], []
for i_bs in range(num_samples_total // self.batch_size + 1):
start = i_bs * self.batch_size
end = min(start + self.batch_size, num_samples_total)
if start == end:
break
out = get_output(
noise_vector[start:end],
input_label[start:end] if input_label is not None else None,
instance_vector[start:end] if instance_vector is not None else None,
self.model,
truncation,
channels=3,
)
if instance_vector is not None:
# Get features from generated images + feature extractor
out_ = preprocess_generated_image(out)
with torch.no_grad():
out_features, _ = self.feature_extractor(out_.cuda())
out_features /= torch.linalg.norm(out_features, dim=-1, keepdims=True)
dists = sklearn.metrics.pairwise_distances(
out_features.cpu(),
instance_vector[start:end].cpu(),
metric="euclidean",
n_jobs=-1,
)
all_dists.append(np.diagonal(dists))
all_outs.append(out.detach().cpu())
del out
all_outs = torch.cat(all_outs)
all_dists = np.concatenate(all_dists)
# Order samples by distance to conditioning feature vector and select only num_samples_ranked images
selected_idxs = np.argsort(all_dists)[:num_samples_ranked]
# Create figure
row_i, col_i, i_im = 0, 0, 0
all_images_mosaic = np.zeros(
(
3,
self.size * (int(np.sqrt(num_samples_ranked))),
self.size * (int(np.sqrt(num_samples_ranked))),
)
)
for j in selected_idxs:
all_images_mosaic[
:,
row_i * self.size: row_i * self.size + self.size,
col_i * self.size: col_i * self.size + self.size,
] = all_outs[j]
if row_i == int(np.sqrt(num_samples_ranked)) - 1:
row_i = 0
if col_i == int(np.sqrt(num_samples_ranked)) - 1:
col_i = 0
else:
col_i += 1
else:
row_i += 1
i_im += 1
out_path = Path(tempfile.mkdtemp()) / "out.png"
save(all_images_mosaic[np.newaxis, ...], str(out_path), torch_format=False)
return out_path