def get_required_keys()

in source/lambdas/pipeline_orchestration/lambda_helpers.py [0:0]


def get_required_keys(pipeline_type: str, use_model_registry: str, problem_type: str = None) -> List[str]:

    common_keys = ["pipeline_type", "model_name", "inference_instance"]
    model_location = ["model_artifact_location"]
    builtin_model_keys = ["model_framework", "model_framework_version"] + model_location
    custom_model_keys = ["custom_image_uri"] + model_location
    # if model registry is used
    if use_model_registry == "Yes":
        builtin_model_keys = custom_model_keys = ["model_package_name"]

    realtime_specific_keys = ["data_capture_location"]
    batch_specific_keys = ["batch_inference_data", "batch_job_output_location"]

    # model monitor keys
    monitors = ["byom_model_quality_monitor", "byom_model_bias_monitor", "byom_model_explainability_monitor"]
    if pipeline_type in monitors and problem_type not in [
        "Regression",
        "MulticlassClassification",
        "BinaryClassification",
    ]:
        raise BadRequest("Bad request format. Unsupported problem_type in byom_model_quality_monitor pipeline")

    # common required keys between model monitor types
    common_monitor_keys = [
        "pipeline_type",
        "model_name",
        "endpoint_name",
        "baseline_data",
        "baseline_job_output_location",
        "data_capture_location",
        "monitoring_output_location",
        "schedule_expression",
        "monitor_max_runtime_seconds",
        "instance_type",
        "instance_volume_size",
    ]

    # ModelQuality specific keys
    model_quality_keys = ["baseline_inference_attribute", "baseline_ground_truth_attribute"]
    # common model related monitors
    common_model_keys = ["problem_type"]
    # add required keys based on problem type
    if problem_type in ["Regression", "MulticlassClassification"]:
        common_model_keys.append("monitor_inference_attribute")
    # problem_type == "BinaryClassification". Note: depending on the model output,
    # monitor_inference_attribute, monitor_probability_attribute, and probability_threshold_attribute
    # can be passed all together, or in pairs
    elif pipeline_type == "byom_model_quality_monitor":
        model_quality_keys.append("baseline_probability_attribute")

    # shared_model_quality_bias keys
    shared_model_quality_bias_keys = ["monitor_ground_truth_input"]

    # add model_predicted_label_config if "byom_model_bias_monitor" and
    # the problem is "BinaryClassification" or "MulticlassClassification"
    extra_bias_keys = []
    if pipeline_type == "byom_model_bias_monitor" and problem_type in [
        "BinaryClassification",
        "MulticlassClassification",
    ]:
        extra_bias_keys.append("model_predicted_label_config")

    # create pipeline_type -> required_keys map
    pipeline_keys_map = {
        "byom_realtime_builtin": [*common_keys, *builtin_model_keys, *realtime_specific_keys],
        "byom_realtime_custom": [*common_keys, *custom_model_keys, *realtime_specific_keys],
        "byom_batch_builtin": [*common_keys, *builtin_model_keys, *batch_specific_keys],
        "byom_batch_custom": [*common_keys, *custom_model_keys, *batch_specific_keys],
        "byom_data_quality_monitor": common_monitor_keys,
        "byom_model_quality_monitor": [
            *common_monitor_keys,
            *model_quality_keys,
            *common_model_keys,
            *shared_model_quality_bias_keys,
        ],
        "byom_model_bias_monitor": [
            *common_monitor_keys,
            *common_model_keys,
            *shared_model_quality_bias_keys,
            *extra_bias_keys,
            "bias_config",
        ],
        "byom_model_explainability_monitor": [
            *common_monitor_keys,
            *common_model_keys,
            "shap_config",
        ],
        "byom_image_builder": ["pipeline_type", "custom_algorithm_docker", "ecr_repo_name", "image_tag"],
    }

    # get the required keys based on the pipeline_type
    required_keys = pipeline_keys_map.get(pipeline_type)

    # return required_keys if not None. Otherwise, raise BadRequest exception
    if required_keys:
        return required_keys
    else:
        raise BadRequest(
            "Bad request format. Pipeline type not supported. Check documentation for API & config formats"
        )