in evaluation/utils/evaluator.py [0:0]
def update_status(self):
def _update_job_status(job):
try:
sagemaker = self.client("sagemaker", job.dataset_name)
job.status = sagemaker.describe_transform_job(
TransformJobName=job.job_name
)
except Exception as ex:
logger.exception(ex)
logger.info(f"Postponing fetching status for {job.job_name}")
return False
else:
return True
logger.info("Updating status")
done_job_names = set()
for job in self._submitted:
if _update_job_status(job):
if job.status["TransformJobStatus"] != "InProgress":
if job.status["TransformJobStatus"] != "Completed":
self._failed.append(job)
update_evaluation_status(
job.model_id, job.dataset_name, "failed"
)
done_job_names.add(job.job_name)
elif job.perturb_prefix or round_end_dt(
job.status["TransformEndTime"]
) < datetime.now(tzlocal()):
self._completed.append(job)
# Although this particular job is completed here, scores
# still need to be calculated, so the status is now
# "evaluating"
update_evaluation_status(
job.model_id, job.dataset_name, "evaluating"
)
done_job_names.add(job.job_name)
self._submitted = [
job for job in self._submitted if job.job_name not in done_job_names
]
logger.info("Fetch metrics")
# fetch AWS metrics for completed jobs
for job in self._completed:
if not job.aws_metrics and not job.perturb_prefix:
cloudwatch = self.client("cloudwatch", job.dataset_name)
cloudwatchlog = self.client("logs", job.dataset_name)
logStreams = cloudwatchlog.describe_log_streams(
logGroupName=self.cloudwatch_namespace,
logStreamNamePrefix=f"{job.job_name}/",
)["logStreams"]
if logStreams:
hosts = set()
for logStream in logStreams:
if logStream["logStreamName"].count("/") == 1:
hosts.add(
"-".join(logStream["logStreamName"].split("-")[:-1])
) # each host is a machine instance
for host in hosts:
metrics = cloudwatch.list_metrics(
Namespace=self.cloudwatch_namespace,
Dimensions=[{"Name": "Host", "Value": host}],
)
if metrics["Metrics"]:
round_start = round_start_dt(
job.status["TransformStartTime"]
)
round_end = round_end_dt(job.status["TransformEndTime"])
# Make sure to not ask more than 1440 points (API limit)
period = (round_end - round_start).total_seconds() / 1440
# Period must be a multiple of 60
period = int(math.ceil(period / 60) * 60)
period = max(60, period)
for m in metrics["Metrics"]:
r = cloudwatch.get_metric_statistics(
Namespace=m["Namespace"],
MetricName=m["MetricName"],
Dimensions=m["Dimensions"],
StartTime=round_start,
EndTime=round_end,
Period=period,
Statistics=["Average"],
)
if r["Datapoints"]:
if m["MetricName"] not in job.aws_metrics:
job.aws_metrics[m["MetricName"]] = []
job.aws_metrics[m["MetricName"]].append(
process_aws_metrics(r["Datapoints"])
)
# dump the updated status
self.dump()