bq-connector/docai_bq_connector/connector/BqDocumentMapper.py (318 lines of code) (raw):

# # Copyright 2022 Google LLC # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from datetime import datetime import json import logging import re from typing import List, Optional, Sequence from docai_bq_connector.connector.BqMetadataMapper import BqMetadataMapper from docai_bq_connector.connector.ConversionError import ConversionError from docai_bq_connector.doc_ai_processing.DocumentField import DocumentField from docai_bq_connector.doc_ai_processing.DocumentField import DocumentRow from docai_bq_connector.doc_ai_processing.ProcessedDocument import ProcessedDocument from docai_bq_connector.helper import clean_number from docai_bq_connector.helper import find from docai_bq_connector.helper import get_bool_value from google.cloud.bigquery import SchemaField from google.cloud.documentai_v1 import Document PARSING_METHOD_ENTITIES = "entities" PARSING_METHOD_FORM = "form" PARSING_METHOD_NORMALIZED_VALUES = "normalized_values" class BqDocumentMapper: def __init__( self, document: ProcessedDocument, bq_schema: List[SchemaField], metadata_mapper: BqMetadataMapper, custom_fields: Optional[dict] = None, include_raw_entities: bool = True, include_error_fields: bool = True, continue_on_error: bool = False, parsing_methodology: str = PARSING_METHOD_ENTITIES, ): self.processed_document = document self.bq_schema = bq_schema self.metadata_mapper = metadata_mapper self.custom_fields = custom_fields self.include_raw_entities = include_raw_entities self.include_error_fields = include_error_fields self.continue_on_error = continue_on_error self.parsing_methodology = parsing_methodology self.errors: List[ConversionError] = [] self.fields = self._parse_document() self.dictionary = self._map_document_to_bigquery_schema(self.fields, bq_schema) def _parse_document(self) -> List[DocumentField]: row: DocumentRow if self.parsing_methodology in [ PARSING_METHOD_ENTITIES, PARSING_METHOD_NORMALIZED_VALUES, ]: row = self._parse_entities(self.processed_document.document.entities) elif self.parsing_methodology == PARSING_METHOD_FORM: row = self._parse_form_entities(self.processed_document.document) else: raise Exception("Unsupported parsing methodology") return row.fields def _parse_entities(self, entities) -> DocumentRow: row = DocumentRow() for entity in entities: if len(entity.page_anchor.page_refs) != 1: continue content = entity.mention_text value = content if content is not None and content.strip() != "" else None if len(entity.properties) == 0: if row.find_field_by_name(entity.type_) is not None: self.errors.append( ConversionError( entity.type_, value, "Duplicate field definition", None, ConversionError.error_type_duplicate_field, identifier=entity.id, ) ) continue row.fields.append( DocumentField( entity.type_, value, entity.normalized_value, entity.confidence, entity.page_anchor.page_refs[0].page + 1, ) ) else: parent_field = row.find_field_by_name(entity.type_) if parent_field is None: parent_field = DocumentField( entity.type_, value, entity.normalized_value, entity.confidence, entity.page_anchor.page_refs[0].page + 1, ) row.fields.append(parent_field) row_children = self._parse_entities(entity.properties) if len(parent_field.children) > 0: parent_field.children[0].fields.extend(row_children.fields) else: parent_field.children.append(row_children) return row @staticmethod def _parse_form_entities(document: Document) -> DocumentRow: row = DocumentRow() for page in document.pages: for field in page.form_fields: name = BqDocumentMapper.__get_text(field.field_name, document) safe_name = BqDocumentMapper.convert_to_underscore(name) # name_confidence = round(field.field_name.confidence, 4) value = BqDocumentMapper.__get_text(field.field_value, document) value_confidence = round(field.field_value.confidence, 4) row.fields.append( DocumentField( name=safe_name, value=value, normalized_value=None, confidence=value_confidence, page_number=page.page_number, ) ) return row @staticmethod def convert_to_underscore(name): name = name.strip("@").strip("#").strip("$").strip(":").replace(" ", "") sub_str = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", sub_str).lower() @staticmethod def __get_text(doc_element: dict, document: dict): """ Document AI identifies form fields by their offsets in document text. This function converts offsets to text snippets. """ response = "" # If a text segment spans several lines, it will # be stored in different text segments. for segment in doc_element.text_anchor.text_segments: # type: ignore[attr-defined] start_index = ( int(segment.start_index) if segment in doc_element.text_anchor.text_segments # type: ignore[attr-defined] else 0 ) end_index = int(segment.end_index) temp = document.text[start_index:end_index].strip() # type: ignore[attr-defined] response += temp.replace("\n", " ") return response.strip() def to_bq_row( self, append_parsed_fields: bool = True, exclude_fields: Optional[List[str]] = None, ): row = {} if self.custom_fields is not None and len(self.custom_fields.keys()) > 0: row.update(self.custom_fields) if self.include_raw_entities is True: row["raw_entities"] = json.dumps(self.to_raw_entities()) if append_parsed_fields is True: if exclude_fields is not None and len(exclude_fields) > 0: _dict = self.dictionary.copy() for field_name in exclude_fields: if field_name in _dict: error_val = _dict[field_name] self.errors.append( ConversionError( field_name, error_val, "Excluding field due to BQ insert " "error", None, ConversionError.error_type_exclude_field, ) ) del _dict[field_name] row.update(_dict) else: row.update(self.dictionary) if self.include_error_fields is True: row["has_errors"] = len(self.errors) != 0 row["errors"] = self._error_list_dictionary() return row def process_insert_errors(self, errors: Sequence[dict]): error_records = [] if len(errors) > 0: for err_list in errors: _errors = err_list.get("errors") if not _errors: continue for err in _errors: field_name = err.get("location") # If a nested field has an error, exclude the top level field if "." in field_name: field_name = field_name[0 : field_name.split(".")[0].rfind("[")] error_val = self.dictionary.get(field_name) error_records.append( ConversionError( field_name, error_val, err.get("reason"), err.get("message"), ConversionError.error_type_bq_insert, ) ) self.errors.extend(error_records) return list(map(lambda x: x.key, error_records)) def to_raw_entities(self): result = [] fields = self.fields for field in fields: result.append(field.to_dictionary()) return result def _map_document_to_bigquery_schema( self, fields: List[DocumentField], bq_schema: List[SchemaField] ): result: dict = {} for field in fields: field_name = field.to_bigquery_safe_name() if field.value is None: continue bq_field = find( lambda schema_field: schema_field.name == field_name, bq_schema ) if bq_field is None: logging.warning( "Parsed field '%s' not found in BigQuery schema. Field will be excluded from the " "BigQuery payload", field_name, ) continue if bq_field.mode == "REPEATED": if len(bq_field.fields) == 0: logging.warning("BQ field '%s' has no child fields", field_name) continue if field_name not in result: result[field_name] = [] for child_row in field.children: child_dict = self._map_document_to_bigquery_schema( child_row.fields, bq_field.fields ) if len(child_dict) > 0: result[field_name].append(child_dict) else: _value = self._cast_type(field, bq_field.field_type) if isinstance(_value, ConversionError): self.errors.append(_value) else: result[field_name] = self._cast_type(field, bq_field.field_type) metadata_dict = self._map_document_metadata_to_bigquery_schema(bq_schema) result = result | metadata_dict return result def _map_document_metadata_to_bigquery_schema(self, bq_schema: List[SchemaField]): result: dict = {} mapped_metadata = self.metadata_mapper.map_metadata() for cur_metadata_mapping in mapped_metadata: col_name = cur_metadata_mapping["bq_column_name"] col_value = cur_metadata_mapping["bq_column_value"] if col_value is None: continue bq_field = find( lambda schema_field: schema_field.name == col_name, bq_schema ) if bq_field is None: logging.warning( "Parsed field '%s' not found in BigQuery schema. Field will be excluded from the " "BigQuery payload", col_name, ) continue _value = self._cast_type( DocumentField( name=col_name, value=col_value, normalized_value=col_value, confidence=-1, page_number=-1, ), bq_field.field_type, ) if not isinstance(_value, ConversionError): result[col_name] = _value return result def _error_list_dictionary(self): return list(map(lambda x: x.to_dict(), self.errors)) def _cast_type(self, field: DocumentField, bq_datatype): try: raw_value = ( field.value.strip() if isinstance(field.value, str) else field.value ) if self.parsing_methodology in [ PARSING_METHOD_ENTITIES, PARSING_METHOD_FORM, ]: if field.value is None: return None if bq_datatype == "STRING": return raw_value if bq_datatype == "BOOLEAN": return get_bool_value(raw_value) if bq_datatype == "DATETIME": if isinstance(field.value, datetime): dt: datetime = field.value return dt.isoformat() return raw_value if bq_datatype in ("DECIMAL", "FLOAT", "NUMERIC"): return float(clean_number(raw_value)) if bq_datatype == "INTEGER": return int(clean_number(raw_value)) return raw_value elif self.parsing_methodology in [PARSING_METHOD_NORMALIZED_VALUES]: normalized_value = field.normalized_value if normalized_value is None: return None if bq_datatype == "STRING": return normalized_value.text if bq_datatype == "BOOLEAN": return normalized_value.boolean_value if bq_datatype == "DATETIME": return normalized_value.datetime_value if bq_datatype in ("DECIMAL", "FLOAT", "NUMERIC"): return normalized_value.float_value if bq_datatype == "INTEGER": return normalized_value.integer_value return raw_value except ValueError as ve: return ConversionError( field.name, field.value, f"ValueError: casting to {bq_datatype}", str(ve), ConversionError.error_type_conversion, )