src/slurm_plugin/console_logger.py (68 lines of code) (raw):
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the
# License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import Any, Callable, Iterable
import boto3
from slurm_plugin.common import ComputeInstanceDescriptor, TaskController
logger = logging.getLogger(__name__)
class ConsoleLogger:
"""Class for retrieving and logging instance console output."""
def __init__(self, enabled: bool, region: str, console_output_consumer: Callable[[str, str, str], None]):
self._region = region
self._console_logging_enabled = enabled
self._console_output_consumer = console_output_consumer
self._boto3_client_factory = lambda service_name: boto3.session.Session().client(
service_name, region_name=region
)
def report_console_output_from_nodes(
self,
compute_instances: Iterable[ComputeInstanceDescriptor],
task_controller: TaskController,
task_wait_function: Callable[[], None],
):
"""Queue a task that will retrieve the console output for failed compute nodes."""
if not self._console_logging_enabled:
return None
# Only schedule a task if we have any compute_instances to query. We also need to realize any lazy instance ID
# lookups before we schedule the task since the instance ID mapping may change after we return from this
# call but before the task is executed.
compute_instances = tuple(compute_instances)
if len(compute_instances) < 1:
return None
task = self._get_console_output_task(
raise_if_shutdown=task_controller.raise_if_shutdown,
task_wait_function=task_wait_function,
client_factory=self._boto3_client_factory,
compute_instances=compute_instances,
)
return task_controller.queue_task(task)
def _get_console_output_task(
self,
task_wait_function: Callable[[], None],
raise_if_shutdown: Callable[[], None],
client_factory: Callable[[str], Any],
compute_instances: Iterable[ComputeInstanceDescriptor],
):
def console_collector():
try:
# Sleep to allow EC2 time to publish the console output after the node terminates.
task_wait_function()
ec2client = client_factory("ec2")
for output in ConsoleLogger._get_console_output_from_nodes(ec2client, compute_instances):
# If shutdown, raise an exception so that any interested threads will know
# this task was not completed.
raise_if_shutdown()
self._console_output_consumer(
output.get("Name"),
output.get("InstanceId"),
output.get("ConsoleOutput"),
)
except Exception as e:
logger.error("Encountered exception while retrieving compute console output: %s", e)
raise
return console_collector
@staticmethod
def _get_console_output_from_nodes(ec2client, compute_instances):
pattern = re.compile(r"\r\n|\n")
for instance in compute_instances:
instance_name = instance.get("Name")
instance_id = instance.get("InstanceId")
logger.info("Retrieving Console Output for node %s (%s)", instance_id, instance_name)
response = ec2client.get_console_output(InstanceId=instance_id)
output = response.get("Output")
yield {
"Name": instance_name,
"InstanceId": instance_id,
"ConsoleOutput": pattern.sub("\r", output) if output else None,
}