jobs/dap-collector/dap_collector/main.py (176 lines of code) (raw):

import asyncio import click import datetime import json import math import os import re import subprocess import time import typing from google.cloud import bigquery import requests LEADER = "https://dap-07-1.api.divviup.org" CMD = f"./collect --task-id {{task_id}} --leader {LEADER} --vdaf {{vdaf}} {{vdaf_args}} --authorization-bearer-token {{auth_token}} --batch-interval-start {{timestamp}} --batch-interval-duration {{duration}} --hpke-config {{hpke_config}} --hpke-private-key {{hpke_private_key}}" INTERVAL_LENGTH = 300 def read_tasks(task_config_url): """Read task configuration from Google Cloud bucket.""" resp = requests.get(task_config_url) tasks = resp.json() return tasks def toh(timestamp): """Turn a timestamp into a datetime object which prints human readably.""" return datetime.datetime.fromtimestamp(timestamp, datetime.timezone.utc) async def collect_once(task, timestamp, duration, hpke_private_key, auth_token): """Runs collection for a single time interval. This uses the Janus collect binary. The result is formatted to fit the BQ table. """ collection_time = str(datetime.datetime.now(datetime.timezone.utc).timestamp()) print(f"{collection_time} Collecting {toh(timestamp)} - {toh(timestamp+duration)}") # Prepare output res = {} res["metric_type"] = task["metric_type"] res["task_id"] = task["task_id"] res["collection_time"] = collection_time res["slot_start"] = timestamp # Convert VDAF description to string for command line use vdaf_args = "" for k, v in task["vdaf_args_structured"].items(): vdaf_args += f" --{k} {v}" cmd = CMD.format( timestamp=timestamp, duration=duration, hpke_private_key=hpke_private_key, auth_token=auth_token, task_id=task["task_id"], vdaf=task["vdaf"], vdaf_args=vdaf_args, hpke_config=task["hpke_config"], ) # How long an individual collection can take before it is killed. timeout = 100 start_counter = time.perf_counter() try: proc = await asyncio.wait_for( asyncio.create_subprocess_shell( cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ), timeout, ) stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout) stdout = stdout.decode() stderr = stderr.decode() except asyncio.exceptions.TimeoutError: res["collection_duration"] = time.perf_counter() - start_counter res["error"] = f"TIMEOUT" return res res["collection_duration"] = time.perf_counter() - start_counter # Parse the output of the collect binary if proc.returncode == 1: if ( stderr == "Error: HTTP response status 400 Bad Request - The number of reports included in the batch is invalid.\n" ): res["error"] = "BATCH TOO SMALL" else: res["error"] = f"UNHANDLED ERROR: {stderr }" else: for line in stdout.splitlines(): if line.startswith("Aggregation result:"): if task["vdaf"] in ["countvec", "sumvec"]: entries = line[21:-1] entries = list(map(int, entries.split(","))) res["value"] = entries elif task["vdaf"] == "sum": s = int(line[20:]) res["value"] = [s] else: raise RuntimeError(f"Unknown VDAF: {task['vdaf']}") elif line.startswith("Number of reports:"): res["report_count"] = int(line.split()[-1].strip()) elif ( line.startswith("Interval start:") or line.startswith("Interval end:") or line.startswith("Interval length:") ): # irrelevant since we are using time interval queries continue else: print(f"UNHANDLED OUTPUT LINE: {line}") raise NotImplementedError return res async def process_queue(q: asyncio.Queue, results: list): """Worker for parallelism. Processes items from the qeueu until it is empty.""" while not q.empty(): job = q.get_nowait() res = await collect_once(*job) results.append(res) async def collect_many( task, time_from, time_until, interval_length, hpke_private_key, auth_token ): """Collects data for a given time interval. Creates a configurable amount of workers which process jobs from a queue for parallelism. """ time_from = int(time_from.timestamp()) time_until = int(time_until.timestamp()) start = math.ceil(time_from // interval_length) * interval_length jobs = asyncio.Queue(288) results = [] while start + interval_length <= time_until: await jobs.put((task, start, interval_length, hpke_private_key, auth_token)) start += interval_length workers = [] for _ in range(10): workers.append(process_queue(jobs, results)) await asyncio.gather(*workers) return results async def collect_task(task, auth_token, hpke_private_key, date): """Collects data for the given task and the given day.""" start = datetime.datetime.fromisoformat(date) start = start.replace(tzinfo=datetime.timezone.utc) end = start + datetime.timedelta(days=1) results = await collect_many( task, start, end, INTERVAL_LENGTH, hpke_private_key, auth_token ) return results def ensure_table(bqclient, table_id): """Checks if the table exists in BQ and creates it otherwise. Fails if the table exists but has the wrong schema. """ schema = [ bigquery.SchemaField("collection_time", "TIMESTAMP", mode="REQUIRED"), bigquery.SchemaField("collection_duration", "FLOAT", mode="REQUIRED"), bigquery.SchemaField("task_id", "STRING", mode="REQUIRED"), bigquery.SchemaField("metric_type", "STRING", mode="REQUIRED"), bigquery.SchemaField("slot_start", "TIMESTAMP", mode="REQUIRED"), bigquery.SchemaField("report_count", "INTEGER"), bigquery.SchemaField("error", "STRING"), bigquery.SchemaField("value", "INTEGER", mode="REPEATED"), ] table = bigquery.Table(table_id, schema=schema) print(f"Making sure the table {table_id} exists.") table = bqclient.create_table(table, exists_ok=True) def store_data(results, bqclient, table_id): """Inserts the results into BQ. Assumes that they are already in the right format""" insert_res = bqclient.insert_rows_json(table=table_id, json_rows=results) if len(insert_res) != 0: print(insert_res) assert len(insert_res) == 0 @click.command() @click.option("--project", help="GCP project id", required=True) @click.option( "--table-id", help="The aggregated DAP measurements will be stored in this table.", required=True, ) @click.option( "--auth-token", help="HTTP bearer token to authenticate to the leader", required=True, ) @click.option( "--hpke-private-key", help="The private key used to decrypt shares from the leader and helper.", required=True, ) @click.option( "--date", help="Date at which the backfill will start, going backwards (YYYY-MM-DD)", required=True, ) @click.option( "--task-config-url", help="URL where a JSON definition of the tasks to be collected can be found.", required=True, ) def main(project, table_id, auth_token, hpke_private_key, date, task_config_url): table_id = project + "." + table_id bqclient = bigquery.Client(project=project) ensure_table(bqclient, table_id) for task in read_tasks(task_config_url): print(f"Now processing task: {task['task_id']}") results = asyncio.run(collect_task(task, auth_token, hpke_private_key, date)) store_data(results, bqclient, table_id) if __name__ == "__main__": main()