in NMT/src/model/attention.py [0:0]
def generate_beam(self, encoded, lang_id, beam_size=20, max_len=175, sample=False, temperature=None):
"""
Generate a sentence from a given initial state.
Input:
- FloatTensor of size (batch_size, hidden_dim) representing
sentences encoded in the latent space
Output:
- LongTensor of size (seq_len, batch_size), word indices
- LongTensor of size (batch_size,), sentence x_len
"""
latent = encoded.dec_input
x_len = encoded.input_len
is_cuda = latent.is_cuda
one_hot = [] if temperature is not None else None
# check inputs
assert type(lang_id) is int
assert beam_size >= 1
assert latent.size() == (x_len.max(), x_len.size(0), self.emb_dim)
assert temperature is None or sample is True
# source / target
n_words = self.n_words[lang_id]
emb_layer = self.embeddings[lang_id]
lstm_layer1 = self.lstm1[lang_id]
lstm_layer2 = self.lstm2[lang_id]
proj_layer = self.proj[lang_id]
# initialize generated sentences batch
slen, bs = latent.size(0), latent.size(1)
assert x_len.max() == slen and x_len.size(0) == bs
cur_len = 1
decoded = torch.LongTensor(max_len, bs * beam_size).fill_(self.pad_index)
decoded = decoded.cuda() if is_cuda else decoded
decoded[0] = self.bos_index[lang_id]
# expand tensors for beam search
expanded_latent = latent.unsqueeze(2).expand(slen, bs, beam_size, self.emb_dim).contiguous().view(slen, bs * beam_size, self.emb_dim)
expanded_x_len = x_len.unsqueeze(1).expand(x_len.size(0), beam_size).contiguous().view(-1)
# currently finished sentences / current scores in all beams
current_hyps = [[] for _ in range(bs)]
# current_hyp_scores = torch.FloatTensor(bs * beam_size).zero_().cuda()
# at first step, only look at the first beam
current_hyp_scores = latent.data.new(sum([[0] + ([-np.inf] * (beam_size - 1)) for _ in range(bs)], []))
# compute attention
expanded_mask = get_mask(expanded_x_len, True, cuda=is_cuda) == 0
h_c_1, h_c_2 = None, None
hidden_states = [latent.data.new(1, bs * beam_size, self.hidden_dim).zero_()]
while cur_len < max_len:
# previous word embeddings
embeddings = emb_layer(decoded[cur_len - 1])
embeddings = F.dropout(embeddings, p=self.dropout, training=self.training)
# attention layer
attention = self.get_attention(expanded_latent, hidden_states[-1][0], embeddings, expanded_mask, lang_id)
# lstm step
lstm_input = embeddings.unsqueeze(0)
if self.input_feeding:
lstm_input = torch.cat([lstm_input, attention], 2)
lstm_output, h_c_1 = lstm_layer1(lstm_input, h_c_1)
assert lstm_output.size() == (1, bs * beam_size, self.hidden_dim)
hidden_states.append(lstm_output)
# lstm (layers > 1)
if self.n_dec_layers > 1:
lstm_output = F.dropout(lstm_output, p=self.dropout, training=self.training)
if not self.input_feeding:
lstm_output = torch.cat([lstm_output, attention], 2)
lstm_output, h_c_2 = lstm_layer2(lstm_output, h_c_2)
assert lstm_output.size() == (1, bs * beam_size, self.hidden_dim)
# word scores
lstm_output = F.dropout(lstm_output, p=self.dropout, training=self.training)
scores = self.log_sm(proj_layer(lstm_output.view(-1, self.hidden_dim)).view(bs * beam_size, n_words))
# beam search
scores2 = scores.data + current_hyp_scores.unsqueeze(1).expand(bs * beam_size, n_words)
scores2 = scores2.contiguous().view(bs, beam_size * n_words)
best_values, best_indexes = scores2.topk(2 * beam_size, dim=1, largest=True, sorted=True)
all_next = [] # [(current hyp value, next word, position in the total batch)]
for sent_id in range(bs):
if len(current_hyps[sent_id]) == beam_size: # this sentence is done
all_next.extend([(0, self.pad_index, 0)] * beam_size) # pad the batch
continue
offset = sent_id * beam_size
current_next = []
for beam_k in range(2 * beam_size):
word_pos = best_indexes[sent_id, beam_k]
beam_id = word_pos // n_words
word_id = word_pos % n_words
assert 0 <= beam_id < beam_size and 0 <= word_id < n_words
if word_id == self.eos_index or cur_len + 1 == max_len:
current_hyps[sent_id].append((
decoded[:cur_len, offset + beam_id].clone(), best_values[sent_id, beam_k]
))
else:
current_next.append((best_values[sent_id, beam_k], word_id, offset + beam_id))
if len(current_hyps[sent_id]) == beam_size: # this sentence is done
current_next = [(0, self.pad_index, 0)] * beam_size # pad the batch
break
if len(current_next) == beam_size: # enough next hypothesis in the beam
break
assert len(current_next) == beam_size
all_next.extend(current_next)
assert len(all_next) == beam_size * (sent_id + 1)
# update current_hyp_scores
# print(cur_len, "aaaaaaaaaaaaaaaaaaaaaaaaaaa")
assert len(all_next) == bs * beam_size, (len(all_next), bs * beam_size)
current_hyp_scores = latent.data.new([x[0] for x in all_next])
# print(all_next)
# update decoded tensor, and LSTM 1 + LSTM 2 internal states
# make this faster!!!
slow = True
if slow:
_decoded = decoded.clone()
_h_c_1 = (h_c_1[0].data.clone(), h_c_1[1].data.clone())
_h_c_2 = (h_c_2[0].data.clone(), h_c_2[1].data.clone())
for sent_id in range(bs):
for beam_k in range(beam_size):
k = sent_id * beam_size + beam_k
# print(k)
previous_beam_id = all_next[k][2]
_decoded[:, k].copy_(decoded[:, previous_beam_id].clone())
_decoded[cur_len, k] = all_next[k][1]
_h_c_1[0][0, k].copy_(_h_c_1[0][0, previous_beam_id])
_h_c_1[1][0, k].copy_(_h_c_1[1][0, previous_beam_id])
for jj in range(self.n_dec_layers - 1):
_h_c_2[0][jj, k].copy_(_h_c_2[0][jj, previous_beam_id])
_h_c_2[1][jj, k].copy_(_h_c_2[1][jj, previous_beam_id])
decoded = _decoded
h_c_1 = (_h_c_1[0], _h_c_1[1])
h_c_2 = (_h_c_2[0], _h_c_2[1])
else:
next_ids = x_len.new([x[2] for x in all_next])
decoded = decoded.index_select(1, next_ids)
h_c_1 = (
h_c_1[0][0].index_select(0, next_ids).unsqueeze(0),
h_c_1[1][0].index_select(0, next_ids).unsqueeze(0)
)
h_c_2 = (
h_c_2[0][0].index_select(0, next_ids).unsqueeze(0),
h_c_2[1][0].index_select(0, next_ids).unsqueeze(0)
)
cur_len += 1
# stop when there are `beam_size` hypothesis for each sentence
if all(len(hyps) == beam_size for hyps in current_hyps):
break
def score_sent(score, sent_len):
# return score
return score / sent_len
# return score / sent_len ** 0.5
# return score / sent_len + 0.01 * sent_len
# best hypothesis
lengths = torch.LongTensor(bs)
sentences = []
for i in range(bs):
hyps = current_hyps[i]
assert len(hyps) == beam_size
# best_hypo = max(hyps, key=lambda x: x[1])[0]
# print("\n".join([" ".join([(self.src_dico if source else self.tgt_dico)[wid] for wid in hyp[0]]) for hyp in hyps]))
best_hypo = max(hyps, key=lambda x: score_sent(x[1], len(x[0])))[0]
lengths[i] = len(best_hypo) + 1
sentences.append(best_hypo)
decoded = torch.LongTensor(lengths.max(), bs).fill_(self.pad_index)
decoded = decoded.cuda() if is_cuda else decoded
for i, hypo in enumerate(sentences):
decoded[:lengths[i] - 1, i] = hypo
decoded[lengths[i] - 1, i] = self.eos_index
# if one_hot is not None:
# one_hot = torch.cat([x.unsqueeze(0) for x in one_hot], 0)
# assert one_hot.size() == (cur_len - 1, bs, n_words)
return decoded[:cur_len], lengths, one_hot