def forEachBatch()

in awsglue/context.py [0:0]


    def forEachBatch(self, frame, batch_function, options = {}):
        if "windowSize" not in options:
            raise Exception("Missing windowSize argument")
        if "checkpointLocation" not in options:
            raise Exception("Missing checkpointLocation argument")

        windowSize = options["windowSize"]
        checkpointLocation = options["checkpointLocation"]

        # Check the Glue version
        glue_ver = self.getConf('spark.glue.GLUE_VERSION', '')
        java_import(self._jvm, "org.apache.spark.metrics.source.StreamingSource")

        # Converting the S3 scheme to S3a for the Glue Streaming checkpoint location in connector jars.
        # S3 scheme on checkpointLocation currently doesn't work on Glue 2.0 (non-EMR).
        # Will remove this once the connector package is imported as brazil package.
        if (glue_ver == '2.0' or glue_ver == '2' or glue_ver == '3.0' or glue_ver == '3'):
            if (checkpointLocation.startswith( 's3://' )):
                java_import(self._jvm, "com.amazonaws.regions.RegionUtils")
                java_import(self._jvm, "com.amazonaws.services.s3.AmazonS3")
                self._jsc.hadoopConfiguration().set("fs.s3a.endpoint", self._jvm.RegionUtils.getRegion(
                    self._jvm.AWSConnectionUtils.getRegion()).getServiceEndpoint(self._jvm.AmazonS3.ENDPOINT_PREFIX))
                checkpointLocation = checkpointLocation.replace( 's3://', 's3a://', 1)

        run = {'value': 0}
        retry_attempt = {'value': 0}

        def batch_function_with_persist(data_frame, batchId):

            # This condition is true when the previous batch succeeded
            if run['value'] > retry_attempt['value']:
                run['value'] = 0
                if retry_attempt['value'] > 0:
                    retry_attempt['value'] = 0
                    logging.warning("The batch is now succeeded. Resetting retry attempt counter to zero.")
            run['value'] += 1

            # process the batch
            startTime = self.currentTimeMillis()
            if "persistDataFrame" in options and options["persistDataFrame"].lower() == "false":
                if len(data_frame.take(1)):
                    batch_function(data_frame, batchId)
            else:
                storage_level = options.get("storageLevel", "MEMORY_AND_DISK").upper()
                data_frame.persist(getattr(pyspark.StorageLevel, storage_level))
                num_records = data_frame.count()
                if num_records > 0:
                    batch_function(data_frame, batchId)
                data_frame.unpersist()
                self._jvm.StreamingSource.updateNumRecords(num_records)
            self._jvm.StreamingSource.updateBatchProcessingTimeInMs(self.currentTimeMillis() - startTime)

        query = frame.writeStream.foreachBatch(batch_function_with_persist).trigger(processingTime=windowSize).option("checkpointLocation", checkpointLocation)

        batch_max_retries = int(options.get('batchMaxRetries', 3))
        if batch_max_retries < 0 or batch_max_retries > 100:
            raise ValueError('Please specify the number of retries as an integer in the range of [0, 100].')

        while (True):
            try:
                query.start().awaitTermination()
            except Exception as e:
                retry_attempt['value'] += 1
                logging.warning("StreamingQueryException caught. Retry number " + str(retry_attempt['value']))

                if retry_attempt['value'] > batch_max_retries:
                    logging.error("Exceeded maximuim number of retries in streaming interval, exception thrown")
                    raise e
                # lastFailedAttempt = failedTime
                backOffTime = retry_attempt['value'] if (retry_attempt['value'] < 3) else 5
                time.sleep(backOffTime)