http_service/bugbug_http/models.py (234 lines of code) (raw):

# -*- coding: utf-8 -*- # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this file, # You can obtain one at http://mozilla.org/MPL/2.0/. import logging import os from datetime import timedelta from functools import lru_cache from typing import Sequence from urllib.parse import urlparse import orjson import requests import zstandard from redis import Redis from bugbug import bugzilla, repository, test_scheduling from bugbug.github import Github from bugbug.model import Model from bugbug.models import testselect from bugbug.utils import get_hgmo_stack from bugbug_http.readthrough_cache import ReadthroughTTLCache logging.basicConfig(level=logging.INFO) LOGGER = logging.getLogger() MODELS_NAMES = [ "defectenhancementtask", "component", "invalidcompatibilityreport", "needsdiagnosis", "regression", "stepstoreproduce", "spambug", "testlabelselect", "testgroupselect", "accessibility", "performancebug", "worksforme", "fenixcomponent", ] DEFAULT_EXPIRATION_TTL = 7 * 24 * 3600 # A week url = urlparse(os.environ.get("REDIS_URL", "redis://localhost/0")) assert url.hostname is not None redis = Redis( host=url.hostname, port=url.port if url.port is not None else 6379, password=url.password, ssl=True if url.scheme == "rediss" else False, ssl_cert_reqs=None, ) MODEL_CACHE: ReadthroughTTLCache[str, Model] = ReadthroughTTLCache( timedelta(hours=1), lambda m: Model.load(f"{m}model") ) MODEL_CACHE.start_ttl_thread() cctx = zstandard.ZstdCompressor(level=10) def setkey(key: str, value: bytes, compress: bool = False) -> None: LOGGER.debug(f"Storing data at {key}: {value!r}") if compress: value = cctx.compress(value) redis.set(key, value) redis.expire(key, DEFAULT_EXPIRATION_TTL) def classify_bug(model_name: str, bug_ids: Sequence[int], bugzilla_token: str) -> str: from bugbug_http.app import JobInfo # This should be called in a process worker so it should be safe to set # the token here bug_ids_set = set(map(int, bug_ids)) bugzilla.set_token(bugzilla_token) bugs = bugzilla.get(bug_ids) missing_bugs = bug_ids_set.difference(bugs.keys()) for bug_id in missing_bugs: job = JobInfo(classify_bug, model_name, bug_id) # TODO: Find a better error format setkey(job.result_key, orjson.dumps({"available": False})) if not bugs: return "NOK" model = MODEL_CACHE.get(model_name) if not model: LOGGER.info("Missing model %r, aborting" % model_name) return "NOK" model_extra_data = model.get_extra_data() # TODO: Classify could choke on a single bug which could make the whole # job to fails. What should we do here? probs = model.classify(list(bugs.values()), True) indexes = probs.argmax(axis=-1) suggestions = model.le.inverse_transform(indexes) probs_list = probs.tolist() indexes_list = indexes.tolist() suggestions_list = suggestions.tolist() for i, bug_id in enumerate(bugs.keys()): data = { "prob": probs_list[i], "index": indexes_list[i], "class": suggestions_list[i], "extra_data": model_extra_data, } job = JobInfo(classify_bug, model_name, bug_id) setkey(job.result_key, orjson.dumps(data), compress=True) # Save the bug last change setkey(job.change_time_key, bugs[bug_id]["last_change_time"].encode()) return "OK" def classify_issue( model_name: str, owner: str, repo: str, issue_nums: Sequence[int] ) -> str: from bugbug_http.app import JobInfo github = Github(owner=owner, repo=repo) issue_ids_set = set(map(int, issue_nums)) issues = { issue_num: github.fetch_issue_by_number(owner, repo, issue_num, True) for issue_num in issue_nums } missing_issues = issue_ids_set.difference(issues.keys()) for issue_id in missing_issues: job = JobInfo(classify_issue, model_name, owner, repo, issue_id) # TODO: Find a better error format setkey(job.result_key, orjson.dumps({"available": False})) if not issues: return "NOK" model = MODEL_CACHE.get(model_name) if not model: LOGGER.info("Missing model %r, aborting" % model_name) return "NOK" model_extra_data = model.get_extra_data() # TODO: Classify could choke on a single bug which could make the whole # job to fail. What should we do here? probs = model.classify(list(issues.values()), True) indexes = probs.argmax(axis=-1) suggestions = model.le.inverse_transform(indexes) probs_list = probs.tolist() indexes_list = indexes.tolist() suggestions_list = suggestions.tolist() for i, issue_id in enumerate(issues.keys()): data = { "prob": probs_list[i], "index": indexes_list[i], "class": suggestions_list[i], "extra_data": model_extra_data, } job = JobInfo(classify_issue, model_name, owner, repo, issue_id) setkey(job.result_key, orjson.dumps(data), compress=True) # Save the bug last change setkey(job.change_time_key, issues[issue_id]["updated_at"].encode()) return "OK" def classify_broken_site_report(model_name: str, reports_data: list[dict]) -> str: from bugbug_http.app import JobInfo reports = { report["uuid"]: {"title": report["title"], "body": report["body"]} for report in reports_data } if not reports: return "NOK" model = MODEL_CACHE.get(model_name) if not model: LOGGER.info("Missing model %r, aborting" % model_name) return "NOK" model_extra_data = model.get_extra_data() probs = model.classify(list(reports.values()), True) indexes = probs.argmax(axis=-1) suggestions = model.le.inverse_transform(indexes) probs_list = probs.tolist() indexes_list = indexes.tolist() suggestions_list = suggestions.tolist() for i, report_uuid in enumerate(reports.keys()): data = { "prob": probs_list[i], "index": indexes_list[i], "class": suggestions_list[i], "extra_data": model_extra_data, } job = JobInfo(classify_broken_site_report, model_name, report_uuid) setkey(job.result_key, orjson.dumps(data), compress=True) return "OK" @lru_cache(maxsize=None) def get_known_tasks() -> tuple[str, ...]: with open("known_tasks", "r") as f: return tuple(line.strip() for line in f) def schedule_tests(branch: str, rev: str) -> str: from bugbug_http import REPO_DIR from bugbug_http.app import JobInfo job = JobInfo(schedule_tests, branch, rev) LOGGER.info("Processing %s...", job) # Pull the revision to the local repository LOGGER.info("Pulling commits from the remote repository...") repository.pull(REPO_DIR, branch, rev) # Load the full stack of patches leading to that revision LOGGER.info("Loading commits to analyze using automationrelevance...") try: revs = get_hgmo_stack(branch, rev) except requests.exceptions.RequestException: LOGGER.warning(f"Push not found for {branch} @ {rev}!") return "NOK" test_selection_threshold = float( os.environ.get("TEST_SELECTION_CONFIDENCE_THRESHOLD", 0.5) ) # On "try", consider commits from other branches too (see https://bugzilla.mozilla.org/show_bug.cgi?id=1790493). # On other repos, only consider "tip" commits (to exclude commits such as https://hg.mozilla.org/integration/autoland/rev/961f253985a4388008700a6a6fde80f4e17c0b4b). if branch == "try": repo_branch = None else: repo_branch = "tip" # Analyze patches. commits = repository.download_commits( REPO_DIR, revs=revs, branch=repo_branch, save=False, use_single_process=True, include_no_bug=True, ) if len(commits) > 0: testlabelselect_model = MODEL_CACHE.get("testlabelselect") testgroupselect_model = MODEL_CACHE.get("testgroupselect") tasks = testlabelselect_model.select_tests(commits, test_selection_threshold) reduced = testselect.reduce_configs( set(t for t, c in tasks.items() if c >= 0.8), 1.0 ) reduced_higher = testselect.reduce_configs( set(t for t, c in tasks.items() if c >= 0.9), 1.0 ) groups = testgroupselect_model.select_tests(commits, test_selection_threshold) config_groups = testselect.select_configs(groups.keys(), 0.9) else: tasks = {} reduced = set() groups = {} config_groups = {} data = { "tasks": tasks, "groups": groups, "config_groups": config_groups, "reduced_tasks": {t: c for t, c in tasks.items() if t in reduced}, "reduced_tasks_higher": {t: c for t, c in tasks.items() if t in reduced_higher}, "known_tasks": get_known_tasks(), } setkey(job.result_key, orjson.dumps(data), compress=True) return "OK" def get_config_specific_groups(config: str) -> str: from bugbug_http.app import JobInfo job = JobInfo(get_config_specific_groups, config) LOGGER.info("Processing %s...", job) equivalence_sets = testselect._get_equivalence_sets(0.9) past_failures_data = test_scheduling.PastFailures("group", True) setkey( job.result_key, orjson.dumps( [ {"name": group} for group in past_failures_data.all_runnables if any( equivalence_set == {config} for equivalence_set in equivalence_sets[group] ) ] ), compress=True, ) return "OK"