def forward()

in src/model.py [0:0]


    def forward(self, md_embs, query=None):
        '''
        Query can be None if attention mechanism is not used
        '''
        processed_mds = []
        for idx, md in enumerate(md_embs.values()):
            if self.attention_mechanism:
                attention_module = self.attention_modules[idx]
                md = attention_module(md, query)

            # Only need to project data if more than one metadata group used
            projection_layer = self.projection_layers[idx]
            processed_md = projection_layer(md)
            processed_mds.append(processed_md)

        if self.use_hierarchical_attention:
            combined_md = self.hierarchical_attention_module(torch.stack(processed_mds), query)
        else:
            combined_md = self.concat_md(processed_mds)

        context_emb = self.context_projection(combined_md)
        context_emb = self.context_normalization(context_emb)
        context_emb = torch.tanh(context_emb)
        return context_emb