def predict_fn()

in bring-your-own-container/fairseq_translation/fairseq/sagemaker_translate.py [0:0]


def predict_fn(input_data, model):
    args = model["args"]
    task = model["task"]
    max_positions = model["max_positions"]
    device = model["device"]
    translator = model["translator"]
    align_dict = model["align_dict"]
    tgt_dict = model["tgt_dict"]

    inputs = [input_data]

    indices = []
    results = []
    for batch, batch_indices in make_batches(inputs, args, task, max_positions):
        indices.extend(batch_indices)
        results += process_batch(batch, translator, device, args, align_dict, tgt_dict)

    r = []
    for i in np.argsort(indices):
        result = results[i]
        # print(result.src_str)
        for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
            r.append(hypo)
            # print(hypo)
            # print(pos_scores)
            if align is not None:
                print(align)
    return "\n".join(r)