def _extract_data_and_correlation_id()

in assets/model_monitoring/components/src/model_data_collector_preprocessor/run.py [0:0]


def _extract_data_and_correlation_id(df: DataFrame, extract_correlation_id: bool, datastore: str = None) -> DataFrame:
    """
    Extract data and correlation id from the MDC logs.

    If data column exists, return the json contents in it,
    otherwise, return the dataref content which is a url to the json file.
    """

    def safe_dumps(x):
        if type(x) in [dict, list]:
            return json.dumps(x)
        elif type(x) is np.ndarray:
            return json.dumps(x.tolist())
        else:
            return x

    def convert_object_to_str(dataframe: pd.DataFrame) -> pd.DataFrame:
        columns = dataframe.columns
        for column in columns:
            if dataframe[column].dtype == "object":
                dataframe[column] = dataframe[column].apply(safe_dumps)

        return dataframe

    def read_data(row) -> str:
        data = getattr(row, MDC_DATA_COLUMN, None)
        if data:
            return data

        dataref = getattr(row, MDC_DATAREF_COLUMN, None)
        # convert https path to azureml long form path which can be recognized by azureml filesystem
        # and read by pd.read_json()
        data_url = _convert_to_azureml_long_form(dataref, datastore)
        return data_url
        # TODO: Move this to tracking stream if both data and dataref are NULL

    def row_to_pdf(row) -> pd.DataFrame:
        df = pd.read_json(read_data(row))
        df = convert_object_to_str(df)
        return df

    data_columns = _get_data_columns(df)
    data_rows = df.select(data_columns).rdd.take(SCHEMA_INFER_ROW_COUNT)  # TODO: make it an argument user can define

    spark = init_spark()
    infer_pdf = pd.concat([row_to_pdf(row) for row in data_rows], ignore_index=True)
    data_as_df = spark.createDataFrame(infer_pdf)
    # data_as_df.show()
    # data_as_df.printSchema()

    def extract_data_and_correlation_id(entry, correlationid):
        result = pd.read_json(entry)
        result = convert_object_to_str(result)
        result[MDC_CORRELATION_ID_COLUMN] = ""
        for index, row in result.iterrows():
            result.loc[index, MDC_CORRELATION_ID_COLUMN] = (
                correlationid + "_" + str(index)
            )
        return result

    def transform_df_function_with_correlation_id(iterator):
        for df in iterator:
            yield pd.concat(
                extract_data_and_correlation_id(
                    read_data(row),
                    getattr(row, MDC_CORRELATION_ID_COLUMN),
                )
                for row in df.itertuples()
            )

    def transform_df_function_without_correlation_id(iterator):
        for df in iterator:
            pdf = pd.concat(
                convert_object_to_str(pd.read_json(read_data(row))) for row in df.itertuples()
            )
            yield pdf

    if extract_correlation_id:
        # Add empty column to get the correlationId in the schema
        data_as_df = data_as_df.withColumn(MDC_CORRELATION_ID_COLUMN, lit(""))
        data_columns.append(MDC_CORRELATION_ID_COLUMN)
        transformed_df = df.select(data_columns).mapInPandas(
            transform_df_function_with_correlation_id, schema=data_as_df.schema
        )
    else:
        # TODO: if neither data and dataref move to tracking stream (or throw ModelMonitoringException?)
        transformed_df = df.select(data_columns).mapInPandas(
            transform_df_function_without_correlation_id, schema=data_as_df.schema
        )
    return transformed_df