def forward()

in src/model.py [0:0]


    def forward(self, input):
        input_ids = input["input"].long()
        input_lens = input["text_len"]

        input_embs = self.embeddings(input_ids)
        max_batch_size = input_embs.size(0)
        seq_len = input_embs.size(1)

        # Assuming input is batch_first - permutting for sequence first
        input_embs = input_embs.permute(1, 0, 2)

        zeros = torch.zeros(self.n_layers, max_batch_size, self.hidden_dim).to(device)

        h_init = zeros
        c_init = zeros
        inputs = input_embs
        outputs = []

        md_input = self.metadata_constructor.preprocess_md(input["md"], self.embeddings)

        if self.metadata_constructor.is_precomputable():
            md = self.metadata_constructor(md_input, input_embs)

        for layer in range(self.n_layers):
            h = h_init[layer]
            c = c_init[layer]

            weight_start_index = layer * self._params_per_layer
            weight_end_index = (layer+1) * self._params_per_layer
            w_mh, w_ih, w_hh, b_ih, b_hh = self._all_weights[weight_start_index: weight_end_index]

            # Meta data can be computed in advance when not using attention

            if self.metadata_constructor.is_precomputable():
                precomputed_md = torch.matmul(md, w_mh.t())

            for t in range(seq_len):
                if not self.metadata_constructor.is_precomputable():
                    md = self.metadata_constructor(md_input, h)
                    md_layer = torch.matmul(md, w_mh.t())
                else:
                    md_layer = precomputed_md[t] if precomputed_md.dim() == 3 else precomputed_md

                h, c = self._run_cell(inputs[t], md_layer, (h, c), w_ih, w_hh, b_ih, b_hh)
                outputs += [h]

            inputs = outputs
            outputs = []

        # At the end the input variable will be set to outputs
        # Permutting to have batch - seq len - hidden dim
        lstm_out = torch.stack(inputs).permute(1, 0, 2)

        if self.use_weight_tying:
            vocab_predictions = self.vocab_projection(self.embedding_projection(lstm_out))
        else:
            vocab_predictions = self.vocab_projection(lstm_out)

        if self.use_softmax_adaptation:
            md_embs = md_embs.view(max_batch_size, -1)
            md_context = self.md_vocab_projection(md_embs).unsqueeze(1)
            vocab_predictions += md_context

        return vocab_predictions