sync/index.py (423 lines of code) (raw):

from __future__ import annotations import abc import json from collections import defaultdict import git import pygit2 from . import log from .base import ProcessName, CommitBuilder, iter_tree, iter_process_names from .env import Environment from .repos import pygit2_get from typing import Any, Callable, Iterator, Tuple, Union, TYPE_CHECKING from git.repo.base import Repo from pygit2.repository import Repository if TYPE_CHECKING: from sync.sync import SyncProcess ChangeEntry = Tuple[str, str, str] IndexKey = Tuple[str, ...] IndexValue = Union[str, ProcessName] logger = log.get_logger(__name__) env = Environment() class Index(metaclass=abc.ABCMeta): name: str | None = None key_fields: tuple[str, ...] unique = False value_cls: type = tuple # Overridden in subclasses using the constructor # This provides a kind of borg pattern where all instances of # the class have the same changes data changes: dict[Any, Any] | None = None def __init__(self, repo: Repo) -> None: if self.__class__.changes is None: self.reset() self.repo = repo self.pygit2_repo = pygit2_get(repo) def reset(self) -> None: constructors: list[Callable[[], Any]] = [list] for _ in self.key_fields[:-1]: def fn() -> Callable: idx = len(constructors) - 1 constructors.append(lambda: defaultdict(constructors[idx])) return constructors[-1] fn() self.__class__.changes = defaultdict(constructors[-1]) @classmethod def create(cls, repo: Repo) -> Index: logger.info("Creating index %s" % cls.name) data = {"name": cls.name, "fields": list(cls.key_fields), "unique": cls.unique} meta_path = "index/%s/_metadata" % cls.name tree = {meta_path: json.dumps(data, indent=0).encode("utf8")} with CommitBuilder(repo, message="Create index %s" % cls.name, ref=env.config["sync"]["ref"]) as commit: commit.add_tree(tree) return cls(repo) @classmethod def get_root_path(cls) -> str: return "index/%s" % cls.name @classmethod def get_or_create(cls, repo: Repo) -> Index: ref_name = env.config["sync"]["ref"] ref = git.Reference(repo, ref_name) try: ref.commit.tree["index/%s" % cls.name] except KeyError: return cls.create(repo) return cls(repo) def get(self, key: IndexKey) -> Any: if len(key) > len(self.key_fields): raise ValueError assert all(isinstance(key_part, str) for key_part in key) items = self._read(key) if self.unique and len(key) == len(self.key_fields) and len(items) > 1: raise ValueError("Got multiple results for unique index") rv = set() for item in items: rv.add(self.load_value(item)) if self.unique: return rv.pop() if rv else None return rv def _read(self, key: IndexKey, include_local: bool = True, ) -> set[str]: path = "{}/{}".format(self.get_root_path(), "/".join(key)) data = set() for obj in iter_blobs(self.pygit2_repo, path): data |= self._load_obj(obj) if include_local: self._update_changes(key, data) return data def _read_changes(self, key: IndexKey | None) -> dict[IndexKey, list[ChangeEntry]]: target = self.changes if target is None: return {} if key: for part in key: if target is not None: target = target[part] else: key = () changes: dict[IndexKey, list[ChangeEntry]] = {} stack = [(key, target)] while stack: key, items = stack.pop() if isinstance(items, list): changes[key] = items else: if items is not None: for key_part, values in items.items(): stack.append((key + (key_part,), values)) return changes def _update_changes(self, key: IndexKey, data: set[str], ) -> None: changes = self._read_changes(key) for key_changes in changes.values(): for old_value, new_value, _ in key_changes: if new_value is None and old_value in data: data.remove(old_value) elif new_value is not None: data.add(new_value) def _load_obj(self, obj: pygit2.Blob) -> set[str]: rv = json.loads(obj.data) if isinstance(rv, list): return set(rv) return {rv} def save(self, commit_builder: CommitBuilder | None = None, message: str | None = None, overwrite: bool = False) -> None: changes = self._read_changes(None) if not changes: return if message is None: message = "Update index %s\n" % self.name for key_changes in changes.values(): for _, _, msg in key_changes: message += " %s\n" % msg if commit_builder is None: # TODO: @overload could help here assert message is not None commit_builder = CommitBuilder(self.repo, message, ref=env.config["sync"]["ref"], initial_empty=overwrite) else: assert commit_builder.initial_empty == overwrite commit_builder.message += message with commit_builder as commit: for key, key_changes in changes.items(): self._update_key(commit, key, key_changes) self.reset() def insert(self, key: IndexKey, value: IndexValue, ) -> Index: if len(key) != len(self.key_fields): raise ValueError assert all(isinstance(item, str) for item in key) value = self.dump_value(value) msg = f"Insert key {key} value {value}" target = self.changes for part in key: if target is not None: target = target[part] assert isinstance(target, list) target.append((None, value, msg)) return self def delete(self, key: IndexKey, value: IndexValue, ) -> Index: if len(key) != len(self.key_fields): raise ValueError assert all(isinstance(item, str) for item in key) value = self.dump_value(value) msg = f"Delete key {key} value {value}" target = self.changes for part in key: if target is not None: target = target[part] assert isinstance(target, list) target.append((value, None, msg)) return self def move(self, old_key: IndexKey | None, new_key: IndexKey, value: IndexValue, ) -> Index: assert old_key != new_key if old_key is not None: self.delete(old_key, value) if new_key is not None: self.insert(new_key, value) return self def _update_key(self, commit: CommitBuilder, key: IndexKey, key_changes: list[ChangeEntry], ) -> None: existing = self._read(key, False) new = existing.copy() for old_value, new_value, _ in key_changes: if new_value is None and old_value is None: new = set() elif new_value is None and old_value in new: new.remove(old_value) elif old_value is None: new.add(new_value) path_suffix = "/".join(key) path = f"{self.get_root_path()}/{path_suffix}" if new == existing: return if not new: commit.delete([path]) return if self.unique and len(new) > 1: raise ValueError(f"Tried to insert duplicate entry for unique index {key}") index_value = list(sorted(new)) commit.add_tree({path: json.dumps(index_value, indent=0).encode("utf8")}) def dump_value(self, value: IndexValue) -> str: if isinstance(value, ProcessName): return value.path() return value def load_value(self, value: str) -> IndexValue: return self.value_cls(*(value.split("/"))) def build(self, *args: Any, **kwargs: Any) -> None: # Delete all entries in existing keys for key in self.keys(): assert len(key) == len(self.key_fields) target = self.changes for part in key: if target is not None: target = target[part] assert isinstance(target, list) target.append((None, None, f"Clear key {key}")) entries, errors = self.build_entries(*args, **kwargs) for key, value in entries: self.insert(key, value) self.save(message="Build index %s" % self.name) for error in errors: logger.warning(error) def build_entries(self, *args, **kwargs): raise NotImplementedError @classmethod def make_key(cls, value: Any) -> IndexKey: return (value,) def keys(self) -> set[IndexKey]: return {key for key, _ in iter_tree(self.pygit2_repo, root_path=self.get_root_path()) if not key[-1] == "_metadata"} class TaskGroupIndex(Index): name = "taskgroup" key_fields = ("taskgroup-id-0", "taskgroup-id-1", "taskgroup-id-2") unique = True value_cls = ProcessName @classmethod def make_key(cls, value: str) -> IndexKey: return (value[:2], value[2:4], value[4:]) def build_entries(self, *args, **kwargs): from . import trypush entries = [] for try_push in trypush.TryPush.load_all(self.repo): if try_push.taskgroup_id is not None: entries.append((self.make_key(try_push.taskgroup_id), try_push.process_name)) return entries, [] class TryCommitIndex(Index): name = "try-commit" key_fields = ("commit-0", "commit-1", "commit-2", "commit-3") unique = True value_cls = ProcessName @classmethod def make_key(cls, value: str) -> IndexKey: return (value[:2], value[2:4], value[4:6], value[6:]) def build_entries(self, *args, **kwargs): entries = [] from . import trypush for try_push in trypush.TryPush.load_all(self.repo): if try_push.try_rev: entries.append((self.make_key(try_push.try_rev), try_push.process_name)) return entries, [] class SyncIndex(Index): name = "sync-id-status" key_fields = ("objtype", "subtype", "status", "obj_id") unique = False value_cls = ProcessName @classmethod def make_key(cls, value: Union[SyncProcess, Tuple[ProcessName, str]], ) -> IndexKey: if isinstance(value, tuple): process_name, status = value else: process_name = value.process_name status = value.status return (process_name.obj_type, process_name.subtype, status, str(process_name.obj_id)) def build_entries(self, git_gecko, git_wpt, **kwargs): from .downstream import DownstreamSync from .upstream import UpstreamSync from .landing import LandingSync entries = [] errors = [] for process_name in iter_process_names(self.pygit2_repo, kind=["sync"]): sync_cls = None if process_name.subtype == "upstream": sync_cls = UpstreamSync elif process_name.subtype == "downstream": sync_cls = DownstreamSync elif process_name.subtype == "landing": sync_cls = LandingSync if sync_cls: try: sync = sync_cls(git_gecko, git_wpt, process_name) except ValueError: errors.append("Corrupt process %s" % process_name) continue entries.append((self.make_key(sync), process_name)) return entries, errors class PrIdIndex(Index): name = "pr-id" key_fields = ("pr-id",) unique = True value_cls = ProcessName @classmethod def make_key(cls, value: Union[SyncProcess, ProcessName]) -> IndexKey: if isinstance(value, ProcessName): pr_id = value.obj_id else: pr_id = str(value.pr) assert pr_id is not None return (pr_id,) def build_entries(self, git_gecko, git_wpt, **kwargs): from .downstream import DownstreamSync from .upstream import UpstreamSync entries = [] errors = [] for process_name in iter_process_names(self.pygit2_repo, kind=["sync"]): cls = None if process_name.subtype == "downstream": cls = DownstreamSync elif process_name.subtype == "upstream": cls = UpstreamSync if not cls: continue try: sync = cls(git_gecko, git_wpt, process_name) except ValueError: errors.append("Corrupt process %s" % process_name) continue if sync.pr: entries.append((self.make_key(sync), process_name)) return entries, errors class BugIdIndex(Index): name = "bug-id" key_fields = ("bug-id", "status") unique = False value_cls = ProcessName @classmethod def make_key(cls, value: Union[SyncProcess, Tuple[ProcessName, str]], ) -> IndexKey: if isinstance(value, tuple): process_name, status = value bug = process_name.obj_id else: assert value.bug is not None bug = str(value.bug) status = value.status return (bug, status) def build_entries(self, git_gecko, git_wpt, **kwargs): from .downstream import DownstreamSync from .upstream import UpstreamSync from .landing import LandingSync entries = [] errors = [] for process_name in iter_process_names(self.pygit2_repo, kind=["sync"]): sync_cls = None if process_name.subtype == "upstream": sync_cls = UpstreamSync elif process_name.subtype == "downstream": sync_cls = DownstreamSync elif process_name.subtype == "landing": sync_cls = LandingSync else: assert False try: sync = sync_cls(git_gecko, git_wpt, process_name) except ValueError as e: errors.append(f"Corrupt process {process_name}:\n{e}") continue entries.append((self.make_key(sync), process_name)) return entries, errors def iter_blobs(repo: Repository, path: str) -> Iterator[pygit2.Blob]: """Iterate over all blobs under a path :param repo: pygit2 repo :param path: path to use as the root, or None for the root path """ ref = repo.references[env.config["sync"]["ref"]] root = ref.peel().tree if path is not None: if path not in root: return root_entry = root[path] if isinstance(root_entry, pygit2.Blob): yield root_entry return stack = [root_entry] while stack: tree = stack.pop() assert isinstance(tree, pygit2.Tree) for item in tree: if isinstance(item, pygit2.Blob): yield item else: stack.append(item) indicies = {item for item in globals().values() if type(item) is type(Index) and issubclass(item, Index) and item != Index}