python-batch/batch.py (238 lines of code) (raw):

#!/usr/bin/env python3 # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Tools to run Google Cloud Batch API """ __author__ = "J Ross Thomson drj@" __version__ = "0.1.0" import json import os import uuid import yaml import sys from absl import app from absl import flags from google.api_core.operation import Operation from google.cloud import batch_v1 from google.cloud import pubsub_v1 from typing import Iterable from yaml.loader import SafeLoader """ We define multiple command line flags: """ FLAGS = flags.FLAGS flags.DEFINE_string("config_file", None, "Config file in YAML") flags.DEFINE_boolean("pubsub", False , "Run Pubsub Topic and Subscriber") flags.DEFINE_string("previous_job_id", None, "For Pubsub restart, specifies topic to read from") flags.DEFINE_string("project_id", None, "Google Cloud Project ID, not name") flags.DEFINE_spaceseplist("volumes", None, "List of GCS paths to mount. Example, \"bucket_name1:mountpath1 bucket_name2:mountpath2\"" ) flags.DEFINE_boolean("create_job", False, "Creates job, otherwise just prints config.") flags.DEFINE_boolean("list_jobs", False, "If true, list jobs for config.") flags.DEFINE_string("delete_job", "", "Job name to delete.") flags.DEFINE_boolean("debug", False, "If true, print debug info.") # Required flag. flags.mark_flag_as_required("config_file") class PubSub: def __init__(self, job_id: str, config, previous_job_id=None) -> None: """Class to create Pub/Sub related cloud resources. Used to create a topic and a subscription. Topic name identical to job_id and subscription is `sub-` + job_id. Args: config: Contains the structure defined by the Config.yaml file. job_id: The name of the job, which becomes the name of the pubsub topic. Raises: None `project_id` is set based on config, env, or argv in that order. """ self.config = config self.job_id = job_id if self.config["project_id"]: self.project_id = self.config["project_id"] if os.environ.get('GOOGLE_CLOUD_PROJECT'): self.project_id = os.environ.get('GOOGLE_CLOUD_PROJECT') if FLAGS.project_id: self.project_id = FLAGS.project_id """ Sending messages to the same region ensures they are received in order even when multiple publishers are used. """ publisher_options = pubsub_v1.types.PublisherOptions(enable_message_ordering=True) client_options = {"api_endpoint": f'{self.config["region"]}-pubsub.googleapis.com:443'} self.publisher = pubsub_v1.PublisherClient( publisher_options=publisher_options, client_options=client_options ) def create_topic(self): """ Creates a topic with name the same as Job. The `topic_path` method creates a fully qualified identifier in the form `projects/{project_id}/topics/{topic_id}` Args: None Returns: None """ self.topic_path = self.publisher.topic_path(self.project_id, self.job_id) self.topic = self.publisher.create_topic(request={"name": self.topic_path}) print(f"Created topic {self.topic_path}\n", file=sys.stderr) def create_subscription(self): """ Creates a subscription with name the same as `sub-` + Job. Subscription will be ordered and have exactly one time delivery. Args: Returns: None """ self.subscriber = pubsub_v1.SubscriberClient() self.subscription_id = "sub-" + self.job_id self.topic_path = self.publisher.topic_path(self.project_id, self.job_id) self.subscription_path = self.subscriber.subscription_path(self.project_id, self.subscription_id) with self.subscriber: self.subscription = self.subscriber.create_subscription( request={ "name": self.subscription_path, "topic": self.topic_path, "enable_message_ordering": True, "enable_exactly_once_delivery": True, } ) print(f"Created subscription {self.subscription_path}\n", file=sys.stderr) def publish_fifo_ids(self): """ Publish integer ID messages to the pubsub topic. Simple queue create via the Pubsub Topic to create queue. If the queue is not fully pulled, the project can restart with the same queue. Picking up where it left off. Args: None Returns: None """ self.order_key = "fifo" for i in range(0, self.config["task_count"] ): # Data must be a bytestring data = str(i).encode("utf-8") # When you publish a message, the client returns a future. future = self.publisher.publish(self.topic_path, data=data, ordering_key=self.order_key) if FLAGS.debug: print(f'Future: {future.result()}, Message: {data}', file=sys.stderr) class CloudBatchJobs: def __init__(self, job_id: str, config, env_vars) -> None: """ Class to create all cloud resources for Batch Jobs Args: job_id: an arbitrary string to name the job. config: data structure from YAML to represent everything in the config file. Returns: None """ self.config = config self.job_id = job_id self.env_vars = env_vars #set "project_id" based on config, env, or argv in that order. if self.config["project_id"]: self.project_id = self.config["project_id"] if os.environ.get('GOOGLE_CLOUD_PROJECT'): self.project_id = os.environ.get('GOOGLE_CLOUD_PROJECT') if FLAGS.project_id: self.project_id = FLAGS.project_id self.client = batch_v1.BatchServiceClient() def _create_runnable(self) -> batch_v1.Runnable: self.runnable = batch_v1.Runnable() self.runnable.environment = batch_v1.Environment() self.runnable.environment.variables = self.env_vars if "container" in self.config: self.runnable.container = batch_v1.Runnable.Container() self.runnable.container.image_uri = self.config["container"]["image_uri"] self.runnable.container.entrypoint = self.config["container"]["entry_point"] self.runnable.container.commands = self.config["container"]["commands"] self.runnable.container.options = "--privileged" if "install_gpu_drivers" in self.config: self.runnable.container.volumes.append("/var/lib/nvidia/lib64:/usr/local/nvidia/lib64") self.runnable.container.volumes.append("/var/lib/nvidia/bin:/usr/local/nvidia/bin") else: self.runnable.script = batch_v1.Runnable.Script() self.runnable.script.text = self.config["script_text"] return(self.runnable) def _create_task(self) -> batch_v1.TaskSpec: """ _summary_ :return: _description_ :rtype: batch_v1.TaskSpec """ self.task = batch_v1.TaskSpec() self.task.max_retry_count = 2 self.task.max_run_duration = "604800s" self.task.volumes = [] if "nfs_server" in self.config: nfs_server = batch_v1.NFS() nfs_server.server = self.config["nfs_server"] nfs_server.remote_path = self.config["nfs_remote_path"] if "nfs_remote_path" in self.config else "/nfshare" nfs_volume = batch_v1.Volume() nfs_volume.nfs = nfs_server nfs_volume.mount_path = self.config["nfs_path"] if "nfs_path" in self.config else "/mnt/nfs" self.task.volumes.append(nfs_volume) # Use commmand line flags first, then config file, to set # gcs bucket mounts if(FLAGS.volumes): self.volumes_list = [] for volume_pair in FLAGS.volumes: (bucket_name, gcs_path) = volume_pair.split(":") self.volumes_list.append({"bucket_name":bucket_name, "gcs_path":gcs_path}) elif "volumes" in self.config: self.volumes_list = self.config["volumes"] else: self.volumes_list = None if(not self.volumes_list is None): gcs_volume = batch_v1.Volume() for volume in self.volumes_list: gcs_bucket = batch_v1.GCS() gcs_bucket.remote_path = volume["bucket_name"] gcs_volume.mount_path = volume["gcs_path"] gcs_volume.gcs = gcs_bucket self.task.volumes.append(gcs_volume) if "container" in self.config: self.runnable.container.volumes.append( volume["gcs_path"]+":"+volume["gcs_path"]+":rw") # We can specify what resources are requoested by each task. resources = batch_v1.ComputeResource() resources.cpu_milli = self.config["cpu_milli"] if "cpu_milli" in self.config else 1000 resources.memory_mib = self.config["memory_mib"] if "memory_mib" in self.config else 102400 resources.boot_disk_mib = self.config["boot_disk_mib"] if "boot_disk_mib" in self.config else 102400 self.task.compute_resource = resources return(self.task) def _create_allocation_policy(self) -> batch_v1.AllocationPolicy: # Policies are used to define on what kind of virtual machines the tasks will run on. # In this case, we tell the system to use an instance template that defines all the # required parameters. self.allocation_policy = batch_v1.AllocationPolicy() self.instance = batch_v1.AllocationPolicy.InstancePolicyOrTemplate() self.instance_policy = batch_v1.AllocationPolicy.InstancePolicy() self.accelerator = batch_v1.AllocationPolicy.Accelerator() self.network_policy = batch_v1.AllocationPolicy.NetworkPolicy() self.network = batch_v1.AllocationPolicy.NetworkInterface() if "template_link" in self.config: self.instance.instance_template = self.config["template_link"] elif "machine_type" in self.config: self.instance_policy.machine_type = self.config["machine_type"] self.instance.policy = self.instance_policy if "accelerator" in self.config: self.accelerator.type_ = self.config["accelerator"]["type"] self.accelerator.count = self.config["accelerator"]["count"] self.instance_policy.accelerators = [self.accelerator] if "install_gpu_drivers" in self.config: self.instance.install_gpu_drivers = True self.instance.policy = self.instance_policy else: raise(Error("No instance policy defined.")) location_policy = batch_v1.AllocationPolicy.LocationPolicy() self.allocation_policy.instances = [self.instance] if "network" in self.config: if "no_external_ip_address" in self.config: self.network.no_external_ip_address = self.config["no_external_ip_address"] self.network.subnetwork = f'regions/{self.config["region"] }/subnetworks/{self.config["subnetwork"]}' self.network.network = f'projects/{ self.project_id }/global/networks/{self.config["network"]}' self.network_policy.network_interfaces = [self.network] self.allocation_policy.network = self.network_policy if "allowed_locations" in self.config: location_policy.allowed_locations = self.config["allowed_locations"] self.allocation_policy.location = location_policy return(self.allocation_policy) def _create_taskgroup(self): """ Tasks are grouped inside a job using TaskGroups. Currently, it's possible to have only one task group. """ self.group = batch_v1.TaskGroup() self.group.task_spec = self.task self.group.task_count = self.config["task_count"] if "task_count" in self.config else 1 self.group.parallelism = self.config["parallelism"] if "parallelism" in self.config else 1 self.group.task_count_per_node = self.config["task_count_per_node"] if "task_count_per_node" in self.config else 1 return(self.group) def _create_job_request(self): """Creates job after configuration is completed. """ self.job = batch_v1.Job() self.job.labels = self.config["labels"] if "labels" in self.config else {"env": "hpc", "type": "hpc"} self.job.task_groups = [self.group] self.job.allocation_policy = self.allocation_policy # We use Cloud Logging as it's an out of the box available option self.job.logs_policy = batch_v1.LogsPolicy() self.job.logs_policy.destination = batch_v1.LogsPolicy.Destination.CLOUD_LOGGING return(self.job) def create_job_request(self) -> batch_v1.Job: # Define what will be done as part of the job. self._create_runnable() self._create_task() self.task.runnables = [self.runnable] self._create_taskgroup() self._create_allocation_policy() self._create_job_request() create_request = batch_v1.CreateJobRequest() create_request.job = self.job create_request.job_id = self.job_id # The job's parent is the region in which the job will run create_request.parent = f'projects/{ self.project_id }/locations/{self.config["region"]}' return(create_request) # [END batch_create_job_with_template] def delete_job(self, job_to_delete) -> Operation: """ Triggers the deletion of a Job. Args: delete_job: name of the job to delete Returns: An operation object related to the deletion. You can call `.result()` on it to wait for its completion. """ client = batch_v1.BatchServiceClient() return client.delete_job(name=f"projects/{self.project_id}/locations/{self.config['region']}/jobs/{job_to_delete}") def list_jobs(self) -> Iterable[batch_v1.Job]: """ Get a list of all jobs defined in given region. Returns: An iterable collection of Job object. """ client = batch_v1.BatchServiceClient() return client.list_jobs(parent=f"projects/{self.project_id}/locations/{self.config['region']}") def parse_yaml_file(file_name: str): """ Parse the provided YAML file to get the job configuration :param file_name: The path to the YAML configuration file :type file_name: str """ with open(file_name) as f: data = yaml.load(f, Loader=SafeLoader) return(data) def main(argv): """ This is where the Job and Pubsub objects are created and used. :param argv: Standard commandline arguments :type argv: _type_ """ # Read in config file. config = parse_yaml_file(FLAGS.config_file) # Create unique job ID, required for Batch and k8s Jobs job_id = config["job_prefix"] + uuid.uuid4().hex[:8] # "previous_job_id": Used for Pubsub. TOPIC is created from JobID, a restart can provide an existing Job # Id for existing pubsub queue # TODO: support more env_vars env_vars = {} if FLAGS.previous_job_id: env_vars = dict(env_vars, TOPIC_ID = FLAGS.previous_job_id) else: env_vars = dict(env_vars, TOPIC_ID = job_id) # Create jobs object jobs = CloudBatchJobs(job_id, config, env_vars) # If pubsub queue is required if FLAGS.pubsub and not FLAGS.debug: pubsub = PubSub(job_id, config, previous_job_id=FLAGS.previous_job_id) pubsub.create_topic() pubsub.create_subscription() pubsub.publish_fifo_ids() # Delete job. JobID must be passed. if(FLAGS.delete_job): print("Deleting job", file=sys.stderr) deleted_job = jobs.delete_job(FLAGS.delete_job) print(deleted_job.result, file=sys.stderr) exit() # Prints list of jobs, in queue, running or complteted. if(FLAGS.list_jobs): print("Listing jobs", file=sys.stderr) job_list = jobs.list_jobs() for job in job_list: print(job.name,"\t",job.status.state, file=sys.stderr) exit() # Prints if FLAGS.debug: print(config, file=sys.stderr) if FLAGS.create_job: # Create the job print(jobs.client.create_job(jobs.create_job_request()), file=sys.stderr) else: print(jobs.create_job_request().job, file=sys.stderr) if __name__ == "__main__": """ This is executed when run from the command line """ app.run(main)