def handler()

in pipeline/postprocessing/fn-postprocess/main.py [0:0]


def handler(event, context):
    try:
        srcbucket = event["Input"]["Bucket"]
        srckey = event["Input"]["Key"]
        entity_config = event.get("EntityConfig", DEFAULT_ENTITY_CONFIG)
        if entity_config is None and DEFAULT_ENTITY_CONFIG_PARAM:
            entity_config = json.loads(
                ssm.get_parameter(Name=DEFAULT_ENTITY_CONFIG_PARAM)["Parameter"]["Value"]
            )
    except KeyError as ke:
        raise MalformedRequest(f"Missing field {ke}, please check your input payload") from ke
    if entity_config is None:
        raise MalformedRequest(
            "Request did not specify EntityConfig, and neither env var DEFAULT_ENTITY_CONFIG (for "
            "inline json) nor DEFAULT_ENTITY_CONFIG_PARAM (for SSM parameter) are set"
        )
    entity_config = [FieldConfiguration.from_dict(cfg) for cfg in entity_config]

    doc = json.loads(s3.Bucket(srcbucket).Object(srckey).get()["Body"].read())
    doc = trp.Document(doc)

    entities = extract_entities(doc, entity_config)

    result_fields = {}
    for ixtype, cfg in enumerate(cfg for cfg in entity_config if not cfg.ignore):
        # Filter the list of detected entity mentions for this class only:
        field_entities = list(filter(lambda e: e.cls_id == cfg.class_id, entities))

        # Consolidate multiple detections of exactly the same value (text):
        field_values = {}
        for ixe, e in enumerate(field_entities):
            if e.text in field_values:
                field_values[e.text]["Detections"].append(e)
                field_values[e.text]["IxLastDetection"] = ixe
            else:
                field_values[e.text] = {
                    "Text": e.text,
                    "Detections": [e],
                    "IxFirstDetection": ixe,
                    "IxLastDetection": ixe,
                }
        field_values_list = [v for v in field_values.values()]
        # To approximate confidence for values detected multiple times, model each detection as an
        # uncorrelated observation of that value (naive, probably biased to over-estimate):
        for v in field_values_list:
            # e.g. {0.84, 0.86, 0.90} -> 1 - (0.16 * 0.14 * 0.1) = 0.998
            v["Confidence"] = 1 - reduce(
                lambda acc, next: acc * (1 - next.confidence),
                v["Detections"],
                1.0,
            )
        # TODO: Adjust for other (disagreeing) confidences better
        value_conf_norm = reduce(lambda acc, next: acc + next["Confidence"], field_values_list, 0.0)
        for v in field_values_list:
            v["Confidence"] = v["Confidence"] / max(1.0, value_conf_norm)

        field_result = {
            "ClassId": cfg.class_id,
            "Confidence": 0.0,
            "NumDetections": len(field_entities),
            "NumDetectedValues": len(field_values),
            "SortOrder": ixtype,
        }
        result_fields[cfg.name] = field_result
        if cfg.optional is not None:
            field_result["Optional"] = cfg.optional

        if cfg.select is not None:
            # Single-valued field: Select 'best' matched values:
            selector = cfg.select
            field_values_sorted = sorted(
                field_values_list,
                key=selector.sort,
                reverse=selector.desc,
            )
            if len(field_values_sorted):
                field_result["Value"] = field_values_sorted[0]["Text"]
                field_result["Confidence"] = field_values_sorted[0]["Confidence"]
                field_result["Detections"] = list(
                    map(
                        lambda e: e.to_dict(),
                        field_values_sorted[0]["Detections"],
                    )
                )
            else:
                field_result["Value"] = ""
                field_result["Detections"] = []
        else:
            # Multi-valued field: Pass through all matched values
            field_result["Values"] = list(
                map(
                    lambda v: {
                        "Confidence": v["Confidence"],
                        "Value": v["Text"],
                        "Detections": list(
                            map(
                                lambda e: e.to_dict(),
                                v["Detections"],
                            )
                        ),
                    },
                    sorted(field_values_list, key=lambda v: v["Confidence"], reverse=True),
                )
            )
            if len(field_result["Values"]):
                # For multi value, take field confidence = average value confidence
                field_result["Confidence"] = (
                    reduce(
                        lambda acc, next: acc + next["Confidence"],
                        field_result["Values"],
                        0.0,
                    )
                    / len(field_result["Values"])
                )

    return {
        "Confidence": min(
            r["Confidence"]
            for r in result_fields.values()
            if not (r["Confidence"] == 0 and r.get("Optional"))
        ),
        "Fields": result_fields,
    }