in ttw/models/beam_search.py [0:0]
def beam_search(self, initial_input, initial_state=None, context=None):
"""Runs beam search sequence generation on a single image.
Args:
initial_input: An initial input for the model -
list of batch size holding the first input for every entry.
initial_state (optional): An initial state for the model -
list of batch size holding the current state for every entry.
Returns:
A list of batch size, each the most likely sequence from the possible beam_size candidates.
"""
print(self.beam_size)
batch_size = len(initial_input)
partial_sequences = [TopN(self.beam_size) for _ in range(batch_size)]
complete_sequences = [TopN(self.beam_size) for _ in range(batch_size)]
words, logprobs, new_state = self.decode_step(
initial_input, initial_state, context,
k=self.beam_size)
for b in range(batch_size):
# Create first beam_size candidate hypotheses for each entry in
# batch
for k in range(self.beam_size):
seq = Sequence(
output=initial_input[b] + [words[b][k]],
state=new_state[b],
logprob=logprobs[b][k],
score=logprobs[b][k],
context=context[b])
partial_sequences[b].push(seq)
# Run beam search.
for _ in range(self.max_sequence_length - 1):
partial_sequences_list = [p.extract() for p in partial_sequences]
for p in partial_sequences:
p.reset()
# Keep a flattened list of parial hypotheses, to easily feed
# through a model as whole batch
flattened_partial = [
s for sub_partial in partial_sequences_list for s in sub_partial]
input_feed = [c.output[-1] for c in flattened_partial]
state_feed = [c.state for c in flattened_partial]
context_feed = [c.context for c in flattened_partial]
if len(input_feed) == 0:
# We have run out of partial candidates; happens when
# beam_size=1
break
# Feed current hypotheses through the model, and recieve new outputs and states
# logprobs are needed to rank hypotheses
words, logprobs, new_states \
= self.decode_step(
input_feed, state_feed, context_feed,
k=self.beam_size + 1)
idx = 0
for b in range(batch_size):
# For every entry in batch, find and trim to the most likely
# beam_size hypotheses
for partial in partial_sequences_list[b]:
state = new_states[idx]
k = 0
num_hyp = 0
while num_hyp < self.beam_size:
w = words[idx][k]
output = partial.output + [w]
logprob = partial.logprob + logprobs[idx][k]
score = logprob
k += 1
num_hyp += 1
if w == self.eos_id:
if self.length_normalization_factor > 0:
L = self.length_normalization_const
length_penalty = (L + len(output)) / (L + 1)
score /= length_penalty ** self.length_normalization_factor
beam = Sequence(output, state,
logprob, score, context=context[b])
complete_sequences[b].push(beam)
num_hyp -= 1 # we can fit another hypotheses as this one is over
else:
beam = Sequence(output, state,
logprob, score, context=context[b])
partial_sequences[b].push(beam)
idx += 1
# If we have no complete sequences then fall back to the partial sequences.
# But never output a mixture of complete and partial sequences because a
# partial sequence could have a higher score than all the complete
# sequences.
for b in range(batch_size):
if not complete_sequences[b].size():
complete_sequences[b] = partial_sequences[b]
seqs = [complete.extract(sort=True)[0]
for complete in complete_sequences]
return seqs