in NMT/src/model/attention.py [0:0]
def generate(self, encoded, lang_id, max_len=200, sample=False, temperature=None):
"""
Generate a sentence from a given initial state.
Input:
- FloatTensor of size (batch_size, hidden_dim) representing
sentences encoded in the latent space
Output:
- LongTensor of size (seq_len, batch_size), word indices
- LongTensor of size (batch_size,), sentence x_len
"""
latent = encoded.dec_input
x_len = encoded.input_len
is_cuda = latent.is_cuda
one_hot = None # [] if temperature is not None else None
# check inputs
assert type(lang_id) is int
assert latent.size() == (x_len.max(), x_len.size(0), self.emb_dim)
assert (sample is True) ^ (temperature is None)
# source / target
n_words = self.n_words[lang_id]
emb_layer = self.embeddings[lang_id]
lstm_layer1 = self.lstm1[lang_id]
lstm_layer2 = self.lstm2[lang_id]
lstm_proj_layer = self.lstm_proj_layers[lang_id]
proj_layer = self.proj[lang_id]
# initialize generated sentences batch
slen, bs = latent.size(0), latent.size(1)
assert x_len.max() == slen and x_len.size(0) == bs
cur_len = 1
decoded = torch.LongTensor(max_len, bs).fill_(self.pad_index)
decoded = decoded.cuda() if is_cuda else decoded
decoded[0] = self.bos_index[lang_id]
# compute attention
mask = get_mask(x_len, True, cuda=is_cuda) == 0
h_c_1, h_c_2 = None, None
hidden_states = [latent.data.new(1, bs, self.hidden_dim).zero_()]
while cur_len < max_len:
# previous word embeddings
embeddings = emb_layer(decoded[cur_len - 1])
embeddings = F.dropout(embeddings, p=self.dropout, training=self.training)
# attention layer
attention = self.get_attention(latent, hidden_states[-1][0], embeddings, mask, lang_id)
# lstm step
lstm_input = embeddings.unsqueeze(0)
if self.input_feeding:
lstm_input = torch.cat([lstm_input, attention], 2)
lstm_output, h_c_1 = lstm_layer1(lstm_input, h_c_1)
assert lstm_output.size() == (1, bs, self.hidden_dim)
hidden_states.append(lstm_output)
# lstm (layers > 1)
if self.n_dec_layers > 1:
lstm_output = F.dropout(lstm_output, p=self.dropout, training=self.training)
if not self.input_feeding:
lstm_output = torch.cat([lstm_output, attention], 2)
lstm_output, h_c_2 = lstm_layer2(lstm_output, h_c_2)
assert lstm_output.size() == (1, bs, self.hidden_dim)
# word scores
output = F.dropout(lstm_output, p=self.dropout, training=self.training).view(-1, self.hidden_dim)
if lstm_proj_layer is not None:
output = F.relu(lstm_proj_layer(output))
scores = proj_layer(output).view(bs, n_words)
scores = scores.data
# do no sample words not in the language vocabulary
if self.vocab_mask_neg is not None:
scores.index_fill_(1, self.vocab_mask_neg[lang_id], -1e30)
# select next words: sample (Gumbel Softmax) or one-hot
if sample:
# if temperature is not None:
# gumbel = gumbel_softmax(scores, temperature, hard=True)
# next_words = gumbel.max(1)[1]
# one_hot.append(gumbel)
# else:
next_words = torch.multinomial((scores / temperature).exp(), 1).squeeze(1)
else:
next_words = torch.topk(scores, 1)[1].squeeze(1)
assert next_words.size() == (bs,)
decoded[cur_len] = next_words
cur_len += 1
# stop when there is a </s> in each sentence
if decoded.eq(self.eos_index).sum(0).ne(0).sum() == bs:
break
# compute the length of each generated sentence, and
# put some padding after the end of each sentence
lengths = torch.LongTensor(bs).fill_(cur_len)
for i in range(bs):
for j in range(cur_len):
if decoded[j, i] == self.eos_index:
if j + 1 < max_len:
decoded[j + 1:, i] = self.pad_index
lengths[i] = j + 1
break
if lengths[i] == max_len:
decoded[-1, i] = self.eos_index
if one_hot is not None:
one_hot = torch.cat([x.unsqueeze(0) for x in one_hot], 0)
assert one_hot.size() == (cur_len - 1, bs, n_words)
return decoded[:cur_len], lengths, one_hot