in code/src/model/attention.py [0:0]
def forward(self, x, lengths):
"""
Input:
- LongTensor of size (slen, bs), word indices
- List of length bs, containing the sentence lengths
Sentences have to be ordered by decreasing length
Output:
- FloatTensor of size (slen, bs, 2 * hidden_dim),
representing the encoded state of each sentence
"""
is_cuda = x.is_cuda
sort_len = lengths.type_as(x.data).sort(0, descending=True)[1]
sort_len_rev = sort_len.sort()[1]
# embeddings
slen, bs = x.size(0), x.size(1)
if x.dim() == 2:
embeddings = self.embeddings(x.index_select(1, sort_len))
else:
assert x.dim() == 3 and x.size(2) == self.n_words
embeddings = x.view(slen * bs, -1).mm(self.embeddings.weight).view(slen, bs, self.emb_dim).index_select(1, sort_len)
embeddings = embeddings.detach() if self.freeze_enc_emb else embeddings
embeddings = F.dropout(embeddings, p=self.dropout, training=self.training)
lstm_input = pack_padded_sequence(embeddings, sorted(lengths.tolist(), reverse=True))
assert lengths.max() == slen and lengths.size(0) == bs
assert lstm_input.data.size() == (sum(lengths), self.emb_dim)
# LSTM
lstm_output, (_, _) = self.lstm(lstm_input)
assert lstm_output.data.size() == (lengths.sum(), 2 * self.hidden_dim)
# get a padded version of the LSTM output
padded_output, _ = pad_packed_sequence(lstm_output)
assert padded_output.size() == (slen, bs, 2 * self.hidden_dim)
# project biLSTM output
padded_output = self.proj(padded_output.view(slen * bs, -1)).view(slen, bs, self.emb_dim)
# re-order sentences in their original order
padded_output = padded_output.index_select(1, sort_len_rev)
# pooling on latent representation
if self.pool_latent is not None:
pool, ks = self.pool_latent
p = ks - slen if slen < ks else (0 if slen % ks == 0 else ks - (slen % ks))
y = padded_output.transpose(0, 2)
if p > 0:
value = 0 if pool == 'avg' else -1e9
y = F.pad(y, (0, p), mode='constant', value=value)
y = (F.avg_pool1d if pool == 'avg' else F.max_pool1d)(y, ks)
padded_output = y.transpose(0, 2)
lengths = (lengths.float() / ks).ceil().long()
# discriminator input
dis_input = lstm_output.data
if self.dis_input_proj:
mask = get_mask(lengths, all_words=True, expand=self.emb_dim, batch_first=False, cuda=is_cuda)
dis_input = padded_output.masked_select(mask).view(lengths.sum().item(), self.emb_dim)
return LatentState(input_len=lengths, dec_input=padded_output, dis_input=dis_input)