jobs/dap-collector-ppa-dev/dap_collector_ppa_dev/main.py (280 lines of code) (raw):
import asyncio
import click
import datetime
import math
import time
from google.cloud import bigquery
import requests
LEADER = "https://dap-09-3.api.divviup.org"
CMD = f"./collect --task-id {{task_id}} --leader {LEADER} --vdaf {{vdaf}} {{vdaf_args}} --authorization-bearer-token {{auth_token}} --batch-interval-start {{timestamp}} --batch-interval-duration {{duration}} --hpke-config {{hpke_config}} --hpke-private-key {{hpke_private_key}}"
MINUTES_IN_DAY = 1440
# The modulo prime for the field for Prio3SumVec, and its size in bits. We use these to detect and counteract negative
# conversion counts (as a result of differential privacy noise being added) wrapping around.
#
# Note that these values are specific to the data type we use for our tasks. If we start using a different type (e.g.
# Prio3Histogram), the values will need to be adjusted.
#
# https://github.com/divviup/libprio-rs/blob/a85d271ddee087f13dfd847a7170786f35abd0b9/src/vdaf/prio3.rs#L88
# https://github.com/divviup/libprio-rs/blob/a85d271ddee087f13dfd847a7170786f35abd0b9/src/fp.rs#L87
FIELD_PRIME = 340282366920938462946865773367900766209
FIELD_SIZE = 128
ADS_SCHEMA = [
bigquery.SchemaField("collection_time", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField("placement_id", "STRING", mode="REQUIRED"),
bigquery.SchemaField("ad_id", "STRING", mode="REQUIRED"),
bigquery.SchemaField("conversion_key", "STRING", mode="REQUIRED"),
bigquery.SchemaField("task_size", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("task_id", "STRING", mode="REQUIRED"),
bigquery.SchemaField("task_index", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("conversion_count", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("advertiser_id", "STRING", mode="REQUIRED"),
bigquery.SchemaField("advertiser_name", "STRING", mode="REQUIRED"),
bigquery.SchemaField("campaign_id", "STRING", mode="REQUIRED"),
]
REPORT_SCHEMA = [
bigquery.SchemaField("collection_time", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField("collection_duration", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("task_id", "STRING", mode="REQUIRED"),
bigquery.SchemaField("metric_type", "STRING", mode="REQUIRED"),
bigquery.SchemaField("slot_start", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField("report_count", "INTEGER"),
bigquery.SchemaField("error", "STRING"),
bigquery.SchemaField("value", "INTEGER", mode="REPEATED"),
]
ads = {}
def read_json(config_url):
"""Read configuration from Google Cloud bucket."""
resp = requests.get(config_url)
return resp.json()
def toh(timestamp):
"""Turn a timestamp into a datetime object which prints human readably."""
return datetime.datetime.fromtimestamp(timestamp, datetime.timezone.utc)
async def collect_once(task, timestamp, duration, hpke_private_key, auth_token):
"""Runs collection for a single time interval.
This uses the Janus collect binary. The result is formatted to fit the BQ table.
"""
collection_time = str(datetime.datetime.now(datetime.timezone.utc).timestamp())
print(f"{collection_time} Collecting {toh(timestamp)} - {toh(timestamp+duration)}")
# Prepare output
res = {}
res["reports"] = []
res["counts"] = []
rpt = build_base_report(task["task_id"], timestamp, task["metric_type"], collection_time)
# Convert VDAF description to string for command line use
vdaf_args = ""
for k, v in task["vdaf_args_structured"].items():
vdaf_args += f" --{k} {v}"
cmd = CMD.format(
timestamp=timestamp,
duration=duration,
hpke_private_key=hpke_private_key,
auth_token=auth_token,
task_id=task["task_id"],
vdaf=task["vdaf"],
vdaf_args=vdaf_args,
hpke_config=task["hpke_config"],
)
# How long an individual collection can take before it is killed.
timeout = 100
start_counter = time.perf_counter()
try:
proc = await asyncio.wait_for(
asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
),
timeout,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout)
stdout = stdout.decode()
stderr = stderr.decode()
except asyncio.exceptions.TimeoutError:
rpt["collection_duration"] = time.perf_counter() - start_counter
rpt["error"] = "TIMEOUT"
res["reports"].append(rpt)
return res
print(f"{timestamp} Result: code {proc.returncode}")
rpt["collection_duration"] = time.perf_counter() - start_counter
# Parse the output of the collect binary
if proc.returncode == 1:
if (
stderr
== "Error: HTTP response status 400 Bad Request - The number of reports included in the batch is invalid.\n"
):
rpt["error"] = "BATCH TOO SMALL"
else:
rpt["error"] = f"UNHANDLED ERROR: {stderr}"
else:
for line in stdout.splitlines():
if line.startswith("Aggregation result:"):
entries = parse_vector(line[21:-1])
rpt["value"] = entries
for i, entry in enumerate(entries):
ad = get_ad(task["task_id"], i)
print(task["task_id"], i, ad)
if ad is not None:
cnt = {}
cnt["collection_time"] = timestamp
cnt["placement_id"] = ad["advertiserInfo"]["placementId"]
cnt["advertiser_id"] = ad["advertiserInfo"]["advertiserId"]
cnt["advertiser_name"] = ad["advertiserInfo"]["advertiserName"]
cnt["ad_id"] = ad["advertiserInfo"]["adId"]
cnt["conversion_key"] = ad["advertiserInfo"]["conversionKey"]
cnt["task_id"] = task["task_id"]
cnt["task_index"] = i
cnt["task_size"] = task["task_size"]
cnt["campaign_id"] = ad["advertiserInfo"]["campaignId"]
cnt["conversion_count"] = entry
res["counts"].append(cnt)
elif line.startswith("Number of reports:"):
rpt["report_count"] = int(line.split()[-1].strip())
elif (
line.startswith("Interval start:")
or line.startswith("Interval end:")
or line.startswith("Interval length:")
):
# irrelevant since we are using time interval queries
continue
else:
print(f"UNHANDLED OUTPUT LINE: {line}")
raise NotImplementedError
res["reports"].append(rpt)
return res
def parse_vector(histogram_str):
count_strs = histogram_str.split(",")
return [
correct_wraparound(int(count_str))
for count_str in count_strs
]
def correct_wraparound(num):
cutoff = 2 ** (FIELD_SIZE - 1)
if num > cutoff:
return num - FIELD_PRIME
return num
def build_base_report(task_id, timestamp, metric_type, collection_time):
row = {}
row["task_id"] = task_id
row["slot_start"] = timestamp
row["metric_type"] = metric_type
row["collection_time"] = collection_time
return row
def build_error_result(task_id, timestamp, metric_type, error):
collection_time = str(datetime.datetime.now(datetime.timezone.utc).timestamp())
results = {}
results["counts"] = []
results["reports"] = []
slot_start = int(timestamp.timestamp())
rpt = build_base_report(task_id, slot_start, metric_type, collection_time)
rpt["collection_duration"] = 0
rpt["error"] = error
results["reports"].append(rpt)
return results
def get_ad(task_id, index):
global ads
for ad in ads:
if ad["taskId"] == task_id and ad["taskIndex"] == index:
return ad
async def process_queue(q: asyncio.Queue, results: dict):
"""Worker for parallelism. Processes items from the queue until it is empty."""
while not q.empty():
job = q.get_nowait()
res = await collect_once(*job)
results["reports"] += res["reports"]
results["counts"] += res["counts"]
async def collect_many(
task, time_from, time_until, interval_length, hpke_private_key, auth_token
):
"""Collects data for a given time interval.
Creates a configurable amount of workers which process jobs from a queue
for parallelism.
"""
time_from = int(time_from.timestamp())
time_until = int(time_until.timestamp())
start = math.ceil(time_from // interval_length) * interval_length
jobs = asyncio.Queue(288)
results = {}
results["reports"] = []
results["counts"] = []
while start + interval_length <= time_until:
await jobs.put((task, start, interval_length, hpke_private_key, auth_token))
start += interval_length
workers = []
for _ in range(10):
workers.append(process_queue(jobs, results))
await asyncio.gather(*workers)
return results
def check_collection_date(date):
# collector should collect through to the beginning of a day
if date.hour != 0 or date.minute != 0 or date.second != 0:
return f"Collection date is not at beginning of a day {date}"
else:
return None
def check_time_precision(time_precision_minutes, end_collection_date):
"""Check that a given time precision is valid for the collection date
"""
end_collection_date_seconds = int(end_collection_date.timestamp())
if time_precision_minutes is None:
# task is missing a time precision setting
return f"Task missing time time_precision_minutes value"
elif time_precision_minutes < MINUTES_IN_DAY:
if MINUTES_IN_DAY % time_precision_minutes > 0:
# time precision has to evenly divide a day in order for this collector code to query all aggregations
return f"Task has time precision that does not evenly divide a day"
elif time_precision_minutes % MINUTES_IN_DAY != 0:
# time precision is a day or longer, but is not a multiple of a day
return f"Task has time precision that is not an even multiple of a day"
elif end_collection_date_seconds % (time_precision_minutes*60) != 0:
# time precision is a multiple of day, but the end does not align with this task's buckets
return f"{end_collection_date} does not align with task aggregation buckets"
return None
async def collect_task(task, auth_token, hpke_private_key, date):
"""Collects data for the given task through to the given day.
For tasks with time precision smaller than a day, will collect data for aggregations from the day prior to date.
For tasks with time precision a day or multiple of day, will collect data for the aggregation that ends on date.
If date does not align with the end of an aggregation, it will not collect anything.
"""
end_collection_date = datetime.datetime.fromisoformat(date)
end_collection_date = end_collection_date.replace(tzinfo=datetime.timezone.utc)
time_precision_minutes = task["time_precision_minutes"]
err = check_collection_date(end_collection_date)
if err is not None:
return build_error_result(task["task_id"], end_collection_date, task["metric_type"], err)
err = check_time_precision(time_precision_minutes, end_collection_date)
if err is not None:
return build_error_result(task["task_id"], end_collection_date, task["metric_type"], err)
# task precision and date are valid
if time_precision_minutes < MINUTES_IN_DAY:
# time precision is shorter than daily
# query for the last day of aggregations
start_collection_date = end_collection_date - datetime.timedelta(days=1)
else:
# time precision is a multiple of a day
# query for the aggregation that ends at end_collection_date
aggregation_days = time_precision_minutes/MINUTES_IN_DAY
start_collection_date = end_collection_date - datetime.timedelta(days=aggregation_days)
return await collect_many(
task, start_collection_date, end_collection_date, time_precision_minutes * 60, hpke_private_key, auth_token
)
def ensure_table(bqclient, table_id, schema):
"""Checks if the table exists in BQ and creates it otherwise.
Fails if the table exists but has the wrong schema.
"""
table = bigquery.Table(table_id, schema=schema)
print(f"Making sure the table {table_id} exists.")
table = bqclient.create_table(table, exists_ok=True)
def store_data(results, bqclient, table_id):
"""Inserts the results into BQ. Assumes that they are already in the right format"""
if results:
insert_res = bqclient.insert_rows_json(table=table_id, json_rows=results)
if len(insert_res) != 0:
print(insert_res)
assert len(insert_res) == 0
@click.command()
@click.option("--project", help="GCP project id", required=True)
@click.option(
"--ad-table-id",
help="The aggregated DAP measurements will be stored in this table.",
required=True,
)
@click.option(
"--report-table-id",
help="The aggregated DAP measurements will be stored in this table.",
required=True,
)
@click.option(
"--auth-token",
envvar='AUTH_TOKEN',
help="HTTP bearer token to authenticate to the leader",
required=True,
)
@click.option(
"--hpke-private-key",
envvar='HPKE_PRIVATE_KEY',
help="The private key used to decrypt shares from the leader and helper.",
required=True,
)
@click.option(
"--date",
help="Date at which the backfill will start, going backwards (YYYY-MM-DD)",
required=True,
)
@click.option(
"--task-config-url",
help="URL where a JSON definition of the tasks to be collected can be found.",
required=True,
)
@click.option(
"--ad-config-url",
help="URL where a JSON definition of the ads to task map can be found.",
required=True,
)
def main(project, ad_table_id, report_table_id, auth_token, hpke_private_key, date, task_config_url, ad_config_url):
global ads
ad_table_id = project + "." + ad_table_id
report_table_id = project + "." + report_table_id
bqclient = bigquery.Client(project=project)
ads = read_json(ad_config_url)
ensure_table(bqclient, ad_table_id, ADS_SCHEMA)
ensure_table(bqclient, report_table_id, REPORT_SCHEMA)
for task in read_json(task_config_url):
print(f"Now processing task: {task['task_id']}")
results = asyncio.run(collect_task(task, auth_token, hpke_private_key, date))
store_data(results["reports"], bqclient, report_table_id)
store_data(results["counts"], bqclient, ad_table_id)
if __name__ == "__main__":
main()