ansible/roles/slurm/files/scripts/load_bq.py (248 lines of code) (raw):

#!/usr/bin/env python3 import argparse import os import shelve import uuid from collections import namedtuple from datetime import datetime, timezone, timedelta from pathlib import Path from pprint import pprint from google.cloud.bigquery import SchemaField from google.cloud import bigquery as bq from google.api_core import retry, exceptions import util from util import run from util import cfg SACCT = "sacct" script = Path(__file__).resolve() DEFAULT_TIMESTAMP_FILE = script.parent / "bq_timestamp" timestamp_file = Path(os.environ.get("TIMESTAMP_FILE", DEFAULT_TIMESTAMP_FILE)) # cluster_id_file = script.parent / 'cluster_uuid' # try: # cluster_id = cluster_id_file.read_text().rstrip() # except FileNotFoundError: # cluster_id = uuid.uuid4().hex # cluster_id_file.write_text(cluster_id) job_idx_cache_path = script.parent / "bq_job_idx_cache" SLURM_TIME_FORMAT = r"%Y-%m-%dT%H:%M:%S" def make_datetime(time_string): return datetime.strptime(time_string, SLURM_TIME_FORMAT).replace( tzinfo=timezone.utc ) def make_time_interval(seconds): sign = 1 if seconds < 0: sign = -1 seconds = abs(seconds) d, r = divmod(seconds, 60 * 60 * 24) h, r = divmod(r, 60 * 60) m, s = divmod(r, 60) d *= sign h *= sign return f"{d}D {h:02}:{m:02}:{s}" converters = { "DATETIME": make_datetime, "INTERVAL": make_time_interval, "STRING": str, "INT64": lambda n: int(n or 0), } def schema_field(field_name, data_type, description, required=False): return SchemaField( field_name, data_type, description=description, mode="REQUIRED" if required else "NULLABLE", ) schema_fields = [ schema_field("cluster_name", "STRING", "cluster name", required=True), schema_field("cluster_id", "STRING", "UUID for the cluster", required=True), schema_field("entry_uuid", "STRING", "entry UUID for the job row", required=True), schema_field( "job_db_uuid", "INT64", "job db index from the slurm database", required=True ), schema_field("job_id_raw", "INT64", "raw job id", required=True), schema_field("job_id", "STRING", "job id", required=True), schema_field("state", "STRING", "final job state", required=True), schema_field("job_name", "STRING", "job name"), schema_field("partition", "STRING", "job partition"), schema_field("submit_time", "DATETIME", "job submit time"), schema_field("start_time", "DATETIME", "job start time"), schema_field("end_time", "DATETIME", "job end time"), schema_field("elapsed_raw", "INT64", "STRING", "job run time in seconds"), # schema_field("elapsed_time", "INTERVAL", "STRING", "job run time interval"), schema_field("timelimit_raw", "STRING", "job timelimit in minutes"), schema_field("timelimit", "STRING", "job timelimit"), # schema_field("num_tasks", "INT64", "number of allocated tasks in job"), schema_field("nodelist", "STRING", "names of nodes allocated to job"), schema_field("user", "STRING", "user responsible for job"), schema_field("uid", "INT64", "uid of job user"), schema_field("group", "STRING", "group of job user"), schema_field("gid", "INT64", "gid of job user"), schema_field("wckey", "STRING", "job wckey"), schema_field("qos", "STRING", "job qos"), schema_field("comment", "STRING", "job comment"), schema_field("admin_comment", "STRING", "job admin comment"), # extra will be added in 23.02 # schema_field("extra", "STRING", "job extra field"), schema_field("exitcode", "STRING", "job exit code"), schema_field("alloc_cpus", "INT64", "count of allocated CPUs"), schema_field("alloc_nodes", "INT64", "number of nodes allocated to job"), schema_field("alloc_tres", "STRING", "allocated trackable resources (TRES)"), # schema_field("system_cpu", "INTERVAL", "cpu time used by parent processes"), # schema_field("cpu_time", "INTERVAL", "CPU time used (elapsed * cpu count)"), schema_field("cpu_time_raw", "INT64", "CPU time used (elapsed * cpu count)"), # schema_field("ave_cpu", "INT64", "Average CPU time of all tasks in job"), # schema_field( # "tres_usage_tot", # "STRING", # "Tres total usage by all tasks in job", # ), ] slurm_field_map = { "job_db_uuid": "DBIndex", "job_id_raw": "JobIDRaw", "job_id": "JobID", "state": "State", "job_name": "JobName", "partition": "Partition", "submit_time": "Submit", "start_time": "Start", "end_time": "End", "elapsed_raw": "ElapsedRaw", "elapsed_time": "Elapsed", "timelimit_raw": "TimelimitRaw", "timelimit": "Timelimit", "num_tasks": "NTasks", "nodelist": "Nodelist", "user": "User", "uid": "Uid", "group": "Group", "gid": "Gid", "wckey": "Wckey", "qos": "Qos", "comment": "Comment", "admin_comment": "AdminComment", # "extra": "Extra", "exit_code": "ExitCode", "alloc_cpus": "AllocCPUs", "alloc_nodes": "AllocNodes", "alloc_tres": "AllocTres", "system_cpu": "SystemCPU", "cpu_time": "CPUTime", "cpu_time_raw": "CPUTimeRaw", "ave_cpu": "AveCPU", "tres_usage_tot": "TresUsageInTot", } # new field name is the key for job_schema. Used to lookup the datatype when # creating the job rows job_schema = {field.name: field for field in schema_fields} # Order is important here, as that is how they are parsed from sacct output Job = namedtuple("Job", job_schema.keys()) client = bq.Client( project=cfg.project, credentials=util.default_credentials(), client_options=util.create_client_options(util.ApiEndpoint.BQ), ) dataset_id = f"{cfg.slurm_cluster_name}_job_data" dataset = bq.DatasetReference(project=cfg.project, dataset_id=dataset_id) table = bq.Table( bq.TableReference(dataset, f"jobs_{cfg.slurm_cluster_name}"), schema_fields ) class JobInsertionFailed(Exception): pass def make_job_row(job): job_row = { field_name: dict.get(converters, field.field_type)(job[field_name]) for field_name, field in job_schema.items() if field_name in job } job_row["entry_uuid"] = uuid.uuid4().hex job_row["cluster_id"] = cfg.cluster_id job_row["cluster_name"] = cfg.slurm_cluster_name return job_row def load_slurm_jobs(start, end): states = ",".join( ( "BOOT_FAIL", "CANCELLED", "COMPLETED", "DEADLINE", "FAILED", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED", "REQUEUED", "REVOKED", "TIMEOUT", ) ) start_iso = start.isoformat(timespec="seconds") end_iso = end.isoformat(timespec="seconds") # slurm_fields and bq_fields will be in matching order slurm_fields = ",".join(slurm_field_map.values()) bq_fields = slurm_field_map.keys() cmd = ( f"{SACCT} --start {start_iso} --end {end_iso} -X -D --format={slurm_fields} " f"--state={states} --parsable2 --noheader --allusers --duplicates" ) text = run(cmd).stdout.splitlines() # zip pairs bq_fields with the value from sacct jobs = [dict(zip(bq_fields, line.split("|"))) for line in text] # The job index cache allows us to avoid sending duplicate jobs. This avoids a race condition with updating the database. with shelve.open(str(job_idx_cache_path), flag="r") as job_idx_cache: job_rows = [ make_job_row(job) for job in jobs if str(job["job_db_uuid"]) not in job_idx_cache ] return job_rows def init_table(): global dataset global table dataset = client.create_dataset(dataset, exists_ok=True) table = client.create_table(table, exists_ok=True) until_found = retry.Retry(predicate=retry.if_exception_type(exceptions.NotFound)) table = client.get_table(table, retry=until_found) # cannot add required fields to an existing schema table.schema = schema_fields table = client.update_table(table, ["schema"]) def purge_job_idx_cache(): purge_time = datetime.now() - timedelta(minutes=30) with shelve.open(str(job_idx_cache_path), writeback=True) as cache: to_delete = [] for idx, stamp in cache.items(): if stamp < purge_time: to_delete.append(idx) for idx in to_delete: del cache[idx] def bq_submit(jobs): try: result = client.insert_rows(table, jobs) except exceptions.NotFound as e: print(f"failed to upload job data, table not yet found: {e}") raise e except Exception as e: print(f"failed to upload job data: {e}") raise e if result: pprint(jobs) pprint(result) raise JobInsertionFailed("failed to upload job data to big query") print(f"successfully loaded {len(jobs)} jobs") def get_time_window(): if not timestamp_file.is_file(): timestamp_file.touch() try: timestamp = datetime.strptime( timestamp_file.read_text().rstrip(), SLURM_TIME_FORMAT ) # time window will overlap the previous by 10 minutes. Duplicates will be filtered out by the job_idx_cache start = timestamp - timedelta(minutes=10) except ValueError: # timestamp 1 is 1 second after the epoch; timestamp 0 is special for sacct start = datetime.fromtimestamp(1) # end is now() truncated to the last second end = datetime.now().replace(microsecond=0) return start, end def write_timestamp(time): timestamp_file.write_text(time.isoformat(timespec="seconds")) def update_job_idx_cache(jobs, timestamp): with shelve.open(str(job_idx_cache_path), writeback=True) as job_idx_cache: for job in jobs: job_idx = str(job["job_db_uuid"]) job_idx_cache[job_idx] = timestamp def main(): if not cfg.enable_bigquery_load: print("bigquery load is not currently enabled") exit(0) init_table() start, end = get_time_window() jobs = load_slurm_jobs(start, end) # on failure, an exception will cause the timestamp not to be rewritten. So # it will try again next time. If some writes succeed, we don't currently # have a way to not submit duplicates next time. if jobs: bq_submit(jobs) write_timestamp(end) update_job_idx_cache(jobs, end) parser = argparse.ArgumentParser(description="submit slurm job data to big query") parser.add_argument( "timestamp_file", nargs="?", action="store", type=Path, help="specify timestamp file for reading and writing the time window start. Precedence over TIMESTAMP_FILE env var.", ) purge_job_idx_cache() if __name__ == "__main__": args = parser.parse_args() if args.timestamp_file: timestamp_file = args.timestamp_file.resolve() main()