in cp_examples/mip_finetune/mip_model.py [0:0]
def _pool(self, image_feats, lens):
if self.pooling == "last_timestep":
pooled_feats = []
for b, l in enumerate(lens.tolist()):
pooled_feats.append(image_feats[b, int(l) - 1])
elif self.pooling == "sum":
pooled_feats = []
for b, l in enumerate(lens.tolist()):
pooled_feats.append(image_feats[b, : int(l)].sum(0))
else:
raise ValueError(f"Unkown pooling method: {self.pooling}")
pooled_feats = torch.stack(pooled_feats)
pooled_feats = F.adaptive_avg_pool2d(pooled_feats, (1, 1))
return pooled_feats.squeeze(3).squeeze(2)