awswrangler/emr_serverless.py (143 lines of code) (raw):
"""EMR Serverless module."""
from __future__ import annotations
import logging
import pprint
import time
from typing import Any, Literal, TypedDict
import boto3
from typing_extensions import NotRequired, Required
from awswrangler import _utils, exceptions
from awswrangler._config import apply_configs
from awswrangler.annotations import Experimental
_logger: logging.Logger = logging.getLogger(__name__)
_EMR_SERVERLESS_JOB_WAIT_POLLING_DELAY: float = 5 # SECONDS
_EMR_SERVERLESS_JOB_FINAL_STATES: list[str] = ["SUCCESS", "FAILED", "CANCELLED"]
class SparkSubmitJobArgs(TypedDict):
"""Typed dictionary defining the Spark submit job arguments."""
entryPoint: Required[str]
"""The entry point for the Spark submit job run."""
entryPointArguments: NotRequired[list[str]]
"""The arguments for the Spark submit job run."""
sparkSubmitParameters: NotRequired[str]
"""The parameters for the Spark submit job run."""
class HiveRunJobArgs(TypedDict):
"""Typed dictionary defining the Hive job run arguments."""
query: Required[str]
"""The S3 location of the query file for the Hive job run."""
initQueryFile: NotRequired[str]
"""The S3 location of the query file for the Hive job run."""
parameters: NotRequired[str]
"""The parameters for the Hive job run."""
@Experimental
def create_application(
name: str,
release_label: str,
application_type: Literal["Spark", "Hive"] = "Spark",
initial_capacity: dict[str, str] | None = None,
maximum_capacity: dict[str, str] | None = None,
tags: dict[str, str] | None = None,
autostart: bool = True,
autostop: bool = True,
idle_timeout: int = 15,
network_configuration: dict[str, str] | None = None,
architecture: Literal["ARM64", "X86_64"] = "X86_64",
image_uri: str | None = None,
worker_type_specifications: dict[str, str] | None = None,
boto3_session: boto3.Session | None = None,
) -> str:
"""
Create an EMR Serverless application.
https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html
Parameters
----------
name
Name of EMR Serverless appliation
release_label
Release label e.g. `emr-6.10.0`
application_type
Application type: "Spark" or "Hive". Defaults to "Spark".
initial_capacity
The capacity to initialize when the application is created.
maximum_capacity
The maximum capacity to allocate when the application is created.
This is cumulative across all workers at any given point in time,
not just when an application is created. No new resources will
be created once any one of the defined limits is hit.
tags
Key/Value collection to put tags on the application.
e.g. {"foo": "boo", "bar": "xoo"})
autostart
Enables the application to automatically start on job submission. Defaults to true.
autostop
Enables the application to automatically stop after a certain amount of time being idle. Defaults to true.
idle_timeout
The amount of idle time in minutes after which your application will automatically stop. Defaults to 15 minutes.
network_configuration
The network configuration for customer VPC connectivity.
architecture
The CPU architecture of an application: "ARM64" or "X86_64". Defaults to "X86_64".
image_uri
The URI of an image in the Amazon ECR registry.
worker_type_specifications
The key-value pairs that specify worker type.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
Returns
-------
Application Id.
"""
emr_serverless = _utils.client(service_name="emr-serverless", session=boto3_session)
application_args: dict[str, Any] = {
"name": name,
"releaseLabel": release_label,
"type": application_type,
"autoStartConfiguration": {
"enabled": autostart,
},
"autoStopConfiguration": {
"enabled": autostop,
"idleTimeoutMinutes": idle_timeout,
},
"architecture": architecture,
}
if initial_capacity:
application_args["initialCapacity"] = initial_capacity
if maximum_capacity:
application_args["maximumCapacity"] = maximum_capacity
if tags:
application_args["tags"] = tags
if network_configuration:
application_args["networkConfiguration"] = network_configuration
if worker_type_specifications:
application_args["workerTypeSpecifications"] = worker_type_specifications
if image_uri:
application_args["imageConfiguration"] = {
"imageUri": image_uri,
}
response: dict[str, str] = emr_serverless.create_application(**application_args) # type: ignore[assignment]
_logger.debug("response: \n%s", pprint.pformat(response))
return response["applicationId"]
@Experimental
@apply_configs
def run_job(
application_id: str,
execution_role_arn: str,
job_driver_args: dict[str, Any] | SparkSubmitJobArgs | HiveRunJobArgs,
job_type: Literal["Spark", "Hive"] = "Spark",
wait: bool = True,
configuration_overrides: dict[str, Any] | None = None,
tags: dict[str, str] | None = None,
execution_timeout: int | None = None,
name: str | None = None,
emr_serverless_job_wait_polling_delay: float = _EMR_SERVERLESS_JOB_WAIT_POLLING_DELAY,
boto3_session: boto3.Session | None = None,
) -> str | dict[str, Any]:
"""
Run an EMR serverless job.
https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html
Parameters
----------
application_id
The id of the application on which to run the job.
execution_role_arn
The execution role ARN for the job run.
job_driver_args
The job driver arguments for the job run.
job_type
Type of the job: "Spark" or "Hive". Defaults to "Spark".
wait
Whether to wait for the job completion or not. Defaults to true.
configuration_overrides
The configuration overrides for the job run.
tags
Key/Value collection to put tags on the application.
e.g. {"foo": "boo", "bar": "xoo"})
execution_timeout
The maximum duration for the job run to run. If the job run runs beyond this duration,
it will be automatically cancelled.
name
Name of the job.
emr_serverless_job_wait_polling_delay
Time to wait between polling attempts.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
Returns
-------
Job Id if wait=False, or job run details.
"""
emr_serverless = _utils.client(service_name="emr-serverless", session=boto3_session)
job_args: dict[str, Any] = {
"applicationId": application_id,
"executionRoleArn": execution_role_arn,
}
if job_type == "Spark":
job_args["jobDriver"] = {
"sparkSubmit": job_driver_args,
}
elif job_type == "Hive":
job_args["jobDriver"] = {
"hive": job_driver_args,
}
else:
raise exceptions.InvalidArgumentValue(f"Unsupported job type `{job_type}`")
if configuration_overrides:
job_args["configurationOverrides"] = configuration_overrides
if tags:
job_args["tags"] = tags
if execution_timeout:
job_args["executionTimeoutMinutes"] = execution_timeout
if name:
job_args["name"] = name
response = emr_serverless.start_job_run(**job_args)
_logger.debug("Job run response: %s", response)
job_run_id: str = response["jobRunId"]
if wait:
return wait_job(
application_id=application_id,
job_run_id=job_run_id,
emr_serverless_job_wait_polling_delay=emr_serverless_job_wait_polling_delay,
)
return job_run_id
@Experimental
@apply_configs
def wait_job(
application_id: str,
job_run_id: str,
emr_serverless_job_wait_polling_delay: float = _EMR_SERVERLESS_JOB_WAIT_POLLING_DELAY,
boto3_session: boto3.Session | None = None,
) -> dict[str, Any]:
"""
Wait for the EMR Serverless job to finish.
https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html
Parameters
----------
application_id
The id of the application on which the job is running.
job_run_id
The id of the job.
emr_serverless_job_wait_polling_delay
Time to wait between polling attempts.
boto3_session
The default boto3 session will be used if **boto3_session** is ``None``.
Returns
-------
Job run details.
"""
emr_serverless = _utils.client(service_name="emr-serverless", session=boto3_session)
response = emr_serverless.get_job_run(
applicationId=application_id,
jobRunId=job_run_id,
)
state = response["jobRun"]["state"]
while state not in _EMR_SERVERLESS_JOB_FINAL_STATES:
time.sleep(emr_serverless_job_wait_polling_delay)
response = emr_serverless.get_job_run(
applicationId=application_id,
jobRunId=job_run_id,
)
state = response["jobRun"]["state"]
_logger.debug("Job state: %s", state)
if state != "SUCCESS":
_logger.debug("Job run response: %s", response)
raise exceptions.EMRServerlessJobError(response.get("jobRun", {}).get("stateDetails"))
return response # type: ignore[return-value]