def _task_embeddings_batch()

in botorch/models/kernels/contextual_lcea.py [0:0]


    def _task_embeddings_batch(self) -> Tensor:
        """Generate embedding features of contexts when model has multiple batches.

        Returns:
            a (ns) x num_contexts x n_embs tensor. ns is the batch size i.e num of
            posterior samples; n_embs is the sum of embedding dimensions i.e.
            sum(embs_dim_list).
        """
        context_features = torch.cat(
            [
                self.context_cat_feature[i, :].unsqueeze(0)
                for i in range(self.num_contexts)
            ]
        )
        embeddings = []
        for b in range(self.batch_shape.numel()):  # pyre-ignore
            for i in range(len(self.emb_weight_matrix_list)):
                # loop over emb layer and concat embs from each layer
                embeddings.append(
                    torch.cat(
                        [
                            torch.nn.functional.embedding(
                                context_features[:, 0].to(
                                    dtype=torch.long, device=self.device
                                ),
                                self.emb_weight_matrix_list[i][b, :],
                            ).unsqueeze(0)
                        ],
                        dim=1,
                    )
                )
        embeddings = torch.cat(embeddings, dim=0)
        # add given embeddings if any
        if self.context_emb_feature is not None:
            embeddings = torch.cat(
                [
                    embeddings,
                    self.context_emb_feature.expand(
                        *self.batch_shape + self.context_emb_feature.shape
                    ),
                ],
                dim=-1,
            )
        return embeddings