in src/model.py [0:0]
    def forward(self, input):
        input_ids = input["input"].long()
        input_lens = input["text_len"]
        input_embs = self.embeddings(input_ids)
        max_batch_size = input_embs.size(0)
        seq_len = input_embs.size(1)
        # Assuming input is batch_first - permutting for sequence first
        input_embs = input_embs.permute(1, 0, 2)
        zeros = torch.zeros(self.n_layers, max_batch_size, self.hidden_dim).to(device)
        h_init = zeros
        c_init = zeros
        inputs = input_embs
        outputs = []
        md_input = self.metadata_constructor.preprocess_md(input["md"], self.embeddings)
        if self.metadata_constructor.is_precomputable():
            md = self.metadata_constructor(md_input, input_embs)
        for layer in range(self.n_layers):
            h = h_init[layer]
            c = c_init[layer]
            weight_start_index = layer * self._params_per_layer
            weight_end_index = (layer+1) * self._params_per_layer
            w_mh, w_ih, w_hh, b_ih, b_hh = self._all_weights[weight_start_index: weight_end_index]
            # Meta data can be computed in advance when not using attention
            if self.metadata_constructor.is_precomputable():
                precomputed_md = torch.matmul(md, w_mh.t())
            for t in range(seq_len):
                if not self.metadata_constructor.is_precomputable():
                    md = self.metadata_constructor(md_input, h)
                    md_layer = torch.matmul(md, w_mh.t())
                else:
                    md_layer = precomputed_md[t] if precomputed_md.dim() == 3 else precomputed_md
                h, c = self._run_cell(inputs[t], md_layer, (h, c), w_ih, w_hh, b_ih, b_hh)
                outputs += [h]
            inputs = outputs
            outputs = []
        # At the end the input variable will be set to outputs
        # Permutting to have batch - seq len - hidden dim
        lstm_out = torch.stack(inputs).permute(1, 0, 2)
        if self.use_weight_tying:
            vocab_predictions = self.vocab_projection(self.embedding_projection(lstm_out))
        else:
            vocab_predictions = self.vocab_projection(lstm_out)
        if self.use_softmax_adaptation:
            md_embs = md_embs.view(max_batch_size, -1)
            md_context = self.md_vocab_projection(md_embs).unsqueeze(1)
            vocab_predictions += md_context
        return vocab_predictions