ansible/roles/slurm/files/scripts/load_bq.py (248 lines of code) (raw):
#!/usr/bin/env python3
import argparse
import os
import shelve
import uuid
from collections import namedtuple
from datetime import datetime, timezone, timedelta
from pathlib import Path
from pprint import pprint
from google.cloud.bigquery import SchemaField
from google.cloud import bigquery as bq
from google.api_core import retry, exceptions
import util
from util import run
from util import cfg
SACCT = "sacct"
script = Path(__file__).resolve()
DEFAULT_TIMESTAMP_FILE = script.parent / "bq_timestamp"
timestamp_file = Path(os.environ.get("TIMESTAMP_FILE", DEFAULT_TIMESTAMP_FILE))
# cluster_id_file = script.parent / 'cluster_uuid'
# try:
# cluster_id = cluster_id_file.read_text().rstrip()
# except FileNotFoundError:
# cluster_id = uuid.uuid4().hex
# cluster_id_file.write_text(cluster_id)
job_idx_cache_path = script.parent / "bq_job_idx_cache"
SLURM_TIME_FORMAT = r"%Y-%m-%dT%H:%M:%S"
def make_datetime(time_string):
return datetime.strptime(time_string, SLURM_TIME_FORMAT).replace(
tzinfo=timezone.utc
)
def make_time_interval(seconds):
sign = 1
if seconds < 0:
sign = -1
seconds = abs(seconds)
d, r = divmod(seconds, 60 * 60 * 24)
h, r = divmod(r, 60 * 60)
m, s = divmod(r, 60)
d *= sign
h *= sign
return f"{d}D {h:02}:{m:02}:{s}"
converters = {
"DATETIME": make_datetime,
"INTERVAL": make_time_interval,
"STRING": str,
"INT64": lambda n: int(n or 0),
}
def schema_field(field_name, data_type, description, required=False):
return SchemaField(
field_name,
data_type,
description=description,
mode="REQUIRED" if required else "NULLABLE",
)
schema_fields = [
schema_field("cluster_name", "STRING", "cluster name", required=True),
schema_field("cluster_id", "STRING", "UUID for the cluster", required=True),
schema_field("entry_uuid", "STRING", "entry UUID for the job row", required=True),
schema_field(
"job_db_uuid", "INT64", "job db index from the slurm database", required=True
),
schema_field("job_id_raw", "INT64", "raw job id", required=True),
schema_field("job_id", "STRING", "job id", required=True),
schema_field("state", "STRING", "final job state", required=True),
schema_field("job_name", "STRING", "job name"),
schema_field("partition", "STRING", "job partition"),
schema_field("submit_time", "DATETIME", "job submit time"),
schema_field("start_time", "DATETIME", "job start time"),
schema_field("end_time", "DATETIME", "job end time"),
schema_field("elapsed_raw", "INT64", "STRING", "job run time in seconds"),
# schema_field("elapsed_time", "INTERVAL", "STRING", "job run time interval"),
schema_field("timelimit_raw", "STRING", "job timelimit in minutes"),
schema_field("timelimit", "STRING", "job timelimit"),
# schema_field("num_tasks", "INT64", "number of allocated tasks in job"),
schema_field("nodelist", "STRING", "names of nodes allocated to job"),
schema_field("user", "STRING", "user responsible for job"),
schema_field("uid", "INT64", "uid of job user"),
schema_field("group", "STRING", "group of job user"),
schema_field("gid", "INT64", "gid of job user"),
schema_field("wckey", "STRING", "job wckey"),
schema_field("qos", "STRING", "job qos"),
schema_field("comment", "STRING", "job comment"),
schema_field("admin_comment", "STRING", "job admin comment"),
# extra will be added in 23.02
# schema_field("extra", "STRING", "job extra field"),
schema_field("exitcode", "STRING", "job exit code"),
schema_field("alloc_cpus", "INT64", "count of allocated CPUs"),
schema_field("alloc_nodes", "INT64", "number of nodes allocated to job"),
schema_field("alloc_tres", "STRING", "allocated trackable resources (TRES)"),
# schema_field("system_cpu", "INTERVAL", "cpu time used by parent processes"),
# schema_field("cpu_time", "INTERVAL", "CPU time used (elapsed * cpu count)"),
schema_field("cpu_time_raw", "INT64", "CPU time used (elapsed * cpu count)"),
# schema_field("ave_cpu", "INT64", "Average CPU time of all tasks in job"),
# schema_field(
# "tres_usage_tot",
# "STRING",
# "Tres total usage by all tasks in job",
# ),
]
slurm_field_map = {
"job_db_uuid": "DBIndex",
"job_id_raw": "JobIDRaw",
"job_id": "JobID",
"state": "State",
"job_name": "JobName",
"partition": "Partition",
"submit_time": "Submit",
"start_time": "Start",
"end_time": "End",
"elapsed_raw": "ElapsedRaw",
"elapsed_time": "Elapsed",
"timelimit_raw": "TimelimitRaw",
"timelimit": "Timelimit",
"num_tasks": "NTasks",
"nodelist": "Nodelist",
"user": "User",
"uid": "Uid",
"group": "Group",
"gid": "Gid",
"wckey": "Wckey",
"qos": "Qos",
"comment": "Comment",
"admin_comment": "AdminComment",
# "extra": "Extra",
"exit_code": "ExitCode",
"alloc_cpus": "AllocCPUs",
"alloc_nodes": "AllocNodes",
"alloc_tres": "AllocTres",
"system_cpu": "SystemCPU",
"cpu_time": "CPUTime",
"cpu_time_raw": "CPUTimeRaw",
"ave_cpu": "AveCPU",
"tres_usage_tot": "TresUsageInTot",
}
# new field name is the key for job_schema. Used to lookup the datatype when
# creating the job rows
job_schema = {field.name: field for field in schema_fields}
# Order is important here, as that is how they are parsed from sacct output
Job = namedtuple("Job", job_schema.keys())
client = bq.Client(
project=cfg.project,
credentials=util.default_credentials(),
client_options=util.create_client_options(util.ApiEndpoint.BQ),
)
dataset_id = f"{cfg.slurm_cluster_name}_job_data"
dataset = bq.DatasetReference(project=cfg.project, dataset_id=dataset_id)
table = bq.Table(
bq.TableReference(dataset, f"jobs_{cfg.slurm_cluster_name}"), schema_fields
)
class JobInsertionFailed(Exception):
pass
def make_job_row(job):
job_row = {
field_name: dict.get(converters, field.field_type)(job[field_name])
for field_name, field in job_schema.items()
if field_name in job
}
job_row["entry_uuid"] = uuid.uuid4().hex
job_row["cluster_id"] = cfg.cluster_id
job_row["cluster_name"] = cfg.slurm_cluster_name
return job_row
def load_slurm_jobs(start, end):
states = ",".join(
(
"BOOT_FAIL",
"CANCELLED",
"COMPLETED",
"DEADLINE",
"FAILED",
"NODE_FAIL",
"OUT_OF_MEMORY",
"PREEMPTED",
"REQUEUED",
"REVOKED",
"TIMEOUT",
)
)
start_iso = start.isoformat(timespec="seconds")
end_iso = end.isoformat(timespec="seconds")
# slurm_fields and bq_fields will be in matching order
slurm_fields = ",".join(slurm_field_map.values())
bq_fields = slurm_field_map.keys()
cmd = (
f"{SACCT} --start {start_iso} --end {end_iso} -X -D --format={slurm_fields} "
f"--state={states} --parsable2 --noheader --allusers --duplicates"
)
text = run(cmd).stdout.splitlines()
# zip pairs bq_fields with the value from sacct
jobs = [dict(zip(bq_fields, line.split("|"))) for line in text]
# The job index cache allows us to avoid sending duplicate jobs. This avoids a race condition with updating the database.
with shelve.open(str(job_idx_cache_path), flag="r") as job_idx_cache:
job_rows = [
make_job_row(job)
for job in jobs
if str(job["job_db_uuid"]) not in job_idx_cache
]
return job_rows
def init_table():
global dataset
global table
dataset = client.create_dataset(dataset, exists_ok=True)
table = client.create_table(table, exists_ok=True)
until_found = retry.Retry(predicate=retry.if_exception_type(exceptions.NotFound))
table = client.get_table(table, retry=until_found)
# cannot add required fields to an existing schema
table.schema = schema_fields
table = client.update_table(table, ["schema"])
def purge_job_idx_cache():
purge_time = datetime.now() - timedelta(minutes=30)
with shelve.open(str(job_idx_cache_path), writeback=True) as cache:
to_delete = []
for idx, stamp in cache.items():
if stamp < purge_time:
to_delete.append(idx)
for idx in to_delete:
del cache[idx]
def bq_submit(jobs):
try:
result = client.insert_rows(table, jobs)
except exceptions.NotFound as e:
print(f"failed to upload job data, table not yet found: {e}")
raise e
except Exception as e:
print(f"failed to upload job data: {e}")
raise e
if result:
pprint(jobs)
pprint(result)
raise JobInsertionFailed("failed to upload job data to big query")
print(f"successfully loaded {len(jobs)} jobs")
def get_time_window():
if not timestamp_file.is_file():
timestamp_file.touch()
try:
timestamp = datetime.strptime(
timestamp_file.read_text().rstrip(), SLURM_TIME_FORMAT
)
# time window will overlap the previous by 10 minutes. Duplicates will be filtered out by the job_idx_cache
start = timestamp - timedelta(minutes=10)
except ValueError:
# timestamp 1 is 1 second after the epoch; timestamp 0 is special for sacct
start = datetime.fromtimestamp(1)
# end is now() truncated to the last second
end = datetime.now().replace(microsecond=0)
return start, end
def write_timestamp(time):
timestamp_file.write_text(time.isoformat(timespec="seconds"))
def update_job_idx_cache(jobs, timestamp):
with shelve.open(str(job_idx_cache_path), writeback=True) as job_idx_cache:
for job in jobs:
job_idx = str(job["job_db_uuid"])
job_idx_cache[job_idx] = timestamp
def main():
if not cfg.enable_bigquery_load:
print("bigquery load is not currently enabled")
exit(0)
init_table()
start, end = get_time_window()
jobs = load_slurm_jobs(start, end)
# on failure, an exception will cause the timestamp not to be rewritten. So
# it will try again next time. If some writes succeed, we don't currently
# have a way to not submit duplicates next time.
if jobs:
bq_submit(jobs)
write_timestamp(end)
update_job_idx_cache(jobs, end)
parser = argparse.ArgumentParser(description="submit slurm job data to big query")
parser.add_argument(
"timestamp_file",
nargs="?",
action="store",
type=Path,
help="specify timestamp file for reading and writing the time window start. Precedence over TIMESTAMP_FILE env var.",
)
purge_job_idx_cache()
if __name__ == "__main__":
args = parser.parse_args()
if args.timestamp_file:
timestamp_file = args.timestamp_file.resolve()
main()