in evaluations/th_evaluator.py [0:0]
def get_activations(self, data, num_samples, global_batch_size, pr_samples=50000):
if self.is_root:
preds = []
spatial_preds = []
clip_preds = []
pr_images = []
for _ in tqdm(range(0, int(np.ceil(num_samples / global_batch_size)))):
batch, cond, _ = next(data)
batch, cond = batch.to(dist_util.dev()), {
k: v.to(dist_util.dev()) for k, v in cond.items()
}
pred, spatial_pred, clip_pred, _, np_batch = self.get_preds(batch)
pred, spatial_pred, clip_pred = (
all_gather(pred).cpu().numpy(),
all_gather(spatial_pred).cpu().numpy(),
all_gather(clip_pred).cpu().numpy(),
)
if self.is_root:
preds.append(pred)
spatial_preds.append(spatial_pred)
clip_preds.append(clip_pred)
if len(pr_images) * np_batch.shape[0] < pr_samples:
pr_images.append(np_batch)
if self.is_root:
preds, spatial_preds, clip_preds, pr_images = (
np.concatenate(preds, axis=0),
np.concatenate(spatial_preds, axis=0),
np.concatenate(clip_preds, axis=0),
np.concatenate(pr_images, axis=0),
)
# assert len(pr_images) >= pr_samples
return (
preds[:num_samples],
spatial_preds[:num_samples],
clip_preds[:num_samples],
pr_images[:pr_samples],
)
else:
return [], [], [], []