in src/model.py [0:0]
def forward(self, input):
""" Appends meta data embeddings to input and passes through LSTM."""
input_ids = input["input"].long()
input_lens = input["text_len"]
input_embs = self.embeddings(input_ids)
if self.use_md:
md_input = self.metadata_constructor.preprocess_md(input["md"], self.embeddings)
md_emb = self.metadata_constructor(md_input)
# Prepending meta data information to text
sos_embs = input_embs[:, :1, :]
text_embs = input_embs[:, 1:, :]
joined_embs = torch.cat((sos_embs, md_emb, text_embs), dim=1)
joined_lens = 1 + input_lens # fixed metadata embedding
packed_input = pack_padded_sequence(joined_embs, joined_lens,
batch_first=True, enforce_sorted=False)
else:
packed_input = pack_padded_sequence(input_embs, input_lens,
batch_first=True, enforce_sorted=False)
# output: batch, seq_len, hidden_size
lstm_out, _ = self.lstm(packed_input)
lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)
if self.use_weight_tying:
vocab_predictions = self.vocab_projection(self.embedding_projection(lstm_out))
else:
vocab_predictions = self.vocab_projection(lstm_out)
return vocab_predictions