def predict_fn()

in notebooks/src/code/inference.py [0:0]


def predict_fn(input_data: dict, model: dict):
    """Classify WORD blocks on a Textract result using a LayoutLMForTokenClassification model

    Parameters
    ----------
    input_data : { doc_json, page_num, s3_output, target_page_only }
        Parsed JSON of Textract result, plus additional control parameters.
    model : { config, device, model, tokenizer }
        The core token classification model, tokenizer, config (not used) and PyTorch device.

    Returns
    -------
    doc_json : Union[List, Dict]
        Input Textract JSON with WORD blocks annotated with additional properties describing their
        classification according to the model: PredictedClass (integer ID of highest-scoring
        class), ClassificationProbabilities (list of floats scoring confidence for each possible
        class), and PredictedClassConfidence (float confidence of highest-scoring class).
    s3_output : S3ObjectSpec
        Passed through from input_data
    """
    collator, config, device, trained_model, tokenizer = itemgetter(
        "collator", "config", "device", "model", "tokenizer"
    )(model)
    doc_json, page_num, s3_output, target_page_only = itemgetter(
        "doc_json", "page_num", "s3_output", "target_page_only"
    )(input_data)
    trp_doc = trp.Document(doc_json)

    # Save memory by extracting individual page, if that's acceptable per the request:
    if target_page_only and page_num is not None:
        doc_json, trp_doc = itemgetter("doc_json", "trp_doc")(
            extract_textract_page(doc_json, page_num, trp_doc)
        )
        page_num = 1

    # We can't use pipeline/TextClassificationPipeline, because LayoutLMForTokenClassification has
    # been implemented such that the bbox input is separate and *optional*, and doesn't come from
    # the tokenizer!
    # So instead the logic here is heavily inspired by the pipeline but with some customizations:
    # https://github.com/huggingface/transformers/blob/f51188cbe74195c14c5b3e2e8f10c2f435f9751a/src/transformers/pipelines/token_classification.py#L115
    # nlp = pipeline(
    #     task="token-classification",
    #     model=trained_model,
    #     config=config,
    #     tokenizer=tokenizer,
    #     framework="pt",
    # )
    with torch.no_grad():
        # Split the page(s) into sequences of acceptable length for inference:
        examples = []
        example_word_block_ids = []
        for page in trp_doc.pages:
            page_words = [word for line in page.lines for word in line.words]
            page_word_texts = [word.text for word in page_words]
            page_word_boxes = layoutlm_boxes_from_trp_blocks(page_words)
            splits = NaiveExampleSplitter.split(
                [word.text for word in page_words],
                tokenizer,
                max_content_seq_len=config.max_position_embeddings - 2,
            )
            for startword, endword in splits:
                examples.append(
                    TextractLayoutLMExampleForWordClassification(
                        word_boxes_normalized=page_word_boxes[startword:endword],
                        word_texts=page_word_texts[startword:endword],
                    )
                )
                example_word_block_ids.append([word.id for word in page_words[startword:endword]])

        # Iterate batches:
        block_results_map = defaultdict(list)
        for ixbatch, _ in enumerate(examples[::INFERENCE_BATCH_SIZE]):
            ixbatchstart = ixbatch * INFERENCE_BATCH_SIZE
            batch_examples = examples[ixbatchstart : (ixbatchstart + INFERENCE_BATCH_SIZE)]
            batch = collator(batch_examples)
            for name in batch:  # Collect batch tensors to same GPU/target device:
                batch[name] = batch[name].to(device)
            output = trained_model.forward(**batch)
            # output.logits is (batch_size, seq_len, n_labels)

            # Convert logits to probabilities and retrieve to numpy:
            output_probs = torch.nn.functional.softmax(output.logits, dim=-1)
            probs_cpu = output_probs.cpu() if output_probs.is_cuda else output_probs
            probs = probs_cpu.numpy()

            # Map (sub-word, token-level) predictions per Textract BLOCK:
            for ixoffset, _ in enumerate(batch_examples):
                word_block_ids = example_word_block_ids[ixbatchstart + ixoffset]
                word_ids = batch.word_ids(ixoffset)
                for ixtoken, ixword in enumerate(word_ids):
                    if ixword is not None:
                        block_results_map[word_block_ids[ixword]].append(
                            probs[ixoffset, ixtoken, :]
                        )

        # Aggregate per-block results and save to Textract JSON:
        for block_id in block_results_map:
            block = trp_doc.getBlockById(block_id)
            block_probs = np.mean(
                np.stack(block_results_map[block_id]),
                axis=0,
            )
            # Remember numpy dtypes may not be JSON serializable, so convert to native types:
            block["ClassificationProbabilities"] = block_probs.tolist()
            block["PredictedClass"] = int(np.argmax(block_probs))
            block["PredictedClassConfidence"] = float(block_probs[block["PredictedClass"]])

    return {"doc_json": doc_json, "s3_output": s3_output}