in dialogue_personalization/utils/beam_ptr.py [0:0]
def beam_search_sample(self, enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0):
#batch should have only one example by duplicate
encoder_outputs, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
s_t_0 = self.model.reduce_state(encoder_hidden)
dec_h, dec_c = s_t_0 # 1 x 2*hidden_size
dec_h = dec_h.squeeze(0)
dec_c = dec_c.squeeze(0)
#decoder batch preparation, it has beam_size example initially everything is repeated
beams = [Beam(tokens=[config.SOS_idx],
log_probs=[0.0],
state=(dec_h[0], dec_c[0]),
context = c_t_0[0],
coverage=(coverage_t_0[0] if config.is_coverage else None))
for _ in range(config.beam_size)]
results = []
steps = 0
while steps < config.max_dec_step and len(results) < config.beam_size:
latest_tokens = [h.latest_token for h in beams]
latest_tokens = [t if t < self.vocab_size else config.UNK_idx \
for t in latest_tokens]
y_t_1 = torch.LongTensor(latest_tokens)
if config.USE_CUDA:
y_t_1 = y_t_1.cuda()
all_state_h =[]
all_state_c = []
all_context = []
for h in beams:
state_h, state_c = h.state
all_state_h.append(state_h)
all_state_c.append(state_c)
all_context.append(h.context)
s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
c_t_1 = torch.stack(all_context, 0)
coverage_t_1 = None
if config.is_coverage:
all_coverage = []
for h in beams:
all_coverage.append(h.coverage)
coverage_t_1 = torch.stack(all_coverage, 0)
final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
encoder_outputs, enc_padding_mask, c_t_1,
extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps, training=False)
topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2)
dec_h, dec_c = s_t
dec_h = dec_h.squeeze()
dec_c = dec_c.squeeze()
all_beams = []
num_orig_beams = 1 if steps == 0 else len(beams)
for i in range(num_orig_beams):
h = beams[i]
state_i = (dec_h[i], dec_c[i])
context_i = c_t[i]
coverage_i = (coverage_t[i] if config.is_coverage else None)
for j in range(config.beam_size * 2): # for each of the top 2*beam_size hyps:
new_beam = h.extend(token=topk_ids[i, j].item(),
log_prob=topk_log_probs[i, j].item(),
state=state_i,
context=context_i,
coverage=coverage_i)
all_beams.append(new_beam)
beams = []
for h in self.sort_beams(all_beams):
if h.latest_token == config.EOS_idx:
if steps >= config.min_dec_steps:
results.append(h)
else:
beams.append(h)
if len(beams) == config.beam_size or len(results) == config.beam_size:
break
steps += 1
if len(results) == 0:
results = beams
beams_sorted = self.sort_beams(results)
return beams_sorted[0]