in curiosity/paraphrase_models.py [0:0]
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Finalize predictions.
"""
predicted_indices = output_dict["predictions"]
if not isinstance(predicted_indices, numpy.ndarray):
predicted_indices = predicted_indices.detach().cpu().numpy()
all_predicted_tokens = []
for indices in predicted_indices:
# Beam search gives us the top k results for each source sentence
# in the batch but we just want the single best.
if len(indices.shape) > 1:
indices = indices[0]
indices = list(indices)
# Collect indices till the first end_symbol
if self._end_index in indices:
indices = indices[: indices.index(self._end_index)]
predicted_tokens = [
self.vocab.get_token_from_index(x, namespace=self._target_namespace)
for x in indices
]
all_predicted_tokens.append(predicted_tokens)
output_dict["predicted_tokens"] = all_predicted_tokens
return output_dict