perfkitbenchmarker/providers/aws/aws_dpb_emr.py (748 lines of code) (raw):

# Copyright 2022 PerfKitBenchmarker Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module containing class for AWS's EMR services. Clusters can be created and deleted. """ import collections import dataclasses import gzip import json import logging import os from typing import Any, Dict from absl import flags from perfkitbenchmarker import disk from perfkitbenchmarker import dpb_constants from perfkitbenchmarker import dpb_service from perfkitbenchmarker import errors from perfkitbenchmarker import provider_info from perfkitbenchmarker import temp_dir from perfkitbenchmarker import vm_util from perfkitbenchmarker.providers.aws import aws_disk from perfkitbenchmarker.providers.aws import aws_dpb_emr_serverless_prices from perfkitbenchmarker.providers.aws import aws_network from perfkitbenchmarker.providers.aws import aws_virtual_machine from perfkitbenchmarker.providers.aws import s3 from perfkitbenchmarker.providers.aws import util FLAGS = flags.FLAGS flags.DEFINE_string( 'dpb_emr_release_label', None, 'DEPRECATED use dpb_service.version.' ) INVALID_STATES = ['TERMINATED_WITH_ERRORS', 'TERMINATED'] READY_CHECK_SLEEP = 30 READY_CHECK_TRIES = 60 READY_STATE = 'WAITING' JOB_WAIT_SLEEP = 30 EMR_TIMEOUT = 14400 disk_to_hdfs_map = { aws_disk.ST1: 'HDD (ST1)', aws_disk.GP2: 'SSD (GP2)', disk.LOCAL: 'Local SSD', } DATAPROC_TO_EMR_CONF_FILES = { # https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html 'core': 'core-site', 'hdfs': 'hdfs-site', # https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-spark-configure.html 'spark': 'spark-defaults', } def _GetClusterConfiguration(cluster_properties: list[str]) -> str: """Return a JSON string containing dpb_cluster_properties.""" properties = collections.defaultdict(dict) for entry in cluster_properties: file, kv = entry.split(':') key, value = kv.split('=') if file not in DATAPROC_TO_EMR_CONF_FILES: raise errors.Config.InvalidValue( 'Unsupported EMR configuration file "{}". '.format(file) + 'Please add it to aws_dpb_emr.DATAPROC_TO_EMR_CONF_FILES.' ) properties[DATAPROC_TO_EMR_CONF_FILES[file]][key] = value json_conf = [] for file, props in properties.items(): json_conf.append({ # https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html 'Classification': file, 'Properties': props, }) return json.dumps(json_conf) class EMRRetryableException(Exception): pass @dataclasses.dataclass(frozen=True) class _AwsDpbEmrServerlessJobRun: """Holds one EMR Serverless job run info, such as IDs and usage stats. Attributes: application_id: The application ID the job run belongs to. job_run_id: The job run's ID. region: Job run's region. memory_gb_hour: RAM GB * hour used. storage_gb_hour: Shuffle storage GB * hour used. vcpu_hour: vCPUs * hour used. """ application_id: str | None = None job_run_id: str | None = None region: str | None = None memory_gb_hour: float | None = None storage_gb_hour: float | None = None vcpu_hour: float | None = None def __bool__(self): """Returns if this represents a job run or is just a dummy placeholder.""" return None not in (self.application_id, self.job_run_id, self.region) def HasStats(self): """Returns whether there are stats collected for the job run.""" return None not in ( self.memory_gb_hour, self.storage_gb_hour, self.vcpu_hour, ) def ComputeJobRunCost(self) -> dpb_service.JobCosts: """Computes the cost of a run for region given this usage.""" if not self.HasStats(): return dpb_service.JobCosts() region_prices = aws_dpb_emr_serverless_prices.EMR_SERVERLESS_PRICES.get( self.region, {} ) memory_gb_hour_price = region_prices.get('memory_gb_hours') storage_gb_hour_price = region_prices.get('storage_gb_hours') vcpu_hour_price = region_prices.get('vcpu_hours') if ( memory_gb_hour_price is None or storage_gb_hour_price is None or vcpu_hour_price is None ): return dpb_service.JobCosts() vcpu_cost = self.vcpu_hour * vcpu_hour_price memory_cost = self.memory_gb_hour * memory_gb_hour_price storage_cost = self.storage_gb_hour * storage_gb_hour_price return dpb_service.JobCosts( total_cost=vcpu_cost + memory_cost + storage_cost, compute_cost=vcpu_cost, memory_cost=memory_cost, storage_cost=storage_cost, compute_units_used=self.vcpu_hour, memory_units_used=self.memory_gb_hour, storage_units_used=self.storage_gb_hour, compute_unit_cost=vcpu_hour_price, memory_unit_cost=memory_gb_hour_price, storage_unit_cost=storage_gb_hour_price, compute_unit_name='vCPU*hr', memory_unit_name='GB*hr', storage_unit_name='GB*hr', ) class AwsDpbEmr(dpb_service.BaseDpbService): """Object representing a AWS EMR cluster. Attributes: cluster_id: ID of the cluster. project: ID of the project in which the cluster is being launched. dpb_service_type: Set to 'emr'. cmd_prefix: Setting default prefix for the emr commands (region optional). network: Dedicated network for the EMR cluster storage_service: Region specific instance of S3 for bucket management. bucket_to_delete: Cluster associated bucket to be cleaned up. dpb_version: EMR version to use. """ CLOUD = provider_info.AWS SERVICE_TYPE = 'emr' SUPPORTS_NO_DYNALLOC = True def __init__(self, dpb_service_spec): super().__init__(dpb_service_spec) self.project = None self.cmd_prefix = list(util.AWS_PREFIX) if self.dpb_service_zone: self.region = util.GetRegionFromZone(self.dpb_service_zone) else: raise errors.Setup.InvalidSetupError( 'dpb_service_zone must be provided, for provisioning.' ) self.cmd_prefix += ['--region', self.region] self.network = aws_network.AwsNetwork.GetNetworkFromNetworkSpec( aws_network.AwsNetworkSpec(zone=self.dpb_service_zone) ) self.storage_service = s3.S3Service() self.storage_service.PrepareService(self.region) self.persistent_fs_prefix = 's3://' self.bucket_to_delete = None self._cluster_create_time: float | None = None self._cluster_ready_time: float | None = None self._cluster_delete_time: float | None = None if not self.GetDpbVersion(): raise errors.Setup.InvalidSetupError( 'dpb_service.version must be provided.' ) def GetDpbVersion(self) -> str | None: return FLAGS.dpb_emr_release_label or super().GetDpbVersion() def GetClusterCreateTime(self) -> float | None: """Returns the cluster creation time. On this implementation, the time returned is based on the timestamps reported by the EMR API (which is stored in the _cluster_create_time attribute). Returns: A float representing the creation time in seconds or None. """ return self._cluster_ready_time - self._cluster_create_time @property def security_group_id(self): """Returns the security group ID of this Cluster.""" return self.network.regional_network.vpc.default_security_group_id def _CreateDependencies(self): """Set up the ssh key.""" super()._CreateDependencies() aws_virtual_machine.AwsKeyFileManager.ImportKeyfile(self.region) def Create(self, restore: bool = False) -> None: """Overrides parent to register creation timeout as KNOWN_INTERMITTENT.""" try: super().Create() except vm_util.RetryError as e: raise errors.Resource.ProvisionTimeoutError from e def _Create(self): """Creates the cluster.""" name = 'pkb_' + FLAGS.run_uri # Set up ebs details if disk_spec is present in the config ebs_configuration = None if self.spec.worker_group.disk_spec: # Make sure nothing we are ignoring is included in the disk spec assert self.spec.worker_group.disk_spec.device_path is None assert self.spec.worker_group.disk_spec.disk_number is None assert self.spec.worker_group.disk_spec.provisioned_iops is None if self.spec.worker_group.disk_spec.disk_type != disk.LOCAL: ebs_configuration = { 'EbsBlockDeviceConfigs': [{ 'VolumeSpecification': { 'SizeInGB': self.spec.worker_group.disk_spec.disk_size, 'VolumeType': self.spec.worker_group.disk_spec.disk_type, }, 'VolumesPerInstance': self.spec.worker_group.disk_count, }] } # Create the specification for the master and the worker nodes instance_groups = [] core_instances = { 'InstanceCount': self.spec.worker_count, 'InstanceGroupType': 'CORE', 'InstanceType': self.spec.worker_group.vm_spec.machine_type, } if ebs_configuration: core_instances.update({'EbsConfiguration': ebs_configuration}) master_instance = { 'InstanceCount': 1, 'InstanceGroupType': 'MASTER', 'InstanceType': self.spec.worker_group.vm_spec.machine_type, } if ebs_configuration: master_instance.update({'EbsConfiguration': ebs_configuration}) instance_groups.append(core_instances) instance_groups.append(master_instance) # Spark SQL needs to access Hive cmd = self.cmd_prefix + [ 'emr', 'create-cluster', '--name', name, '--release-label', self.GetDpbVersion(), '--use-default-roles', '--instance-groups', json.dumps(instance_groups), '--application', 'Name=Spark', 'Name=Hadoop', 'Name=Hive', '--log-uri', self.base_dir, ] ec2_attributes = [ 'KeyName=' + aws_virtual_machine.AwsKeyFileManager.GetKeyNameForRun(), 'SubnetId=' + self.network.subnet.id, # Place all VMs in default security group for simplicity and speed of # provisioning 'EmrManagedMasterSecurityGroup=' + self.security_group_id, 'EmrManagedSlaveSecurityGroup=' + self.security_group_id, ] cmd += ['--ec2-attributes', ','.join(ec2_attributes)] if self.GetClusterProperties(): cmd += [ '--configurations', _GetClusterConfiguration(self.GetClusterProperties()), ] stdout, _, _ = vm_util.IssueCommand(cmd) result = json.loads(stdout) self.cluster_id = result['ClusterId'] logging.info('Cluster created with id %s', self.cluster_id) self._AddTags(util.MakeDefaultTags()) def _AddTags(self, tags_dict: dict[str, str]): tag_args = [f'{key}={value}' for key, value in tags_dict.items()] cmd = ( self.cmd_prefix + ['emr', 'add-tags', '--resource-id', self.cluster_id, '--tags'] + tag_args ) try: vm_util.IssueCommand(cmd) except errors.VmUtil.IssueCommandError as e: error_message = str(e) if 'ThrottlingException' in error_message: raise errors.Benchmarks.QuotaFailure.RateLimitExceededError( error_message ) from e raise def _Delete(self): if self.cluster_id: delete_cmd = self.cmd_prefix + [ 'emr', 'terminate-clusters', '--cluster-ids', self.cluster_id, ] vm_util.IssueCommand(delete_cmd, raise_on_failure=False) def _DeleteDependencies(self): super()._DeleteDependencies() aws_virtual_machine.AwsKeyFileManager.DeleteKeyfile(self.region) def _Exists(self): """Check to see whether the cluster exists.""" if not self.cluster_id: return False cmd = self.cmd_prefix + [ 'emr', 'describe-cluster', '--cluster-id', self.cluster_id, ] stdout, _, retcode = vm_util.IssueCommand(cmd, raise_on_failure=False) if retcode != 0: return False result = json.loads(stdout) end_datetime = ( result.get('Cluster', {}) .get('Status', {}) .get('Timeline', {}) .get('EndDateTime') ) if end_datetime is not None: self._cluster_delete_time = end_datetime if result['Cluster']['Status']['State'] in INVALID_STATES: return False else: return True def _IsReady(self): """Check to see if the cluster is ready.""" logging.info('Checking _Ready cluster: %s', self.cluster_id) cmd = self.cmd_prefix + [ 'emr', 'describe-cluster', '--cluster-id', self.cluster_id, ] stdout, _, _ = vm_util.IssueCommand(cmd) result = json.loads(stdout) # TODO(saksena): Handle error outcomees when spinning up emr clusters is_ready = result['Cluster']['Status']['State'] == READY_STATE if is_ready: self._cluster_create_time, self._cluster_ready_time = ( self._ParseClusterCreateTime(result) ) return is_ready @classmethod def _ParseClusterCreateTime( cls, data: dict[str, Any] ) -> tuple[float | None, float | None]: """Parses the cluster create & ready time from an API response dict.""" try: creation_ts = data['Cluster']['Status']['Timeline']['CreationDateTime'] ready_ts = data['Cluster']['Status']['Timeline']['ReadyDateTime'] return creation_ts, ready_ts except (LookupError, TypeError): return None, None def _GetCompletedJob(self, job_id): """See base class.""" cmd = self.cmd_prefix + [ 'emr', 'describe-step', '--cluster-id', self.cluster_id, '--step-id', job_id, ] stdout, stderr, retcode = vm_util.IssueCommand(cmd, raise_on_failure=False) if retcode: if 'ThrottlingException' in stderr: logging.warning( 'Rate limited while polling EMR step:\n%s\nRetrying.', stderr ) return None else: raise errors.VmUtil.IssueCommandError( f'Getting step status failed:\n{stderr}' ) result = json.loads(stdout) state = result['Step']['Status']['State'] if state == 'FAILED': raise dpb_service.JobSubmissionError( result['Step']['Status']['FailureDetails'] ) if state == 'COMPLETED': pending_time = result['Step']['Status']['Timeline']['CreationDateTime'] start_time = result['Step']['Status']['Timeline']['StartDateTime'] end_time = result['Step']['Status']['Timeline']['EndDateTime'] return dpb_service.JobResult( run_time=end_time - start_time, pending_time=start_time - pending_time ) def SubmitJob( self, jarfile=None, classname=None, pyspark_file=None, query_file=None, job_poll_interval=None, job_arguments=None, job_files=None, job_jars=None, job_py_files=None, job_stdout_file=None, job_type=None, properties=None, ): """See base class.""" if job_arguments: # Escape commas in arguments job_arguments = (arg.replace(',', '\\,') for arg in job_arguments) all_properties = self.GetJobProperties() all_properties.update(properties or {}) if job_type == 'hadoop': if not (jarfile or classname): raise ValueError('You must specify jarfile or classname.') if jarfile and classname: raise ValueError('You cannot specify both jarfile and classname.') arg_list = [] # Order is important if classname: # EMR does not support passing classnames as jobs. Instead manually # invoke `hadoop CLASSNAME` using command-runner.jar jarfile = 'command-runner.jar' arg_list = ['hadoop', classname] # Order is important arg_list += ['-D{}={}'.format(k, v) for k, v in all_properties.items()] if job_arguments: arg_list += job_arguments arg_spec = 'Args=[' + ','.join(arg_list) + ']' step_list = ['Jar=' + jarfile, arg_spec] elif job_type == dpb_constants.SPARK_JOB_TYPE: arg_list = [] if job_files: arg_list += ['--files', ','.join(job_files)] if job_py_files: arg_list += ['--py-files', ','.join(job_py_files)] if job_jars: arg_list += ['--jars', ','.join(job_jars)] for k, v in all_properties.items(): arg_list += ['--conf', '{}={}'.format(k, v)] # jarfile must be last before args arg_list += ['--class', classname, jarfile] if job_arguments: arg_list += job_arguments arg_spec = '[' + ','.join(arg_list) + ']' step_type_spec = 'Type=Spark' step_list = [step_type_spec, 'Args=' + arg_spec] elif job_type == dpb_constants.PYSPARK_JOB_TYPE: arg_list = [] if job_files: arg_list += ['--files', ','.join(job_files)] if job_jars: arg_list += ['--jars', ','.join(job_jars)] for k, v in all_properties.items(): arg_list += ['--conf', '{}={}'.format(k, v)] # pyspark_file must be last before args arg_list += [pyspark_file] if job_arguments: arg_list += job_arguments arg_spec = 'Args=[{}]'.format(','.join(arg_list)) step_list = ['Type=Spark', arg_spec] elif job_type == dpb_constants.SPARKSQL_JOB_TYPE: assert not job_arguments arg_list = [query_file] jar_spec = 'Jar="command-runner.jar"' for k, v in all_properties.items(): arg_list += ['--conf', '{}={}'.format(k, v)] arg_spec = 'Args=[spark-sql,-f,{}]'.format(','.join(arg_list)) step_list = [jar_spec, arg_spec] step_string = ','.join(step_list) step_cmd = self.cmd_prefix + [ 'emr', 'add-steps', '--cluster-id', self.cluster_id, '--steps', step_string, ] stdout, _, _ = vm_util.IssueCommand(step_cmd) result = json.loads(stdout) step_id = result['StepIds'][0] return self._WaitForJob(step_id, EMR_TIMEOUT, job_poll_interval) def DistributedCopy(self, source, destination, properties=None): """Method to copy data using a distributed job on the cluster.""" job_arguments = ['s3-dist-cp'] job_arguments.append('--src={}'.format(source)) job_arguments.append('--dest={}'.format(destination)) return self.SubmitJob( 'command-runner.jar', job_arguments=job_arguments, job_type=dpb_constants.HADOOP_JOB_TYPE, ) def GetHdfsType(self) -> str | None: """Gets human friendly disk type for metric metadata.""" try: return disk_to_hdfs_map[self.spec.worker_group.disk_spec.disk_type] except KeyError: raise errors.Setup.InvalidSetupError( f'Invalid disk_type={self.spec.worker_group.disk_spec.disk_type!r} in' ' spec.' ) from None def _FetchLogs(self, step_id: str) -> str | None: local_stdout_path = os.path.join( temp_dir.GetRunDirPath(), f'emr_{self.cluster_id}_{step_id}.stdout.gz', ) get_stdout_cmd = self.cmd_prefix + [ 's3', 'cp', os.path.join( self.base_dir, f'{self.cluster_id}/steps/{step_id}/stderr.gz' ), local_stdout_path, ] _, _, _ = vm_util.IssueCommand(get_stdout_cmd) with gzip.open(local_stdout_path, 'rt') as f: stdout = f.read() return stdout class AwsDpbEmrServerless( dpb_service.DpbServiceServerlessMixin, dpb_service.BaseDpbService ): """Resource that allows spawning EMR Serverless Jobs. Pre-initialization capacity is not supported yet. Docs: https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html """ CLOUD = provider_info.AWS SERVICE_TYPE = 'emr_serverless' def __init__(self, dpb_service_spec): # TODO(odiego): Refactor the AwsDpbEmr and AwsDpbEmrServerless into a # hierarchy or move common code to a parent class. super().__init__(dpb_service_spec) self.project = None self.cmd_prefix = list(util.AWS_PREFIX) if self.dpb_service_zone: self.region = util.GetRegionFromZone(self.dpb_service_zone) else: raise errors.Setup.InvalidSetupError( 'dpb_service_zone must be provided, for provisioning.' ) self.cmd_prefix += ['--region', self.region] self.storage_service = s3.S3Service() self.storage_service.PrepareService(self.region) self.persistent_fs_prefix = 's3://' self._cluster_create_time = None if not self.GetDpbVersion(): raise errors.Setup.InvalidSetupError( 'dpb_service.version must be provided. Versions follow the format: ' '"emr-x.y.z" and are listed at ' 'https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/' 'release-versions.html' ) self.role = FLAGS.aws_emr_serverless_role # Last job usage info self._job_run = _AwsDpbEmrServerlessJobRun() self._FillMetadata() def SubmitJob( self, jarfile=None, classname=None, pyspark_file=None, query_file=None, job_poll_interval=None, job_arguments=None, job_files=None, job_jars=None, job_py_files=None, job_stdout_file=None, job_type=None, properties=None, ): """See base class.""" assert job_type # Set vars according to job type. if job_type == dpb_constants.PYSPARK_JOB_TYPE: application_type = 'SPARK' spark_props = self.GetJobProperties() if job_py_files: spark_props['spark.submit.pyFiles'] = ','.join(job_py_files) job_driver_dict = { 'sparkSubmit': { 'entryPoint': pyspark_file, 'entryPointArguments': job_arguments, 'sparkSubmitParameters': ' '.join( f'--conf {prop}={val}' for prop, val in spark_props.items() ), } } else: raise NotImplementedError( f'Unsupported job type {job_type} for AWS EMR Serverless.' ) s3_monitoring_config = { 's3MonitoringConfiguration': { 'logUri': os.path.join(self.base_dir, 'logs') } } # Create the application. stdout, _, _ = vm_util.IssueCommand( self.cmd_prefix + [ 'emr-serverless', 'create-application', '--release-label', self.GetDpbVersion(), '--name', self.cluster_id, '--type', application_type, '--tags', json.dumps(util.MakeDefaultTags()), '--monitoring-configuration', json.dumps(s3_monitoring_config), ] ) result = json.loads(stdout) application_id = result['applicationId'] @vm_util.Retry( poll_interval=job_poll_interval, fuzz=0, retryable_exceptions=(EMRRetryableException,), ) def WaitTilApplicationReady(): result = self._GetApplication(application_id) if result['application']['state'] not in ('CREATED', 'STARTED'): raise EMRRetryableException( f'Application {application_id} not ready yet.' ) return result WaitTilApplicationReady() # Run the job. stdout, _, _ = vm_util.IssueCommand( self.cmd_prefix + [ 'emr-serverless', 'start-job-run', '--application-id', application_id, '--execution-role-arn', self.role, '--job-driver', json.dumps(job_driver_dict), ] ) result = json.loads(stdout) self._job_run = _AwsDpbEmrServerlessJobRun( application_id=result['applicationId'], job_run_id=result['jobRunId'], region=self.region, ) return self._WaitForJob(self._job_run, EMR_TIMEOUT, job_poll_interval) def CalculateLastJobCosts(self) -> dpb_service.JobCosts: @vm_util.Retry( fuzz=0, retryable_exceptions=(EMRRetryableException,), ) def WaitTilUsageMetricsAvailable(): self._CallGetJobRunApi(self._job_run) if not self._job_run.HasStats(): raise EMRRetryableException( 'Usage metrics not ready yet for EMR Serverless ' f'application_id={self._job_run.application_id!r} ' f'job_run_id={self._job_run.job_run_id!r}' ) if not self._job_run: return _AwsDpbEmrServerlessJobRun().ComputeJobRunCost() if not self._job_run.HasStats(): try: WaitTilUsageMetricsAvailable() except vm_util.TimeoutExceededRetryError: logging.warning('Timeout exceeded for retrieving usage metrics.') return self._job_run.ComputeJobRunCost() def GetJobProperties(self) -> Dict[str, str]: result = {'spark.dynamicAllocation.enabled': 'FALSE'} if self.spec.emr_serverless_core_count: result['spark.executor.cores'] = self.spec.emr_serverless_core_count result['spark.driver.cores'] = self.spec.emr_serverless_core_count if self.spec.emr_serverless_memory: result['spark.executor.memory'] = f'{self.spec.emr_serverless_memory}G' if self.spec.emr_serverless_executor_count: result['spark.executor.instances'] = ( self.spec.emr_serverless_executor_count ) if self.spec.worker_group.disk_spec.disk_size: result['spark.emr-serverless.driver.disk'] = ( f'{self.spec.worker_group.disk_spec.disk_size}G' ) result['spark.emr-serverless.executor.disk'] = ( f'{self.spec.worker_group.disk_spec.disk_size}G' ) result.update(super().GetJobProperties()) return result def _GetApplication(self, application_id): stdout, _, _ = vm_util.IssueCommand( self.cmd_prefix + [ 'emr-serverless', 'get-application', '--application-id', application_id, ] ) result = json.loads(stdout) return result def _ComputeJobRunCost( self, memory_gb_hour: float, storage_gb_hour: float, vcpu_hour: float ) -> dpb_service.JobCosts: region_prices = aws_dpb_emr_serverless_prices.EMR_SERVERLESS_PRICES.get( self.region, {} ) memory_gb_hour_price = region_prices.get('memory_gb_hours') storage_gb_hour_price = region_prices.get('storage_gb_hours') vcpu_hour_price = region_prices.get('vcpu_hours') if ( memory_gb_hour_price is None or storage_gb_hour_price is None or vcpu_hour_price is None ): return dpb_service.JobCosts() vcpu_cost = vcpu_hour * vcpu_hour_price memory_cost = memory_gb_hour * memory_gb_hour_price storage_cost = storage_gb_hour * storage_gb_hour_price return dpb_service.JobCosts( total_cost=vcpu_cost + memory_cost + storage_cost, compute_cost=vcpu_cost, memory_cost=memory_cost, storage_cost=storage_cost, compute_units_used=vcpu_hour, memory_units_used=memory_gb_hour, storage_units_used=storage_gb_hour, compute_unit_cost=vcpu_hour_price, memory_unit_cost=memory_gb_hour_price, storage_unit_cost=storage_gb_hour_price, compute_unit_name='vCPU*hr', memory_unit_name='GB*hr', storage_unit_name='GB*hr', ) def _GetCompletedJob(self, job_run): """See base class.""" return self._CallGetJobRunApi(job_run) def _CallGetJobRunApi( self, job_run: _AwsDpbEmrServerlessJobRun ) -> dpb_service.JobResult | None: """Performs EMR Serverless GetJobRun API call.""" cmd = self.cmd_prefix + [ 'emr-serverless', 'get-job-run', '--application-id', job_run.application_id, '--job-run-id', job_run.job_run_id, ] stdout, stderr, retcode = vm_util.IssueCommand(cmd, raise_on_failure=False) if retcode: if 'ThrottlingException' in stderr: logging.warning( 'Rate limited while polling EMR JobRun:\n%s\nRetrying.', stderr ) return None raise errors.VmUtil.IssueCommandError( f'Getting JobRun status failed:\n{stderr}' ) result = json.loads(stdout) state = result['jobRun']['state'] if state in ('FAILED', 'CANCELLED'): raise dpb_service.JobSubmissionError(result['jobRun'].get('stateDetails')) if state == 'SUCCESS': start_time = result['jobRun']['createdAt'] end_time = result['jobRun']['updatedAt'] self._job_run = self._ParseCostMetrics(result) return dpb_service.JobResult( run_time=end_time - start_time, fetch_output_fn=lambda: ( self._FetchLogs(job_run.application_id, job_run.job_run_id), None, ), ) def _FetchLogs(self, application_id: str, job_run_id: str) -> str | None: local_stdout_path = os.path.join( temp_dir.GetRunDirPath(), f'emrs8s_{application_id}_{job_run_id}.stdout.gz', ) get_stdout_cmd = self.cmd_prefix + [ 's3', 'cp', os.path.join( self.base_dir, f'logs/applications/{application_id}/jobs/{job_run_id}/' 'SPARK_DRIVER/stdout.gz', ), local_stdout_path, ] _, _, _ = vm_util.IssueCommand(get_stdout_cmd) with gzip.open(local_stdout_path, 'rt') as f: stdout = f.read() return stdout def _FillMetadata(self) -> None: """Gets a dict to initialize this DPB service instance's metadata.""" basic_data = self.metadata dpb_disk_size = self.spec.worker_group.disk_spec.disk_size or 'default' core_count = str(self.spec.emr_serverless_core_count) or 'default' cluster_shape = f'emr-serverless-{core_count}' cluster_size = str(self.spec.emr_serverless_executor_count) or 'default' self.metadata = { 'dpb_service': basic_data['dpb_service'], 'dpb_version': basic_data['dpb_version'], 'dpb_service_version': basic_data['dpb_service_version'], 'dpb_cluster_shape': cluster_shape, 'dpb_cluster_size': cluster_size, 'dpb_hdfs_type': basic_data['dpb_hdfs_type'], 'dpb_memory_per_node': self.spec.emr_serverless_memory or 'default', 'dpb_disk_size': dpb_disk_size, 'dpb_service_zone': basic_data['dpb_service_zone'], 'dpb_job_properties': basic_data['dpb_job_properties'], } def GetHdfsType(self) -> str | None: """Gets human friendly disk type for metric metadata.""" return 'default-disk' def _ParseCostMetrics( self, get_job_run_result: dict[Any, Any] ) -> _AwsDpbEmrServerlessJobRun: """Parses usage metrics from an EMR s8s GetJobRun API response.""" resource_utilization = get_job_run_result.get('jobRun', {}).get( 'totalResourceUtilization', {} ) return dataclasses.replace( self._job_run, memory_gb_hour=resource_utilization.get('memoryGBHour'), storage_gb_hour=resource_utilization.get('storageGBHour'), vcpu_hour=resource_utilization.get('vCPUHour'), )