python-batch/batch.py (238 lines of code) (raw):
#!/usr/bin/env python3
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tools to run Google Cloud Batch API
"""
__author__ = "J Ross Thomson drj@"
__version__ = "0.1.0"
import json
import os
import uuid
import yaml
import sys
from absl import app
from absl import flags
from google.api_core.operation import Operation
from google.cloud import batch_v1
from google.cloud import pubsub_v1
from typing import Iterable
from yaml.loader import SafeLoader
"""
We define multiple command line flags:
"""
FLAGS = flags.FLAGS
flags.DEFINE_string("config_file", None, "Config file in YAML")
flags.DEFINE_boolean("pubsub", False , "Run Pubsub Topic and Subscriber")
flags.DEFINE_string("previous_job_id", None, "For Pubsub restart, specifies topic to read from")
flags.DEFINE_string("project_id", None, "Google Cloud Project ID, not name")
flags.DEFINE_spaceseplist("volumes", None, "List of GCS paths to mount. Example, \"bucket_name1:mountpath1 bucket_name2:mountpath2\"" )
flags.DEFINE_boolean("create_job", False, "Creates job, otherwise just prints config.")
flags.DEFINE_boolean("list_jobs", False, "If true, list jobs for config.")
flags.DEFINE_string("delete_job", "", "Job name to delete.")
flags.DEFINE_boolean("debug", False, "If true, print debug info.")
# Required flag.
flags.mark_flag_as_required("config_file")
class PubSub:
def __init__(self, job_id: str, config, previous_job_id=None) -> None:
"""Class to create Pub/Sub related cloud resources.
Used to create a topic and a subscription.
Topic name identical to job_id and subscription is `sub-` + job_id.
Args:
config:
Contains the structure defined by the Config.yaml file.
job_id:
The name of the job, which becomes the name of the pubsub topic.
Raises:
None
`project_id` is set based on config, env, or argv in that order.
"""
self.config = config
self.job_id = job_id
if self.config["project_id"]:
self.project_id = self.config["project_id"]
if os.environ.get('GOOGLE_CLOUD_PROJECT'):
self.project_id = os.environ.get('GOOGLE_CLOUD_PROJECT')
if FLAGS.project_id:
self.project_id = FLAGS.project_id
"""
Sending messages to the same region ensures they are received in order
even when multiple publishers are used.
"""
publisher_options = pubsub_v1.types.PublisherOptions(enable_message_ordering=True)
client_options = {"api_endpoint": f'{self.config["region"]}-pubsub.googleapis.com:443'}
self.publisher = pubsub_v1.PublisherClient(
publisher_options=publisher_options, client_options=client_options
)
def create_topic(self):
"""
Creates a topic with name the same as Job.
The `topic_path` method creates a fully qualified identifier
in the form `projects/{project_id}/topics/{topic_id}`
Args:
None
Returns:
None
"""
self.topic_path = self.publisher.topic_path(self.project_id, self.job_id)
self.topic = self.publisher.create_topic(request={"name": self.topic_path})
print(f"Created topic {self.topic_path}\n", file=sys.stderr)
def create_subscription(self):
"""
Creates a subscription with name the same as `sub-` + Job.
Subscription will be ordered and have exactly one time delivery.
Args:
Returns:
None
"""
self.subscriber = pubsub_v1.SubscriberClient()
self.subscription_id = "sub-" + self.job_id
self.topic_path = self.publisher.topic_path(self.project_id, self.job_id)
self.subscription_path = self.subscriber.subscription_path(self.project_id, self.subscription_id)
with self.subscriber:
self.subscription = self.subscriber.create_subscription(
request={
"name": self.subscription_path,
"topic": self.topic_path,
"enable_message_ordering": True,
"enable_exactly_once_delivery": True,
}
)
print(f"Created subscription {self.subscription_path}\n", file=sys.stderr)
def publish_fifo_ids(self):
"""
Publish integer ID messages to the pubsub topic.
Simple queue create via the Pubsub Topic to create queue.
If the queue is not fully pulled, the project can restart with the same queue.
Picking up where it left off.
Args:
None
Returns:
None
"""
self.order_key = "fifo"
for i in range(0, self.config["task_count"] ):
# Data must be a bytestring
data = str(i).encode("utf-8")
# When you publish a message, the client returns a future.
future = self.publisher.publish(self.topic_path, data=data, ordering_key=self.order_key)
if FLAGS.debug:
print(f'Future: {future.result()}, Message: {data}', file=sys.stderr)
class CloudBatchJobs:
def __init__(self, job_id: str, config, env_vars) -> None:
"""
Class to create all cloud resources for Batch Jobs
Args:
job_id: an arbitrary string to name the job.
config: data structure from YAML to represent everything in the config file.
Returns:
None
"""
self.config = config
self.job_id = job_id
self.env_vars = env_vars
#set "project_id" based on config, env, or argv in that order.
if self.config["project_id"]:
self.project_id = self.config["project_id"]
if os.environ.get('GOOGLE_CLOUD_PROJECT'):
self.project_id = os.environ.get('GOOGLE_CLOUD_PROJECT')
if FLAGS.project_id:
self.project_id = FLAGS.project_id
self.client = batch_v1.BatchServiceClient()
def _create_runnable(self) -> batch_v1.Runnable:
self.runnable = batch_v1.Runnable()
self.runnable.environment = batch_v1.Environment()
self.runnable.environment.variables = self.env_vars
if "container" in self.config:
self.runnable.container = batch_v1.Runnable.Container()
self.runnable.container.image_uri = self.config["container"]["image_uri"]
self.runnable.container.entrypoint = self.config["container"]["entry_point"]
self.runnable.container.commands = self.config["container"]["commands"]
self.runnable.container.options = "--privileged"
if "install_gpu_drivers" in self.config:
self.runnable.container.volumes.append("/var/lib/nvidia/lib64:/usr/local/nvidia/lib64")
self.runnable.container.volumes.append("/var/lib/nvidia/bin:/usr/local/nvidia/bin")
else:
self.runnable.script = batch_v1.Runnable.Script()
self.runnable.script.text = self.config["script_text"]
return(self.runnable)
def _create_task(self) -> batch_v1.TaskSpec:
"""
_summary_
:return: _description_
:rtype: batch_v1.TaskSpec
"""
self.task = batch_v1.TaskSpec()
self.task.max_retry_count = 2
self.task.max_run_duration = "604800s"
self.task.volumes = []
if "nfs_server" in self.config:
nfs_server = batch_v1.NFS()
nfs_server.server = self.config["nfs_server"]
nfs_server.remote_path = self.config["nfs_remote_path"] if "nfs_remote_path" in self.config else "/nfshare"
nfs_volume = batch_v1.Volume()
nfs_volume.nfs = nfs_server
nfs_volume.mount_path = self.config["nfs_path"] if "nfs_path" in self.config else "/mnt/nfs"
self.task.volumes.append(nfs_volume)
# Use commmand line flags first, then config file, to set
# gcs bucket mounts
if(FLAGS.volumes):
self.volumes_list = []
for volume_pair in FLAGS.volumes:
(bucket_name, gcs_path) = volume_pair.split(":")
self.volumes_list.append({"bucket_name":bucket_name, "gcs_path":gcs_path})
elif "volumes" in self.config:
self.volumes_list = self.config["volumes"]
else:
self.volumes_list = None
if(not self.volumes_list is None):
gcs_volume = batch_v1.Volume()
for volume in self.volumes_list:
gcs_bucket = batch_v1.GCS()
gcs_bucket.remote_path = volume["bucket_name"]
gcs_volume.mount_path = volume["gcs_path"]
gcs_volume.gcs = gcs_bucket
self.task.volumes.append(gcs_volume)
if "container" in self.config:
self.runnable.container.volumes.append(
volume["gcs_path"]+":"+volume["gcs_path"]+":rw")
# We can specify what resources are requoested by each task.
resources = batch_v1.ComputeResource()
resources.cpu_milli = self.config["cpu_milli"] if "cpu_milli" in self.config else 1000
resources.memory_mib = self.config["memory_mib"] if "memory_mib" in self.config else 102400
resources.boot_disk_mib = self.config["boot_disk_mib"] if "boot_disk_mib" in self.config else 102400
self.task.compute_resource = resources
return(self.task)
def _create_allocation_policy(self) -> batch_v1.AllocationPolicy:
# Policies are used to define on what kind of virtual machines the tasks will run on.
# In this case, we tell the system to use an instance template that defines all the
# required parameters.
self.allocation_policy = batch_v1.AllocationPolicy()
self.instance = batch_v1.AllocationPolicy.InstancePolicyOrTemplate()
self.instance_policy = batch_v1.AllocationPolicy.InstancePolicy()
self.accelerator = batch_v1.AllocationPolicy.Accelerator()
self.network_policy = batch_v1.AllocationPolicy.NetworkPolicy()
self.network = batch_v1.AllocationPolicy.NetworkInterface()
if "template_link" in self.config:
self.instance.instance_template = self.config["template_link"]
elif "machine_type" in self.config:
self.instance_policy.machine_type = self.config["machine_type"]
self.instance.policy = self.instance_policy
if "accelerator" in self.config:
self.accelerator.type_ = self.config["accelerator"]["type"]
self.accelerator.count = self.config["accelerator"]["count"]
self.instance_policy.accelerators = [self.accelerator]
if "install_gpu_drivers" in self.config:
self.instance.install_gpu_drivers = True
self.instance.policy = self.instance_policy
else:
raise(Error("No instance policy defined."))
location_policy = batch_v1.AllocationPolicy.LocationPolicy()
self.allocation_policy.instances = [self.instance]
if "network" in self.config:
if "no_external_ip_address" in self.config:
self.network.no_external_ip_address = self.config["no_external_ip_address"]
self.network.subnetwork = f'regions/{self.config["region"] }/subnetworks/{self.config["subnetwork"]}'
self.network.network = f'projects/{ self.project_id }/global/networks/{self.config["network"]}'
self.network_policy.network_interfaces = [self.network]
self.allocation_policy.network = self.network_policy
if "allowed_locations" in self.config:
location_policy.allowed_locations = self.config["allowed_locations"]
self.allocation_policy.location = location_policy
return(self.allocation_policy)
def _create_taskgroup(self):
"""
Tasks are grouped inside a job using TaskGroups.
Currently, it's possible to have only one task group.
"""
self.group = batch_v1.TaskGroup()
self.group.task_spec = self.task
self.group.task_count = self.config["task_count"] if "task_count" in self.config else 1
self.group.parallelism = self.config["parallelism"] if "parallelism" in self.config else 1
self.group.task_count_per_node = self.config["task_count_per_node"] if "task_count_per_node" in self.config else 1
return(self.group)
def _create_job_request(self):
"""Creates job after configuration is completed.
"""
self.job = batch_v1.Job()
self.job.labels = self.config["labels"] if "labels" in self.config else {"env": "hpc", "type": "hpc"}
self.job.task_groups = [self.group]
self.job.allocation_policy = self.allocation_policy
# We use Cloud Logging as it's an out of the box available option
self.job.logs_policy = batch_v1.LogsPolicy()
self.job.logs_policy.destination = batch_v1.LogsPolicy.Destination.CLOUD_LOGGING
return(self.job)
def create_job_request(self) -> batch_v1.Job:
# Define what will be done as part of the job.
self._create_runnable()
self._create_task()
self.task.runnables = [self.runnable]
self._create_taskgroup()
self._create_allocation_policy()
self._create_job_request()
create_request = batch_v1.CreateJobRequest()
create_request.job = self.job
create_request.job_id = self.job_id
# The job's parent is the region in which the job will run
create_request.parent = f'projects/{ self.project_id }/locations/{self.config["region"]}'
return(create_request)
# [END batch_create_job_with_template]
def delete_job(self, job_to_delete) -> Operation:
"""
Triggers the deletion of a Job.
Args:
delete_job: name of the job to delete
Returns:
An operation object related to the deletion. You can call `.result()`
on it to wait for its completion.
"""
client = batch_v1.BatchServiceClient()
return client.delete_job(name=f"projects/{self.project_id}/locations/{self.config['region']}/jobs/{job_to_delete}")
def list_jobs(self) -> Iterable[batch_v1.Job]:
"""
Get a list of all jobs defined in given region.
Returns:
An iterable collection of Job object.
"""
client = batch_v1.BatchServiceClient()
return client.list_jobs(parent=f"projects/{self.project_id}/locations/{self.config['region']}")
def parse_yaml_file(file_name: str):
"""
Parse the provided YAML file to get the job configuration
:param file_name: The path to the YAML configuration file
:type file_name: str
"""
with open(file_name) as f:
data = yaml.load(f, Loader=SafeLoader)
return(data)
def main(argv):
"""
This is where the Job and Pubsub objects are created and used.
:param argv: Standard commandline arguments
:type argv: _type_
"""
# Read in config file.
config = parse_yaml_file(FLAGS.config_file)
# Create unique job ID, required for Batch and k8s Jobs
job_id = config["job_prefix"] + uuid.uuid4().hex[:8]
# "previous_job_id": Used for Pubsub. TOPIC is created from JobID, a restart can provide an existing Job
# Id for existing pubsub queue
# TODO: support more env_vars
env_vars = {}
if FLAGS.previous_job_id:
env_vars = dict(env_vars, TOPIC_ID = FLAGS.previous_job_id)
else:
env_vars = dict(env_vars, TOPIC_ID = job_id)
# Create jobs object
jobs = CloudBatchJobs(job_id, config, env_vars)
# If pubsub queue is required
if FLAGS.pubsub and not FLAGS.debug:
pubsub = PubSub(job_id, config, previous_job_id=FLAGS.previous_job_id)
pubsub.create_topic()
pubsub.create_subscription()
pubsub.publish_fifo_ids()
# Delete job. JobID must be passed.
if(FLAGS.delete_job):
print("Deleting job", file=sys.stderr)
deleted_job = jobs.delete_job(FLAGS.delete_job)
print(deleted_job.result, file=sys.stderr)
exit()
# Prints list of jobs, in queue, running or complteted.
if(FLAGS.list_jobs):
print("Listing jobs", file=sys.stderr)
job_list = jobs.list_jobs()
for job in job_list:
print(job.name,"\t",job.status.state, file=sys.stderr)
exit()
# Prints
if FLAGS.debug:
print(config, file=sys.stderr)
if FLAGS.create_job:
# Create the job
print(jobs.client.create_job(jobs.create_job_request()), file=sys.stderr)
else:
print(jobs.create_job_request().job, file=sys.stderr)
if __name__ == "__main__":
""" This is executed when run from the command line """
app.run(main)