in metaflow/plugins/aws/batch/batch.py [0:0]
def wait(self, stdout_location, stderr_location, echo=None):
def wait_for_launch(job, child_jobs):
status = job.status
echo(
"Task is starting (status %s)..." % status,
"stderr",
batch_id=job.id,
)
t = time.time()
while True:
if status != job.status or (time.time() - t) > 30:
if not child_jobs:
child_statuses = ""
else:
status_keys = set(
[child_job.status for child_job in child_jobs]
)
status_counts = [
(
status,
len(
[
child_job.status == status
for child_job in child_jobs
]
),
)
for status in status_keys
]
child_statuses = " (parallel node status: [{}])".format(
", ".join(
[
"{}:{}".format(status, num)
for (status, num) in sorted(status_counts)
]
)
)
status = job.status
echo(
"Task is starting (status %s)... %s" % (status, child_statuses),
"stderr",
batch_id=job.id,
)
t = time.time()
if job.is_running or job.is_done or job.is_crashed:
break
select.poll().poll(200)
prefix = b"[%s] " % util.to_bytes(self.job.id)
stdout_tail = S3Tail(stdout_location)
stderr_tail = S3Tail(stderr_location)
child_jobs = []
if self.num_parallel > 1:
for node in range(1, self.num_parallel):
child_job = copy.copy(self.job)
child_job._id = child_job._id + "#{}".format(node)
child_jobs.append(child_job)
# 1) Loop until the job has started
wait_for_launch(self.job, child_jobs)
# 2) Tail logs until the job has finished
tail_logs(
prefix=prefix,
stdout_tail=stdout_tail,
stderr_tail=stderr_tail,
echo=echo,
has_log_updates=lambda: self.job.is_running,
)
# In case of hard crashes (OOM), the final save_logs won't happen.
# We can fetch the remaining logs from AWS CloudWatch and persist them
# to Amazon S3.
if self.job.is_crashed:
msg = next(
msg
for msg in [
self.job.reason,
self.job.status_reason,
"Task crashed.",
]
if msg is not None
)
raise BatchException(
"%s " "This could be a transient error. " "Use @retry to retry." % msg
)
else:
if self.job.is_running:
# Kill the job if it is still running by throwing an exception.
raise BatchException("Task failed!")
echo(
"Task finished with exit code %s." % self.job.status_code,
"stderr",
batch_id=self.job.id,
)