sync/gh.py (531 lines of code) (raw):

import itertools import random import re import time import urllib.parse from io import StringIO import github import newrelic from . import log from .env import Environment from datetime import datetime from github.Branch import Branch from github.Commit import Commit from github.PullRequest import PullRequest from github.Repository import Repository from typing import Any, Dict, List, Optional, Tuple, Union logger = log.get_logger(__name__) env = Environment() class CheckRun(github.GithubObject.NonCompletableGithubObject): def _initAttributes(self): self._id = github.GithubObject.NotSet self._status = github.GithubObject.NotSet self._name = github.GithubObject.NotSet self._conclusion = github.GithubObject.NotSet self._url = github.GithubObject.NotSet def _useAttributes(self, attributes): if "id" in attributes: self._id = self._makeIntAttribute(attributes["id"]) if "status" in attributes: self._status = self._makeStringAttribute(attributes["status"]) if "name" in attributes: self._name = self._makeStringAttribute(attributes["name"]) if "conclusion" in attributes: self._conclusion = self._makeStringAttribute(attributes["conclusion"]) if "url" in attributes: self._url = self._makeStringAttribute(attributes["url"]) if "head_sha" in attributes: self._head_sha = self._makeStringAttribute(attributes["head_sha"]) @property def id(self): return self._id.value @property def status(self): return self._status.value @property def name(self): return self._name.value @property def conclusion(self): return self._conclusion.value @property def url(self): return self._url.value @property def head_sha(self): return self._head_sha.value class GitHub: def __init__(self, token: str, url: str) -> None: self.gh = github.Github(token) self.repo_name = urllib.parse.urlsplit(url).path.lstrip("/") self.pr_cache: Dict[int, PullRequest] = {} self._repo: Optional[Repository] = None def pr_url(self, pr_id: int) -> str: return ("%s/pull/%s" % (env.config["web-platform-tests"]["repo"]["url"], pr_id)) def load_pull(self, data: Dict[str, Any]) -> None: pr = self.gh.create_from_raw_data(github.PullRequest.PullRequest, data) self.pr_cache[pr.number] = pr @property def repo(self) -> Repository: if self._repo is None: self._repo = self.gh.get_repo(self.repo_name) assert self._repo is not None return self._repo def get_pull(self, id: int) -> PullRequest: id = int(id) if id not in self.pr_cache: pr = self.repo.get_pull(id) if pr is None: raise ValueError("No pull request with id %s" % id) self.pr_cache[id] = pr return self.pr_cache[id] def create_pull(self, title: str, body: str, base: str, head: str ) -> int: try: pr = self.repo.create_pull(title=title, body=body, base=base, head=head) logger.info("Created PR %s" % pr.number) except github.GithubException: # Check if there's already a PR for this head user = self.repo_name.split("/")[0] pulls = self.repo.get_pulls(head=f"{user}:{head}") entries = list(pulls) if len(entries) == 0: raise elif len(entries) > 1: raise ValueError("Found multiple existing pulls for branch") pr = pulls[0] self.add_labels(pr.number, "mozilla:gecko-sync") self.pr_cache[pr.number] = pr return pr.number def has_branch(self, name: str) -> bool: return self._get_branch(name) is not None def _get_branch(self, name: str) -> Optional[Branch]: try: return self.repo.get_branch(name) except (github.GithubException, github.UnknownObjectException): return None def get_status(self, pr_id: int, context: str ) -> Optional[str]: pr = self.get_pull(pr_id) head_commit = self.repo.get_commit(pr.head.ref) statuses = [item for item in head_commit.get_statuses() if item.context == context] statuses.sort(key=lambda x: -x.id) if statuses: return statuses[0].state return None def set_status(self, pr_id: int, status: str, target_url: Optional[str], description: Optional[str], context: str ): pr = self.get_pull(pr_id) head_commit = self.repo.get_commit(pr.head.ref) kwargs = {} if target_url is not None: kwargs["target_url"] = target_url if description is not None: kwargs["description"] = description head_commit.create_status(status, context=context, **kwargs) def add_labels(self, pr_id: int, *labels: str ): logger.debug("Adding labels {} to PR {}".format(", ".join(labels), pr_id)) pr_id = self._convert_pr_id(pr_id) issue = self.repo.get_issue(pr_id) issue.add_to_labels(*labels) def remove_labels(self, pr_id: int, *labels: str ): logger.debug(f"Removing labels {labels} from PR {pr_id}") pr_id = self._convert_pr_id(pr_id) issue = self.repo.get_issue(pr_id) for label in labels: try: issue.remove_from_labels(label) except github.GithubException as e: if isinstance(e.data, dict): msg = e.data.get("message", "") else: msg = e.data if msg != "Label does not exist": logger.warning("Error handling label removal: %s" % e) newrelic.agent.record_exception() def _convert_pr_id(self, pr_id: Union[str, int] ): # (...) -> int if not isinstance(pr_id, int): try: pr_id = int(pr_id) except ValueError: raise ValueError('PR ID is not a valid number') return pr_id def required_checks(self, branch_name: str) -> List[str]: branch = self._get_branch(branch_name) if branch is None: # TODO: Maybe raise an exception here return [] return (branch.raw_data .get("protection", {}) .get("required_status_checks", {}) .get("contexts", [])) def get_check_runs(self, pr_id: int) -> Dict[str, Dict[str, Any]]: pr = self.get_pull(pr_id) check_runs = list(self._get_check_runs(pr.head.sha)) required_contexts = self.required_checks(pr.base.ref) rv: Dict[str, Dict[str, Any]] = {} id_by_name: Dict[str, int] = {} for item in check_runs: if item.name in id_by_name and item.id < id_by_name[item.name]: continue id_by_name[item.name] = item.id rv[item.name] = { "status": item.status, "conclusion": item.conclusion, "url": item.url, "required": item.name in required_contexts, "head_sha": item.head_sha } return rv def _get_check_runs(self, sha1, check_name=None): query = [] if check_name: query.append(("check_name", check_name)) commit = self.repo.get_commit(sha1) url = commit._parentUrl(commit.url) + "/" + commit.sha + "/check-runs" headers = {"Accept": "application/vnd.github.antiope-preview+json"} return github.PaginatedList.PaginatedList(CheckRun, commit._requester, url, query, headers=headers, list_item="check_runs") def pull_state(self, pr_id: int) -> str: pr = self.get_pull(pr_id) return pr.state def reopen_pull(self, pr_id: int) -> None: pr = self.get_pull(pr_id) pr.edit(state="open") def close_pull(self, pr_id: int) -> None: pr = self.get_pull(pr_id) # Perhaps? # issue = self.repo.get_issue(pr_id) # issue.add_to_labels("mozilla:backed-out") pr.edit(state="closed") def is_approved(self, pr_id: int) -> bool: pr = self.get_pull(pr_id) reviews = pr.get_reviews() # We get a chronological list of all reviews, so we want to # check if the last review by any reviewer was in the approved # state review_by_reviewer = {} for review in reviews: review_by_reviewer[review.user.login] = review.state return "APPROVED" in list(review_by_reviewer.values()) def merge_sha(self, pr_id: int) -> Optional[str]: pr = self.get_pull(pr_id) if pr.merged: return pr.merge_commit_sha return None def is_mergeable(self, pr_id: int) -> bool: mergeable = None count = 0 while mergeable is None and count < 6: # GitHub sometimes doesn't have the mergability information ready; # In this case mergeable is None and we need to wait and try again pr = self.get_pull(pr_id) mergeable = pr is not None and pr.mergeable if mergeable is None: time.sleep(2**count) count += 1 del self.pr_cache[pr_id] return bool(mergeable) def merge_pull(self, pr_id: int) -> str: pr = self.get_pull(pr_id) merge_status = pr.merge(merge_method="rebase") return merge_status.sha def pr_for_commit(self, sha: str) -> Optional[int]: logger.info("Looking up PR for commit %s" % sha) owner, repo = self.repo_name.split("/") prs = list(self.gh.search_issues(query=f"is:pr repo:{owner}/{repo} sha:{sha}")) if len(prs) == 0: return None if len(prs) > 1: logger.warning("Got multiple PRs related to commit %s: %s" % (sha, ", ".join(str(item.number) for item in prs))) prs = sorted(prs, key=lambda x: x.number) return prs[0].number def get_commits(self, pr_id: int) -> List[Commit]: return list(self.get_pull(pr_id).get_commits()) def cleanup_pr_body(self, text: Optional[str]) -> Optional[str]: if text is None: return None r = re.compile(re.escape("<!-- Reviewable:start -->") + ".*" + re.escape("<!-- Reviewable:end -->"), re.DOTALL) return r.sub("", text) def _construct_check_data(self, name: str, commit_sha: Optional[str] = None, check_id: Optional[int] = None, url: Optional[str] = None, external_id: Optional[str] = None, status: Optional[str] = None, started_at: Optional[datetime] = None, conclusion: Optional[str] = None, completed_at: Optional[datetime] = None, output: Optional[Dict[str, str]] = None, actions: Optional[Any] = None, ) -> Tuple[str, Dict[str, Any]]: if check_id is not None and commit_sha is not None: raise ValueError("Only one of check_id and commit_sha may be supplied") if status is not None: if status not in ("queued", "in_progress", "completed"): raise ValueError("Invalid status %s" % status) if started_at is not None: started_at_text: Optional[str] = started_at.isoformat() else: started_at_text = None if status == "completed" and conclusion is None: raise ValueError("Got a completed status but no conclusion") if conclusion is not None and completed_at is None: raise ValueError("Got a conclusion but no completion time") if conclusion is not None: if conclusion not in ("success", "failure", "neutral", "cancelled", "timed_out", "action_required"): raise ValueError("Invalid conclusion %s" % conclusion) if completed_at is not None: completed_at_text = completed_at.isoformat() if output is not None: if "title" not in output: raise ValueError("Output requires a title") if "summary" not in output: raise ValueError("Output requires a summary") req_data: Dict[str, Any] = { "name": name, } for (name, value) in [("head_sha", commit_sha), ("id", check_id), ("url", url), ("external_id", external_id), ("status", status), ("started_at", started_at_text), ("conclusion", conclusion), ("completed_at", completed_at_text), ("output", output), ("actions", actions)]: req_data[name] = value req_method = "POST" if check_id is None else "PATCH" return req_method, req_data def set_check(self, name: str, commit_sha: Optional[str] = None, check_id: Optional[int] = None, url: Optional[str] = None, external_id: Optional[str] = None, status: Optional[str] = None, started_at: Optional[datetime] = None, conclusion: Optional[str] = None, completed_at: Optional[datetime] = None, output: Optional[Dict[str, str]] = None, actions: Optional[List[str]] = None ) -> Dict[str, Any]: req_method, req_data = self._construct_check_data(name, commit_sha, check_id, url, external_id, status, started_at, conclusion, completed_at, output, actions) req_headers = {"Accept": "application/vnd.github.antiope-preview+json"} req_url = self.repo.url + "/check-runs" if check_id is not None: req_url += ("/%s" % check_id) headers, data = self.repo._requester.requestJsonAndCheck( # type: ignore req_method, req_url, input=req_data, headers=req_headers ) # Not sure what to return here return data class AttrDict(dict): def __getattr__(self, name: str) -> Any: if name in self: return self[name] else: raise AttributeError(name) class MockGitHub(GitHub): def __init__(self): self.prs = {} self.commit_prs = {} self._id = itertools.count(1) self.output = StringIO() self.checks = {} def _log(self, data: str) -> None: self.output.write(data) self.output.write("\n") @property def repo(self): raise NotImplementedError def get_pull(self, id: int) -> Any: self._log("Getting PR %s" % id) return self.prs.get(int(id)) def create_pull(self, title: str, body: str, base: str, head: str, _commits: Optional[List[AttrDict]] = None, _id: Optional[int] = None, _user: Optional[str] = None, ) -> int: if _id is None: id = next(self._id) else: id = int(_id) assert id not in self.prs if _user is None: _user = env.config["web-platform-tests"]["github"]["user"] if _commits is None: _commits = [AttrDict(**{"sha": "%040x" % random.getrandbits(160), "message": "Test commit", "_statuses": [], "_checks": []})] data = AttrDict(**{ "number": id, "title": title, "body": body, "base": {"ref": base}, "head": head, "merged": False, "merge_commit_sha": "%040x" % random.getrandbits(160), "state": "open", "mergeable": True, "_approved": True, "_commits": _commits, "user": { "login": _user }, "labels": [] }) self.prs[id] = data for commit in _commits: self.commit_prs[commit["sha"]] = id self._log("Created PR with id %s" % id) return id def has_branch(self, name: str) -> bool: self._log("Checked branch %s" % name) return True def add_labels(self, pr_id: int, *labels: str) -> None: self.get_pull(pr_id)["labels"].extend(labels) def remove_labels(self, pr_id: int, *labels: str) -> None: pr = self.get_pull(pr_id) pr["labels"] = [item for item in pr["labels"] if item not in labels] def load_pull(self, data: Dict[str, Any]) -> None: pr = self.get_pull(data["number"]) pr.merged = data["merged"] pr.state = data["state"] def required_checks(self, branch_name: str) -> List[str]: return ["wpt-decision-task", "sink-task"] def get_check_runs(self, pr_id: int) -> Dict[str, Dict[str, Any]]: pr = self.get_pull(pr_id) rv = {} if pr: self._log("Got status for PR %s " % pr_id) for item in pr["_commits"][-1]["_checks"]: rv[item["name"]] = item.copy() del rv[item["name"]]["name"] return rv def pull_state(self, pr_id: int) -> str: pr = self.get_pull(pr_id) if not pr: raise ValueError return pr["state"] def reopen_pull(self, pr_id: int) -> None: pr = self.get_pull(pr_id) if not pr: raise ValueError pr["state"] = "open" def close_pull(self, pr_id: int) -> None: pr = self.get_pull(pr_id) if not pr: raise ValueError pr["state"] = "closed" def is_approved(self, pr_id: int) -> bool: pr = self.get_pull(pr_id) return pr._approved def merge_sha(self, pr_id: int) -> Optional[str]: pr = self.get_pull(pr_id) if pr.merged: return pr.merge_commit_sha return None def merge_pull(self, pr_id: int) -> str: pr = self.get_pull(pr_id) if self.is_mergeable(pr_id): pr.merged = True else: # TODO: raise the right kind of error here raise ValueError pr["merged_by"] = {"login": env.config["web-platform-tests"]["github"]["user"]} self._log("Merged PR with id %s" % pr_id) return pr.merge_commit_sha def pr_for_commit(self, sha: str) -> Optional[int]: return self.commit_prs.get(sha) def get_pulls(self, minimum_id=None): for number in self.prs: if minimum_id and number >= minimum_id: yield self.get_pull(number) def get_status(self, pr_id: int, context: str ): # type (...) -> Optional[Text] pr = self.get_pull(pr_id) statuses = [item for item in pr._commits[-1]._statuses if item.context == context] statuses.sort(key=lambda x: -x.id) if statuses: return statuses[0].state return None def set_status(self, pr_id: int, status: str, target_url: Optional[str], description: Optional[str], context: str ): pr = self.get_pull(pr_id) head_commit = pr._commits[-1] kwargs = { "state": status, "context": context, "id": len(pr._commits[-1]._statuses), } if target_url is not None: kwargs["target_url"] = target_url if description is not None: kwargs["description"] = description head_commit._statuses.append(AttrDict(**kwargs)) def set_check(self, name: str, commit_sha: Optional[str] = None, check_id: Optional[int] = None, url: Optional[str] = None, external_id: Optional[str] = None, status: Optional[str] = None, started_at: Optional[datetime] = None, conclusion: Optional[str] = None, completed_at: Optional[datetime] = None, output: Optional[Dict[str, str]] = None, actions: Optional[List[str]] = None, ) -> Dict[str, Any]: req_method, req_data = self._construct_check_data(name, commit_sha, check_id, url, external_id, status, started_at, conclusion, completed_at, output, actions) if req_data["head_sha"] not in self.checks: assert req_method == "POST" check_id = len(self.checks) else: assert req_method == "PATCH" check_id = self.checks["head_sha"][0] self.checks[req_data["head_sha"]] = (check_id, req_method, req_data) return {"id": check_id}