def _pool()

in cp_examples/mip_finetune/mip_model.py [0:0]


    def _pool(self, image_feats, lens):
        if self.pooling == "last_timestep":
            pooled_feats = []
            for b, l in enumerate(lens.tolist()):
                pooled_feats.append(image_feats[b, int(l) - 1])
        elif self.pooling == "sum":
            pooled_feats = []
            for b, l in enumerate(lens.tolist()):
                pooled_feats.append(image_feats[b, : int(l)].sum(0))
        else:
            raise ValueError(f"Unkown pooling method: {self.pooling}")

        pooled_feats = torch.stack(pooled_feats)
        pooled_feats = F.adaptive_avg_pool2d(pooled_feats, (1, 1))
        return pooled_feats.squeeze(3).squeeze(2)