def run()

in python/dataproc_templates/pubsublite/pubsublite_to_bigtable.py [0:0]


    def run(self, spark: SparkSession, args: Dict[str, Any]) -> None:
        logger: Logger = self.get_logger(spark=spark)

        # Arguments
        subscription_path: str = args[constants.PUBSUBLITE_BIGTABLE_SUBSCRIPTION_PATH]
        timeout: int = args[constants.PUBSUBLITE_BIGTABLE_STREAMING_TIMEOUT]
        trigger: str = args[constants.PUBSUBLITE_BIGTABLE_STREAMING_TRIGGER]
        checkpoint_location: str = args[constants.PUBSUBLITE_BIGTABLE_STREAMING_CHECKPOINT_PATH]
        project: str = args[constants.PUBSUBLITE_BIGTABLE_OUTPUT_PROJECT]
        instance_id: str = args[constants.PUBSUBLITE_BIGTABLE_OUTPUT_INSTANCE]
        table_id: str = args[constants.PUBSUBLITE_BIGTABLE_OUTPUT_TABLE]
        column_families_list: str = args[constants.PUBSUBLITE_BIGTABLE_OUTPUT_COLUMN_FAMILIES]
        max_versions: str = args[constants.PUBSUBLITE_BIGTABLE_OUTPUT_MAX_VERSIONS]

        ignore_keys = {
            constants.PUBSUBLITE_BIGTABLE_SUBSCRIPTION_PATH,
            constants.PUBSUBLITE_BIGTABLE_STREAMING_CHECKPOINT_PATH,
            constants.PUBSUBLITE_BIGTABLE_OUTPUT_PROJECT,
        }
        filtered_args = {
            key: val for key, val in args.items() if key not in ignore_keys
        }
        logger.info(
            "Starting Pub/Sub Lite to Bigtable spark job with parameters:\n"
            f"{pprint.pformat(filtered_args)}"
        )

        # Read
        input_data: DataFrame

        input_data = (
            spark.readStream.format(constants.FORMAT_PUBSUBLITE)
            .option(constants.PUBSUBLITE_SUBSCRIPTION, subscription_path)
            .load()
        )

        input_data = input_data.withColumn("data", input_data.data.cast(StringType()))

        # Write
        options = {}
        if checkpoint_location:
            options = {constants.PUBSUBLITE_CHECKPOINT_LOCATION: checkpoint_location}

        client = Client(project=project, admin=True)
        table = self.get_table(
            client,
            instance_id,
            table_id,
            column_families_list,
            max_versions,
            logger,
        )

        def write_to_bigtable(batch_df: DataFrame, batch_id: int):
            self.populate_table(batch_df, table, logger)

        query = (
            input_data.writeStream.foreachBatch(write_to_bigtable)
            .options(**options)
            .trigger(processingTime=trigger)
            .start()
        )
        query.awaitTermination(timeout)
        query.stop()

        client.close()