def run()

in 5-app-infra/3-artifact-publish/docker/flex-templates/pubsub_to_bq_deidentification/decrypt_pubsub_to_bq.py [0:0]


def run(argv=None, save_main_session=True):

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--output_table',
        required=True,
        help=(
            'Output BigQuery table for results specified as: '
            'PROJECT:DATASET.TABLE or DATASET.TABLE.'
        )
    )
    parser.add_argument(
        '--bq_schema',
        required=True,
        help=(
            'Output BigQuery table schema specified as string with format: '
            'FIELD_1:STRING,FIELD_2:STRING,...'
        )
    )
    parser.add_argument(
        '--dlp_project',
        required=True,
        help=(
            'ID of the project that holds the DLP template.'
        )
    )
    parser.add_argument(
        '--dlp_location',
        required=False,
        help=(
            'The Location of the DLP template resource.'
        )
    )
    parser.add_argument(
        '--deidentification_template_name',
        required=True,
        help=(
            'Name of the DLP Structured De-identification Template '
            'of the form "projects/<PROJECT>/locations/<LOCATION>'
            '/deidentifyTemplates/<TEMPLATE_ID>"'
        )
    )
    parser.add_argument(
        "--window_interval_sec",
        default=30,
        type=int,
        help=(
            'Window interval in seconds for grouping incoming messages.'
        )
    )
    parser.add_argument(
        "--batch_size",
        default=1000,
        type=int,
        help=(
            'Number of records to be sent in a batch in '
            'the call to the Data Loss Prevention (DLP) API.'
        )
    )

    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument(
        '--input_topic',
        help=(
            'Input PubSub topic of the form '
            '"projects/<PROJECT>/topics/<TOPIC>".'
            'A temporary subscription will be created from '
            'the specified topic.'
        )
    )
    group.add_argument(
        '--input_subscription',
        help=(
            'Input PubSub subscription of the form '
            '"projects/<PROJECT>/subscriptions/<SUBSCRIPTION>."'
        )
    )
    
    parser.add_argument(
        '--cryptoKeyName',
        required=True,
        help=(
            'GCP KMS Key URI as'
            'projects/<PROJECT_ID>/locations/<LOCATION>/keyRings/<KEY_RING>/cryptoKeys/<KEY_NAME>'
        )
    )
    parser.add_argument(
        '--wrappedKey',
        required=True,
        help=(
            'Tink Keyset base64 encoded wrapped key from Secret Manager'
            'projects/<PROJECT_ID>/secrets/<SECRET_NAME>/versions/<VERSION>'
        )
    )

    known_args, pipeline_args = parser.parse_known_args(argv)

    # Define pipeline options
    pipeline_options = PipelineOptions(
        pipeline_args,
        # runner='DirectRunner',  # Use DirectRunner for local testing
        streaming=True,
        # direct_num_workers=1  # Reduce parallelism to avoid threading issues
        # using cloudpickle over dill for serialization; observed issues in serializing on Dataflow
        pickle_library="cloudpickle",
    )

    # Build the Beam pipeline
    with beam.Pipeline(options=pipeline_options) as p:
        logger.info(f"Processing pipeline started..")
        
        if known_args.input_subscription:
            messages = (
            p
            | 'Read from Pub/Sub' >> beam.io.ReadFromPubSub(subscription=f'{known_args.input_subscription}').with_output_types(bytes)
            | 'Decrypt Messages' >> beam.ParDo(DecryptMessages(known_args.cryptoKeyName, known_args.wrappedKey))
            | 'Parse JSON payload' >>
                beam.Map(json.loads)
            | 'Normalize' >> beam.FlatMap(normalize_data)
            | "Fixed-size windows" >>
                beam.WindowInto(window.FixedWindows(60))
        )
        else:
            messages = (
            p
            | 'Read from Pub/Sub' >> beam.io.ReadFromPubSub(topic=f'{known_args.input_topic}').with_output_types(bytes)
            | 'Decrypt Messages' >> beam.ParDo(DecryptMessages(known_args.cryptoKeyName, known_args.wrappedKey))
            | 'Parse JSON payload' >>
                beam.Map(json.loads)
            | 'Normalize' >> beam.FlatMap(normalize_data) 
            | "Fixed-size windows" >>
                beam.WindowInto(window.FixedWindows(60))
        )
        
        de_identified_messages = (
            messages
            | "Batching" >> BatchElements(
                min_batch_size=known_args.batch_size,
                max_batch_size=known_args.batch_size
            )
            | 'Convert dicts to table' >>
                beam.Map(from_list_dicts_to_table)
            | 'Call DLP de-identification' >>
            MaskDetectedDetails(
                project=known_args.dlp_project,
                location=known_args.dlp_location,
                template_name=known_args.deidentification_template_name
            )
            | 'Convert table to dicts' >>
                beam.FlatMap(from_table_to_list_dict)
            | 'Calculate Total Bytes' >> ParDo(CalculateTotalBytes(known_args.output_table))
        )

        de_identified_messages | 'Write to BQ' >> beam.io.WriteToBigQuery(
            known_args.output_table,
            schema=known_args.bq_schema,
            create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
            method="STREAMING_INSERTS"
        )