scripts/update_github_status.py (296 lines of code) (raw):
import asyncio
import aiohttp # type: ignore
import math
import os
import datetime
import re
import boto3 # type: ignore
import json
import io
import argparse
import gzip
import os
from cryptography.hazmat.backends import default_backend
import jwt
import requests
import time
from typing import *
BUCKET = os.getenv("bucket", "ossci-job-status")
APP_ID = int(os.environ["app_id"])
# The private key needs to maintain its newlines, get it via
# $ cat key.pem | tr '\n' '|' | pbcopy
PRIVATE_KEY = os.environ["private_key"].replace("|", "\n")
def app_headers() -> Dict[str, str]:
cert_bytes = PRIVATE_KEY.encode()
private_key = default_backend().load_pem_private_key(cert_bytes, None) # type: ignore
time_since_epoch_in_seconds = int(time.time())
payload = {
# issued at time
"iat": time_since_epoch_in_seconds,
# JWT expiration time (10 minute maximum)
"exp": time_since_epoch_in_seconds + (10 * 60),
# GitHub App's identifier
"iss": APP_ID,
}
actual_jwt = jwt.encode(payload, private_key, algorithm="RS256")
headers = {
"Authorization": f"Bearer {actual_jwt}",
"Accept": "application/vnd.github.machine-man-preview+json",
}
return headers
def jprint(obj: Any) -> None:
print(json.dumps(obj, indent=2))
def installation_id(user: str) -> int:
r_bytes = requests.get(
"https://api.github.com/app/installations", headers=app_headers()
)
r = json.loads(r_bytes.content.decode())
for item in r:
if item["account"]["login"] == user:
return int(item["id"])
raise RuntimeError(f"User {user} not found in {r}")
def user_token(user: str) -> str:
"""
Authorize this request with the GitHub app set by the 'app_id' and
'private_key' environment variables.
1. Get the installation ID for the user that has installed the app
2. Request a new token for that user
3. Return it so it can be used in future API requests
"""
# Hardcode the installation to PyTorch so we can always get a valid ID key
id = installation_id("pytorch")
url = f"https://api.github.com/app/installations/{id}/access_tokens"
r_bytes = requests.post(url, headers=app_headers())
r = json.loads(r_bytes.content.decode())
token = str(r["token"])
return token
if "AWS_KEY_ID" in os.environ and "AWS_SECRET_KEY" in os.environ:
# Use keys for local development
session = boto3.Session(
aws_access_key_id=os.environ.get("AWS_KEY_ID"),
aws_secret_access_key=os.environ.get("AWS_SECRET_KEY"),
)
else:
# In the Lambda, use permissions on the Lambda's role
session = boto3.Session()
s3 = session.resource("s3")
def compress_query(query: str) -> str:
query = query.replace("\n", "")
query = re.sub("\s+", " ", query)
return query
def head_commit_query(user: str, repo: str, branches: List[str]) -> str:
"""
Fetch the head commit for a list of branches
"""
def branch_part(branch: str, num: int) -> str:
return f"""
r{num}: repository(name: "{repo}", owner: "{user}") {{
ref(qualifiedName:"refs/heads/{branch}") {{
name
target {{
... on Commit {{
oid
}}
}}
}}
}}
"""
parts = [branch_part(branch, i) for i, branch in enumerate(branches)]
return "{" + "\n".join(parts) + "}"
def extract_gha(suites: List[Dict[str, Any]]) -> List[Dict[str, str]]:
jobs = []
for suite in suites:
suite = suite["node"]
if suite["workflowRun"] is None:
# If no jobs were triggered this will be empty
continue
workflow = suite["workflowRun"]["workflow"]["name"]
for run in suite["checkRuns"]["nodes"]:
conclusion = run["conclusion"]
if conclusion is None:
if run["status"].lower() == "queued":
conclusion = "queued"
elif run["status"].lower() == "in_progress":
conclusion = "pending"
else:
raise RuntimeError(f"unexpected run {run}")
jobs.append(
{
"name": f"{workflow} / {run['name']}",
"status": conclusion.lower(),
"url": run["detailsUrl"],
}
)
return jobs
def extract_status(contexts: List[Dict[str, Any]]) -> List[Dict[str, str]]:
jobs = []
for context in contexts:
jobs.append(
{
"name": context["context"],
"status": context["state"].lower(),
"url": context["targetUrl"],
}
)
return jobs
def extract_jobs(raw_commits: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
commits = []
for raw_commit in raw_commits:
if raw_commit["status"] is None:
# Will be none if no non-GHA jobs were triggered
status = []
else:
status = extract_status(raw_commit["status"]["contexts"])
gha = extract_gha(raw_commit["checkSuites"]["edges"])
jobs = status + gha
if raw_commit["author"]["user"] is None:
author = raw_commit["author"]["name"]
else:
author = raw_commit["author"]["user"]["login"]
commits.append(
{
"sha": raw_commit["oid"],
"headline": raw_commit["messageHeadline"],
"body": raw_commit["messageBody"],
"author": author,
"date": raw_commit["authoredDate"],
"jobs": jobs,
}
)
return commits
class BranchHandler:
def __init__(
self,
gql: Any,
user: str,
repo: str,
name: str,
head: str,
history_size: int,
fetch_size: int,
):
self.gql = gql
self.user = user
self.repo = repo
self.name = name
self.head = head
self.fetch_size = fetch_size
self.history_size = history_size
def write_to_s3(self, data: Any) -> None:
content = json.dumps(data, default=str)
buf = io.BytesIO()
gzipfile = gzip.GzipFile(fileobj=buf, mode="w")
gzipfile.write(content.encode())
gzipfile.close()
bucket = s3.Bucket(BUCKET)
prefix = f"v6/{self.user}/{self.repo}/{self.name.replace('/', '_')}.json"
bucket.put_object(
Key=prefix,
Body=buf.getvalue(),
ContentType="application/json",
ContentEncoding="gzip",
Expires="0",
)
print(f"Wrote {len(data)} commits from {self.name} to {prefix}")
def query(self, offset: int) -> str:
after = ""
# The cursor for fetches are formatted like after: "<sha> <offset>", but
# the first commit isn't included, so shift all the offsets and don't
# use an "after" for the first batch
if offset > 0:
after = f', after: "{self.head} {offset - 1}"'
return f"""
{{
repository(name: "{self.repo}", owner: "{self.user}") {{
ref(qualifiedName:"refs/heads/{self.name}") {{
name
target {{
... on Commit {{
history(first:{self.fetch_size}{after}) {{
nodes {{
oid
messageBody
messageHeadline
author {{
name
user {{
login
}}
}}
authoredDate
checkSuites(first:100) {{
edges {{
node {{
checkRuns(first:100) {{
nodes {{
name
status
conclusion
detailsUrl
}}
}}
workflowRun {{
workflow {{
name
}}
}}
}}
}}
}}
status {{
contexts {{
context
state
targetUrl
}}
}}
}}
}}
}}
}}
}}
}}
}}
"""
def check_response(self, gql_response: Any) -> None:
# Just check that this path in the dict exists
gql_response["data"]["repository"]["ref"]["target"]["history"]["nodes"]
async def run(self) -> None:
"""
Fetch history for the branch (in batches) and merge them all together
"""
# GitHub's API errors out if you try to fetch too much data at once, so
# split up the 100 commits into batches of 'self.fetch_size'
fetches = math.ceil(self.history_size / self.fetch_size)
async def fetch(i: int) -> Any:
try:
return await self.gql.query(
self.query(offset=self.fetch_size * i), verify=self.check_response
)
except Exception as e:
print(
f"Error: {e}\nFailed to fetch {self.user}/{self.repo}/{self.name} on batch {i} / {fetches}"
)
return None
coros = [fetch(i) for i in range(fetches)]
result = await asyncio.gather(*coros)
raw_commits = []
print(f"Parsing results {self.name}")
# Merge all the batches
for r in result:
if r is None:
continue
try:
commits_batch = r["data"]["repository"]["ref"]["target"]["history"][
"nodes"
]
raw_commits += commits_batch
except Exception as e:
# Errors here are expected if the branch has less than HISTORY_SIZE
# commits (GitHub will just time out). There's no easy way to find
# this number ahead of time and avoid errors, but if we had that
# then we could delete this try-catch.
print(f"Error: Didn't find history in commit batch: {e}\n{r}")
# Pull out the data and format it
commits = extract_jobs(raw_commits)
print(f"Writing results for {self.name} to S3")
# Store gzip'ed data to S3
# print(len(commits))
# print(commits)
self.write_to_s3(commits)
class GraphQL:
def __init__(self, session: aiohttp.ClientSession) -> None:
self.session = session
def log_rate_limit(self, headers: Any) -> None:
remaining = headers.get("X-RateLimit-Remaining")
used = headers.get("X-RateLimit-Used")
total = headers.get("X-RateLimit-Limit")
reset_timestamp = int(headers.get("X-RateLimit-Reset", 0)) # type: ignore
reset = datetime.datetime.fromtimestamp(reset_timestamp).strftime(
"%a, %d %b %Y %H:%M:%S"
)
print(
f"[rate limit] Used {used}, {remaining} / {total} remaining, reset at {reset}"
)
async def query(
self,
query: str,
verify: Optional[Callable[[Any], None]] = None,
retries: int = 5,
) -> Any:
"""
Run an authenticated GraphQL query
"""
# Remove unnecessary white space
query = compress_query(query)
if retries <= 0:
raise RuntimeError(f"Query {query[:100]} failed, no retries left")
url = "https://api.github.com/graphql"
try:
async with self.session.post(url, json={"query": query}) as resp:
self.log_rate_limit(resp.headers)
r = await resp.json()
if "data" not in r:
raise RuntimeError(r)
if verify is not None:
verify(r)
return r
except Exception as e:
print(
f"Retrying query {query[:100]}, remaining attempts: {retries - 1}\n{e}"
)
return await self.query(query, verify=verify, retries=retries - 1)
async def main(
user: str, repo: str, branches: List[str], history_size: int, fetch_size: int
) -> None:
"""
Grab a list of all the head commits for each branch, then fetch all the jobs
for the last 'history_size' commits on that branch
"""
async with aiohttp.ClientSession(
headers={
"Authorization": "token {}".format(user_token(user)),
"Accept": "application/vnd.github.machine-man-preview+json",
}
) as aiosession:
gql = GraphQL(aiosession)
print(f"Querying branches: {branches}")
heads = await gql.query(head_commit_query(user, repo, branches))
handlers = []
for head in heads["data"].values():
sha = head["ref"]["target"]["oid"]
branch = head["ref"]["name"]
handlers.append(
BranchHandler(gql, user, repo, branch, sha, history_size, fetch_size)
)
await asyncio.gather(*[h.run() for h in handlers])
def lambda_handler(event: Any, context: Any) -> None:
"""
'event' here is the payload configured from EventBridge (or set manually
via environment variables)
"""
data: Dict[str, Any] = {
"branches": None,
"user": None,
"repo": None,
"history_size": None,
"fetch_size": None,
}
for key in data.keys():
if key in os.environ:
data[key] = os.environ[key]
else:
data[key] = event[key]
if any(x is None for x in data.values()):
raise RuntimeError(
"Data missing from configuration, it must be set as an environment "
f"variable or as the input JSON payload in the Lambda event:\n{data}"
)
data["history_size"] = int(data["history_size"])
data["fetch_size"] = int(data["fetch_size"])
data["branches"] = data["branches"].split(",")
# return
asyncio.run(main(**data))
# if os.getenv("DEBUG", "0") == "1":
# # For local development
# lambda_handler(
# {
# "branches": "release/1.10",
# "user": "pytorch",
# "repo": "pytorch",
# "history_size": 100,
# "fetch_size": 10,
# },
# None,
# )
parser = argparse.ArgumentParser(description="Update JSON in S3 for a branch")
parser.add_argument("--branch", required=True)
parser.add_argument("--repo", required=True)
parser.add_argument("--user", required=True)
parser.add_argument("--fetch_size", default=10)
parser.add_argument("--history_size", default=100)
args = parser.parse_args()
lambda_handler(
{
"branches": args.branch,
"user": args.user,
"repo": args.repo,
"history_size": int(args.history_size),
"fetch_size": int(args.fetch_size),
},
None,
)