in src/model.py [0:0]
def forward(self, md_embs, query=None):
'''
Query can be None if attention mechanism is not used
'''
processed_mds = []
for idx, md in enumerate(md_embs.values()):
if self.attention_mechanism:
attention_module = self.attention_modules[idx]
md = attention_module(md, query)
# Only need to project data if more than one metadata group used
projection_layer = self.projection_layers[idx]
processed_md = projection_layer(md)
processed_mds.append(processed_md)
if self.use_hierarchical_attention:
combined_md = self.hierarchical_attention_module(torch.stack(processed_mds), query)
else:
combined_md = self.concat_md(processed_mds)
context_emb = self.context_projection(combined_md)
context_emb = self.context_normalization(context_emb)
context_emb = torch.tanh(context_emb)
return context_emb