scheduler/job_requester/requester.py (191 lines of code) (raw):
import json
import logging
import os
import re
import sys
import time
from datetime import datetime
from threading import Lock
from functools import cmp_to_key
import boto3
from job_requester import Message
MAX_TIMEOUT_IN_SEC = 5000
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.DEBUG)
LOGGER.addHandler(logging.StreamHandler(sys.stdout))
class JobRequester:
def __init__(self, timeout=MAX_TIMEOUT_IN_SEC):
self.s3_ticket_bucket = "dlc-test-tickets"
self.s3_ticket_bucket_folder = "request_tickets"
self.timeout_limit = min(timeout, MAX_TIMEOUT_IN_SEC)
self.s3_client = boto3.client("s3")
self.s3_resource = boto3.resource("s3")
self.ticket_name_counter = 0
self.request_lock = Lock()
def create_ticket_content(self, image, context, num_of_instances, request_time):
"""
Create content of the ticket to be sent to S3
:param image: <string> ECR URI
:param context: <string> build context (PR/MAINLINE/NIGHTLY/DEV)
:param num_of_instances: <int> number of instances required by the test job
:param request_time: <string> datetime timestamp of when request was made
:return: <dict> content of the request ticket
"""
content = {
"CONTEXT": context,
"TIMESTAMP": request_time,
"ECR-URI": image,
"SCHEDULING_TRIES": 0,
"INSTANCES_NUM": num_of_instances,
"TIMEOUT_LIMIT": self.timeout_limit,
"COMMIT": os.getenv("CODEBUILD_RESOLVED_SOURCE_VERSION", "default"),
}
return content
def get_ticket_name_prefix(self):
"""
Create a length 7 prefix for ticket name
:return: <string> prefix for request ticket name
"""
source_version = os.getenv("PR_NUMBER", "default")
if "pr/" in source_version:
# mod the PR ID by 100000 to make the prefix 7 digits
return f"pr{(int(source_version.split('/')[-1]) % 100000):05}"
else:
return source_version[:7]
def send_ticket(self, ticket_content, framework):
"""
Send a request ticket to S3 bucket, self.s3_ticket_bucket
Could run under multi-threading context, unique ticket name for each threads
:param ticket_content: <dict> content of the ticket
:return: <string> name of the ticket
"""
# ticket name: {CB source version}-{framework}{ticket name counter}_(datetime string)
ticket_name_prefix = self.get_ticket_name_prefix()
request_time = ticket_content["TIMESTAMP"]
self.request_lock.acquire()
ticket_name = (
f"{ticket_name_prefix}-{framework}{str(self.ticket_name_counter)}_{request_time}.json"
)
self.ticket_name_counter += 1
self.request_lock.release()
self.s3_client.put_object(
Bucket=self.s3_ticket_bucket,
Key=f"{self.s3_ticket_bucket_folder}/{ticket_name}",
)
S3_ticket_object = self.s3_resource.Object(
self.s3_ticket_bucket, f"{self.s3_ticket_bucket_folder}/{ticket_name}"
)
S3_ticket_object.put(Body=bytes(json.dumps(ticket_content).encode("UTF-8")))
try:
# change object acl to make ticket accessible to dev account.
self.s3_client.put_object_acl(
ACL="bucket-owner-full-control",
Bucket=self.s3_ticket_bucket,
Key=f"{self.s3_ticket_bucket_folder}/{ticket_name}",
)
except Exception as e:
raise e
LOGGER.info(f"Ticket sent successfully, ticket name: {ticket_name}")
return ticket_name
def assign_sagemaker_instance_type(self, image):
"""
Assign the instance type that the input image needs for testing
:param image: <string> ECR URI
:return: <string> type of instance used by the image
"""
return (
"ml.g5.12xlarge"
if "gpu" in image
else "ml.c5.4xlarge"
if "tensorflow" in image
else "ml.c5.9xlarge"
)
def extract_timestamp(self, ticket_key):
"""
extract the timestamp string from S3 request ticket key
:param ticket_key: <string> key of the request ticket
:return: <string> timestamp in format "%Y-%m-%d-%H-%M-%S" that is encoded in the ticket name
"""
return re.match(r".*_(\d{4}(-\d{2}){5})\.json", ticket_key).group(1)
def ticket_timestamp_cmp_function(self, ticket1_name, ticket2_name):
"""
Compares the timestamp of the two request tickets
:param ticket1, ticket2: <dict> S3 object descriptors from s3_client.list_objects
:return: <bool>
"""
ticket1_timestamp, ticket2_timestamp = (
self.extract_timestamp(ticket1_name),
self.extract_timestamp(ticket2_name),
)
return ticket1_timestamp > ticket2_timestamp
def construct_query_response(self, status, reason=None, queueNum=None):
"""
Create query response for query_status calls
:param status: <string> queuing/preparing/completed/runtimeError
:param reason: <string> maxRetries/timeout
:param queueNum: <int>
:return: <dict> response for the ticket query
"""
query_response = {"status": status}
if reason != None:
query_response["reason"] = reason
if queueNum != None:
query_response["queueNum"] = queueNum
return query_response
def search_ticket_folder(self, folder, path):
"""
Search folder/path on S3 to find the target ticket. If found, return a query response for the search. Otherwise
return None.
:param folder: <string> folder to search
:param path: <string> path within the folder
:return: <dict or None>
"""
objects = self.s3_client.list_objects(
Bucket=self.s3_ticket_bucket, Prefix=f"{folder}/{path}"
)
if "Contents" in objects:
ticket_key = objects["Contents"][0]["Key"]
suffix_pattern = re.compile(".*-(.*).json")
suffix = suffix_pattern.match(ticket_key).group(1)
if folder == "dead_letter_queue" or folder == "duplicate_pr_requests":
return self.construct_query_response("failed", reason=suffix)
else:
return self.construct_query_response(suffix)
return None
def send_request(self, image, build_context, num_of_instances):
"""
Sending a request to test job executor (place request ticket to S3)
Could run under multi-threading context
:param num_of_instances: <int> number of instances needed for the test
:param image: <string> ECR uri
:param build_context: <string> PR/MAINLINE/NIGHTLY/DEV
:return: <Message object>
"""
assert (
"training" in image or "inference" in image
), f"Job type (training/inference) not stated in image tag: {image}"
time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
ticket_content = self.create_ticket_content(image, build_context, num_of_instances, time)
framework = (
"mxnet" if "mxnet" in image else "pytorch" if "pytorch" in image else "tensorflow"
)
ticket_name = self.send_ticket(ticket_content, framework)
instance_type = self.assign_sagemaker_instance_type(image)
job_type = "training" if "training" in image else "inference"
identifier = Message(
self.s3_ticket_bucket, ticket_name, image, instance_type, job_type, time
)
return identifier
def receive_logs(self, identifier):
"""
Requesting for the test logs
:param identifier: <Message object> returned from send_request
:return: <json or None> if log received, return the json log. Otherwise return None.
"""
ticket_name_without_extension = identifier.ticket_name.rstrip(".json")
objects = self.s3_client.list_objects(
Bucket=self.s3_ticket_bucket,
Prefix=f"resource_pool/{identifier.instance_type}-{identifier.job_type}/{ticket_name_without_extension}",
)
ticket_prefix = f"resource_pool/{identifier.instance_type}-{identifier.job_type}/{ticket_name_without_extension}"
if "Contents" in objects:
entry = objects["Contents"][0]
ticket_object = self.s3_client.get_object(Bucket="dlc-test-tickets", Key=entry["Key"])
ticket_body = json.loads(ticket_object["Body"].read().decode("utf-8"))
return ticket_body["LOGS"]
return None
def cancel_request(self, identifier):
"""
Cancel the test request by removing ticket from the queue.
If the test request is already running, do nothing.
:param identifier: <Message object> the response object returned from send_request
"""
# check if ticket is on the queue
ticket_in_queue = self.search_ticket_folder(
"request_tickets", identifier.ticket_name.rstrip(".json")
)
if ticket_in_queue:
self.s3_client.delete_object(
Bucket=self.s3_ticket_bucket, Key=f"request_tickets/{identifier.ticket_name}"
)
return
# check if ticket is a PR duplicate
ticket_in_duplicate = self.search_ticket_folder(
"duplicate_pr_requests", identifier.ticket_name.rstrip(".json")
)
if ticket_in_duplicate:
LOGGER.info(
f"{identifier.ticket_name} is a duplicate PR test, test request will not be scheduled."
)
return
LOGGER.info(
f"{identifier.ticket_name} test has begun, test request could not be cancelled."
)
def query_status(self, identifier):
"""
:param identifier: <Message object> unique identifier returned from call to send_request
:return: <dict> {"status": <string> queuing/preparing/completed/failed/runtimeError,
"reason" (if status == failed): <string> maxRetries/timeout/duplicatePR,
"queueNum" (if status == queuing): <int>
}
"""
retries = 2
request_ticket_name = identifier.ticket_name
ticket_without_extension = request_ticket_name.rstrip(".json")
instance_type = identifier.instance_type
job_type = identifier.job_type
for _ in range(retries):
# check if ticket is on the queue
ticket_objects = self.s3_client.list_objects(
Bucket=self.s3_ticket_bucket, Prefix="request_tickets/"
)
# "Contents" in the API response only if there are objects satisfy the prefix
if "Contents" in ticket_objects:
ticket_name_pattern = re.compile(".*\/(.*)")
ticket_names_list = [
ticket_name_pattern.match(ticket["Key"]).group(1)
for ticket in ticket_objects["Contents"]
if ticket["Key"].endswith(".json")
]
# ticket is on the queue, find the queue number
if request_ticket_name in ticket_names_list:
ticket_names_list.sort(key=cmp_to_key(self.ticket_timestamp_cmp_function))
queue_num = ticket_names_list.index(request_ticket_name)
return self.construct_query_response("queuing", queueNum=queue_num)
# check if ticket is on the dead letter queue
ticket_in_dead_letter = self.search_ticket_folder(
"dead_letter_queue", ticket_without_extension
)
if ticket_in_dead_letter:
return ticket_in_dead_letter
ticket_in_duplicate = self.search_ticket_folder(
"duplicate_pr_requests", ticket_without_extension
)
if ticket_in_duplicate:
return ticket_in_duplicate
ticket_in_progress = self.search_ticket_folder(
"resource_pool", f"{instance_type}-{job_type}/{ticket_without_extension}"
)
if ticket_in_progress:
return ticket_in_progress
time.sleep(2)
raise AssertionError(f"Request ticket name {request_ticket_name} could not be found.")