in src/model.py [0:0]
def forward(self, input):
input_ids = input["input"].long()
input_lens = input["text_len"]
input_embs = self.embeddings(input_ids)
max_batch_size = input_embs.size(0)
seq_len = input_embs.size(1)
# Assuming input is batch_first - permutting for sequence first
input_embs = input_embs.permute(1, 0, 2)
zeros = torch.zeros(self.n_layers, max_batch_size, self.hidden_dim).to(device)
h_init = zeros
c_init = zeros
inputs = input_embs
outputs = []
md_input = self.metadata_constructor.preprocess_md(input["md"], self.embeddings)
if self.metadata_constructor.is_precomputable():
md = self.metadata_constructor(md_input, input_embs)
for layer in range(self.n_layers):
h = h_init[layer]
c = c_init[layer]
weight_start_index = layer * self._params_per_layer
weight_end_index = (layer+1) * self._params_per_layer
w_mh, w_ih, w_hh, b_ih, b_hh = self._all_weights[weight_start_index: weight_end_index]
# Meta data can be computed in advance when not using attention
if self.metadata_constructor.is_precomputable():
precomputed_md = torch.matmul(md, w_mh.t())
for t in range(seq_len):
if not self.metadata_constructor.is_precomputable():
md = self.metadata_constructor(md_input, h)
md_layer = torch.matmul(md, w_mh.t())
else:
md_layer = precomputed_md[t] if precomputed_md.dim() == 3 else precomputed_md
h, c = self._run_cell(inputs[t], md_layer, (h, c), w_ih, w_hh, b_ih, b_hh)
outputs += [h]
inputs = outputs
outputs = []
# At the end the input variable will be set to outputs
# Permutting to have batch - seq len - hidden dim
lstm_out = torch.stack(inputs).permute(1, 0, 2)
if self.use_weight_tying:
vocab_predictions = self.vocab_projection(self.embedding_projection(lstm_out))
else:
vocab_predictions = self.vocab_projection(lstm_out)
if self.use_softmax_adaptation:
md_embs = md_embs.view(max_batch_size, -1)
md_context = self.md_vocab_projection(md_embs).unsqueeze(1)
vocab_predictions += md_context
return vocab_predictions