in botorch/models/kernels/contextual_lcea.py [0:0]
def _task_embeddings_batch(self) -> Tensor:
"""Generate embedding features of contexts when model has multiple batches.
Returns:
a (ns) x num_contexts x n_embs tensor. ns is the batch size i.e num of
posterior samples; n_embs is the sum of embedding dimensions i.e.
sum(embs_dim_list).
"""
context_features = torch.cat(
[
self.context_cat_feature[i, :].unsqueeze(0)
for i in range(self.num_contexts)
]
)
embeddings = []
for b in range(self.batch_shape.numel()): # pyre-ignore
for i in range(len(self.emb_weight_matrix_list)):
# loop over emb layer and concat embs from each layer
embeddings.append(
torch.cat(
[
torch.nn.functional.embedding(
context_features[:, 0].to(
dtype=torch.long, device=self.device
),
self.emb_weight_matrix_list[i][b, :],
).unsqueeze(0)
],
dim=1,
)
)
embeddings = torch.cat(embeddings, dim=0)
# add given embeddings if any
if self.context_emb_feature is not None:
embeddings = torch.cat(
[
embeddings,
self.context_emb_feature.expand(
*self.batch_shape + self.context_emb_feature.shape
),
],
dim=-1,
)
return embeddings