in bring-your-own-container/fairseq_translation/fairseq/sagemaker_translate.py [0:0]
def predict_fn(input_data, model):
args = model["args"]
task = model["task"]
max_positions = model["max_positions"]
device = model["device"]
translator = model["translator"]
align_dict = model["align_dict"]
tgt_dict = model["tgt_dict"]
inputs = [input_data]
indices = []
results = []
for batch, batch_indices in make_batches(inputs, args, task, max_positions):
indices.extend(batch_indices)
results += process_batch(batch, translator, device, args, align_dict, tgt_dict)
r = []
for i in np.argsort(indices):
result = results[i]
# print(result.src_str)
for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
r.append(hypo)
# print(hypo)
# print(pos_scores)
if align is not None:
print(align)
return "\n".join(r)