in xlml/apis/task.py [0:0]
def run(self) -> DAGNode:
"""Run a test job.
Returns:
A task group with the following tasks chained: provision, run_model,
post_process, clean_up.
"""
# piz: We skip the queued resource for GPU for now since there is no queued
# resource command for GPU.
if self.existing_instance_name is not None:
return self.run_with_existing_instance()
with TaskGroup(
group_id=self.task_test_config.benchmark_id, prefix_group_id=True
) as group:
(
provision,
ip_address,
instance_name,
ssh_keys,
gcs_location,
) = self.provision()
# If you already specify `task_metric_config.json_lines` value in the
# test config script, then `gcs_location` will take no effect.
if (
self.task_metric_config
and self.task_metric_config.use_runtime_generated_gcs_folder
):
env_variable = {
f"{metric_config.SshEnvVars.GCS_OUTPUT.name}": gcs_location
}
else:
env_variable = None
run_model = self.run_model(ip_address, ssh_keys, env_variable)
post_process = self.post_process(gcs_location)
clean_up = self.clean_up(
instance_name,
self.task_gcp_config.project_name,
self.task_gcp_config.zone,
)
provision >> run_model >> post_process >> clean_up
return group