src/sagemaker_core/main/logs.py (96 lines of code) (raw):
import boto3
import botocore
from boto3.session import Session
import botocore.client
from botocore.config import Config
from typing import Generator, Tuple, List
from sagemaker_core.main.utils import SingletonMeta
class CloudWatchLogsClient(metaclass=SingletonMeta):
"""
A singleton class for creating a CloudWatchLogs client.
"""
client: botocore.client = None
def __init__(self):
if not self.client:
session = Session()
self.client = session.client(
"logs",
session.region_name,
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
)
class LogStreamHandler:
log_group_name: str = None
log_stream_name: str = None
stream_id: int = None
next_token: str = None
cw_client = None
def __init__(self, log_group_name: str, log_stream_name: str, stream_id: int):
self.log_group_name = log_group_name
self.log_stream_name = log_stream_name
self.cw_client = CloudWatchLogsClient().client
self.stream_id = stream_id
def get_latest_log_events(self) -> Generator[Tuple[str, dict], None, None]:
"""
This method gets all the latest log events for this stream that exist at this moment in time.
cw_client.get_log_events() always returns a nextForwardToken even if the current batch of events is empty.
You can keep calling cw_client.get_log_events() with the same token until a new batch of log events exist.
API Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/logs/client/get_log_events.html
Returns:
Generator[tuple[str, dict], None, None]: Generator that yields a tuple that consists for two values
str: stream_name,
dict: event dict in format
{
"ingestionTime": number,
"message": "string",
"timestamp": number
}
"""
while True:
if not self.next_token:
token_args = {}
else:
token_args = {"nextToken": self.next_token}
response = self.cw_client.get_log_events(
logGroupName=self.log_group_name,
logStreamName=self.log_stream_name,
startFromHead=True,
**token_args,
)
self.next_token = response["nextForwardToken"]
if not response["events"]:
break
for event in response["events"]:
yield self.log_stream_name, event
class MultiLogStreamHandler:
log_group_name: str = None
log_stream_name_prefix: str = None
expected_stream_count: int = None
streams: List[LogStreamHandler] = []
cw_client = None
def __init__(
self, log_group_name: str, log_stream_name_prefix: str, expected_stream_count: int
):
self.log_group_name = log_group_name
self.log_stream_name_prefix = log_stream_name_prefix
self.expected_stream_count = expected_stream_count
self.cw_client = CloudWatchLogsClient().client
def get_latest_log_events(self) -> Generator[Tuple[str, dict], None, None]:
"""
This method gets all the latest log events from each stream that exist at this moment.
Returns:
Generator[tuple[str, dict], None, None]: Generator that yields a tuple that consists for two values
str: stream_name,
dict: event dict in format -
{
"ingestionTime": number,
"message": "string",
"timestamp": number
}
"""
if not self.ready():
return []
for stream in self.streams:
yield from stream.get_latest_log_events()
def ready(self) -> bool:
"""
Checks whether or not MultiLogStreamHandler is ready to serve new log events at this moment.
If self.streams is already set, return True.
Otherwise, check if the current number of log streams in the log group match the exptected stream count.
Returns:
bool: Whether or not MultiLogStreamHandler is ready to serve new log events.
"""
if len(self.streams) >= self.expected_stream_count:
return True
try:
response = self.cw_client.describe_log_streams(
logGroupName=self.log_group_name,
logStreamNamePrefix=self.log_stream_name_prefix + "/",
orderBy="LogStreamName",
)
stream_names = [stream["logStreamName"] for stream in response["logStreams"]]
next_token = response.get("nextToken")
while next_token:
response = self.cw_client.describe_log_streams(
logGroupName=self.log_group_name,
logStreamNamePrefix=self.log_stream_name_prefix + "/",
orderBy="LogStreamName",
nextToken=next_token,
)
stream_names.extend([stream["logStreamName"] for stream in response["logStreams"]])
next_token = response.get("nextToken", None)
if len(stream_names) >= self.expected_stream_count:
self.streams = [
LogStreamHandler(self.log_group_name, log_stream_name, index)
for index, log_stream_name in enumerate(stream_names)
]
return True
else:
# Log streams are created whenever a container starts writing to stdout/err,
# so if the stream count is less than the expected number, return False
return False
except botocore.exceptions.ClientError as e:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
if e.response["Error"]["Code"] == "ResourceNotFoundException":
return False
else:
raise