in evaluations/th_evaluator.py [0:0]
def get_preds(self, batch, captions=None):
with torch.no_grad():
batch = 127.5 * (batch + 1)
np_batch = batch.to(torch.uint8).cpu().numpy().transpose((0, 2, 3, 1))
pred, spatial_pred = self.inception(batch)
pred, spatial_pred = pred.reshape(
[pred.shape[0], -1]
), spatial_pred.reshape([spatial_pred.shape[0], -1])
clip_in = torch.stack(
[clip_preproc(self.clip_preproc_fn, img) for img in np_batch]
)
clip_pred = self.clip_visual(clip_in.half().to(dist_util.dev()))
if captions is not None:
text_in = self.clip_tokenizer(captions)
text_pred = self.clip_text(text_in.to(dist_util.dev()))
else:
# Hack to easily deal with no captions
text_pred = self.clip_proj(clip_pred.half())
text_pred = text_pred / text_pred.norm(dim=-1, keepdim=True)
return pred, spatial_pred, clip_pred, text_pred, np_batch