sync/base.py (554 lines of code) (raw):

from __future__ import annotations import json import sys import weakref from collections import defaultdict from collections.abc import Mapping import git import pygit2 from . import log from . import commit as sync_commit from .env import Environment from .lock import MutGuard, RepoLock, mut, constructor from .repos import pygit2_get from typing import (Any, DefaultDict, Iterator, Optional, Set, Tuple, TYPE_CHECKING) from git.refs.reference import Reference from git.repo.base import Repo if TYPE_CHECKING: from pygit2.repository import Repository from pygit2 import Commit as PyGit2Commit, TreeEntry from sync.commit import Commit from sync.landing import LandingSync from sync.lock import SyncLock from sync.sync import SyncPointName ProcessNameIndexData = DefaultDict[str, DefaultDict[str, DefaultDict[str, Set]]] ProcessNameKey = Tuple[str, str, str, str] env = Environment() logger = log.get_logger(__name__) class IdentityMap(type): """Metaclass for objects that implement an identity map. Identity mapped objects return the same object when created with the same input data, so that there can only be a single object with given properties active at a time. Typically one would expect the identity in the identity map to be determined by the full set of arguments to the constructor. However we have various situations where either global singletons are passed in via the constructor (e.g. the repos), or associated class data (notably the commit type). To avoid refactoring everything when introducing this feature, we instead define the following protocol: A class implementing this metaclass must define a _cache_key method that is called with the constructor arguments. It returns some hashable object that is used as a key in the instance cache. If an existing instance of the same class exists with the same key that is returned, otherwise a new instance is constructed and cached. A class may also define a _cache_verify property. If this method exists it is called after an instance is retrieved from the cache and may be used to check that any arguments that are not part of the constructor key are consistent between the provided values and the instance values. This is clearly a hack; in general the code should be refactored so that such associated values are determined based on data that forms part of the instance key rather than passed in explicitly.""" _cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary() def __init__(cls, name, bases, cls_dict): if not hasattr(cls, "_cache_key"): raise ValueError("Class is missing _cache_key method") super().__init__(name, bases, cls_dict) def __call__(cls, *args, **kwargs): cache = IdentityMap._cache cache_key = cls._cache_key(*args, **kwargs) if cache_key is None: raise ValueError key = (cls, cache_key) value = cache.get(key) if value is None: value = super().__call__(*args, **kwargs) cache[key] = value if hasattr(value, "_cache_verify") and not value._cache_verify(*args, **kwargs): raise ValueError("Cached instance didn't match non-key arguments") return value def iter_tree(pygit2_repo: Repository, root_path: str = "", rev: PyGit2Commit | None = None, ) -> Iterator[tuple[tuple[str, ...], TreeEntry]]: """Iterator over all paths in a tree""" if rev is None: ref_name = env.config["sync"]["ref"] ref = pygit2_repo.references[ref_name] rev_obj = ref.peel() else: rev_obj = pygit2_repo[rev.id] assert isinstance(rev_obj, pygit2.Commit) root_obj = rev_obj.tree if root_path: root_tree = root_obj[root_path] else: root_tree = root_obj stack = [] stack.append((root_path, root_tree)) while stack: path, tree = stack.pop() assert isinstance(tree, pygit2.Tree) for item in tree: item_path = f"{path}/{item.name}" if isinstance(item, pygit2.Tree): stack.append((item_path, item)) else: name = tuple(item for item in item_path[len(root_path):].split("/") if item) yield name, item def iter_process_names(pygit2_repo: Repository, kind: list[str] = ["sync", "try"], ) -> Iterator[ProcessName]: """Iterator over all ProcessName objects""" ref = pygit2_repo.references[env.config["sync"]["ref"]] root = ref.peel().tree stack = [] for root_path in kind: try: tree = root[root_path] except KeyError: continue stack.append((root_path, tree)) while stack: path, tree = stack.pop() for item in tree: item_path = f"{path}/{item.name}" if isinstance(item, pygit2.Tree): stack.append((item_path, item)) else: process_name = ProcessName.from_path(item_path) if process_name is not None: yield process_name class ProcessNameIndex(metaclass=IdentityMap): def __init__(self, repo: Repo) -> None: self.repo = repo self.pygit2_repo = pygit2_get(repo) self.reset() @classmethod def _cache_key(cls, repo: Repo) -> tuple[Repo]: return (repo,) def reset(self) -> None: self._all: set[ProcessName] = set() self._data: ProcessNameIndexData = defaultdict( lambda: defaultdict( lambda: defaultdict(set))) self._built = False def build(self) -> None: for process_name in iter_process_names(self.pygit2_repo): self.insert(process_name) self._built = True def insert(self, process_name: ProcessName) -> None: self._all.add(process_name) self._data[ process_name.obj_type][ process_name.subtype][ process_name.obj_id].add(process_name) def has(self, process_name: ProcessName) -> bool: if not self._built: self.build() return process_name in self._all def get(self, obj_type: str, subtype: str | None = None, obj_id: str | None = None) -> set[ProcessName]: if not self._built: self.build() target = self._data for key in [obj_type, subtype, obj_id]: assert isinstance(key, str) if key is None: break target = target[key] # type: ignore rv: set[ProcessName] = set() stack = [target] while stack: item = stack.pop() if isinstance(item, set): rv |= item else: stack.extend(item.values()) # type: ignore return rv class ProcessName(metaclass=IdentityMap): """Representation of a name that is used to identify a sync operation. This has the general form <obj type>/<subtype>/<obj_id>[/<seq_id>]. Here <obj type> represents the type of process e.g upstream or downstream, <obj_id> is an identifier for the sync, typically either a bug number or PR number, and <seq_id> is an optional id to cover cases where we might have multiple processes with the same obj_id. """ def __init__(self, obj_type: str, subtype: str, obj_id: str, seq_id: str | int) -> None: assert obj_type is not None assert subtype is not None assert obj_id is not None assert seq_id is not None self._obj_type = obj_type self._subtype = subtype self._obj_id = str(obj_id) self._seq_id = str(seq_id) @classmethod def _cache_key(cls, obj_type: str, subtype: str, obj_id: str, seq_id: str | int, ) -> tuple[str, str, str, str]: return (obj_type, subtype, str(obj_id), str(seq_id)) def __str__(self) -> str: data = "%s/%s/%s/%s" % self.as_tuple() if sys.version_info[0] == 2: data = data.encode("utf8") return data def key(self) -> tuple[str, str, str, str]: return self._cache_key(self._obj_type, self._subtype, self._obj_id, self._seq_id) def path(self) -> str: return "%s/%s/%s/%s" % self.as_tuple() def __eq__(self, other: Any) -> bool: if self is other: return True if self.__class__ != other.__class__: return False return self.as_tuple() == other.as_tuple() def __hash__(self) -> int: return hash(self.key()) @property def obj_type(self) -> str: return self._obj_type @property def subtype(self) -> str: return self._subtype @property def obj_id(self) -> str: return self._obj_id @property def seq_id(self) -> int: return int(self._seq_id) def as_tuple(self) -> tuple[str, str, str, int]: return (self.obj_type, self.subtype, self.obj_id, self.seq_id) @classmethod def from_path(cls, path: str) -> ProcessName | None: return cls.from_tuple(path.split("/")) @classmethod def from_tuple(cls, parts: list[str]) -> ProcessName | None: if parts[0] not in ["sync", "try"]: return None if len(parts) != 4: return None return cls(*parts) @classmethod def with_seq_id(cls, repo: Repo, obj_type: str, subtype: str, obj_id: str) -> ProcessName: existing = ProcessNameIndex(repo).get(obj_type, subtype, obj_id) last_id = -1 for process_name in existing: if (process_name.seq_id is not None and int(process_name.seq_id) > last_id): last_id = process_name.seq_id seq_id = last_id + 1 return cls(obj_type, subtype, obj_id, str(seq_id)) class VcsRefObject(metaclass=IdentityMap): """Representation of a named reference to a git object associated with a specific process_name. This is typically either a tag or a head (i.e. branch), but can be any git object.""" ref_prefix: str | None = None def __init__(self, repo: Repo, name: ProcessName | SyncPointName, commit_cls: type = sync_commit.Commit) -> None: self.repo = repo self.pygit2_repo = pygit2_get(repo) if not self.get_path(name) in self.pygit2_repo.references: raise ValueError("No ref found in %s with path %s" % (repo.working_dir, self.get_path(name))) self.name = name self.commit_cls = commit_cls self._lock = None def as_mut(self, lock: SyncLock) -> MutGuard: return MutGuard(lock, self) @property def lock_key(self) -> tuple[str, str]: return (self.name.subtype, self.name.obj_id) @classmethod def _cache_key(cls, repo: Repo, process_name: ProcessName | SyncPointName, commit_cls: type = sync_commit.Commit, ) -> tuple[Repo, ProcessNameKey | tuple[str, str]]: return (repo, process_name.key()) def _cache_verify(self, repo: Repo, process_name: ProcessName | SyncPointName, commit_cls: type = sync_commit.Commit) -> bool: return commit_cls == self.commit_cls @classmethod @constructor(lambda args: (args["name"].subtype, args["name"].obj_id)) def create(cls, lock: SyncLock, repo: Repo, name: ProcessName, obj: str, commit_cls: type = sync_commit.Commit, force: bool = False) -> VcsRefObject: path = cls.get_path(name) logger.debug("Creating ref %s" % path) pygit2_repo = pygit2_get(repo) if path in pygit2_repo.references: if not force: raise ValueError(f"Ref {path} exists") pygit2_repo.references.create(path, pygit2_repo.revparse_single(obj).id, force=force) return cls(repo, name, commit_cls) def __str__(self) -> str: return self.path def delete(self) -> None: self.pygit2_repo.references[self.path].delete() @classmethod def get_path(cls, name: ProcessName | SyncPointName) -> str: return f"refs/{cls.ref_prefix}/{name.path()}" @property def path(self) -> str: return self.get_path(self.name) @property def ref(self) -> Reference | None: if self.path in self.pygit2_repo.references: return git.Reference(self.repo, self.path) return None @property def commit(self) -> Commit | None: ref = self.ref if ref is not None: commit = self.commit_cls(self.repo, ref.commit) return commit return None @commit.setter # type: ignore @mut() def commit(self, commit: Commit | str) -> None: if isinstance(commit, sync_commit.Commit): sha1 = commit.sha1 else: sha1 = commit sha1 = self.pygit2_repo.revparse_single(sha1).id self.pygit2_repo.references[self.path].set_target(sha1) class BranchRefObject(VcsRefObject): ref_prefix = "heads" class CommitBuilder: def __init__(self, repo: Repo, message: str, ref: str | None = None, commit_cls: type = sync_commit.Commit, initial_empty: bool = False ) -> None: """Object to be used as a context manager for commiting changes to the repo. This class provides low-level access to the git repository in order to make commits without requiring a checkout. It also enforces locking so that only one process may make a commit at a time. In order to use the object, one initalises it and then invokes it as a context manager e.g. with CommitBuilder(repo, "Some commit message" ref=ref) as commit_builder: # Now we have acquired the lock so that the commit ref points to is fixed commit_builder.add_tree({"some/path": "Some file data"}) commit_bulder.delete(["another/path"]) # On exiting the context, the commit is created, the ref updated to point at the # new commit and the lock released # To get the created commit we call get commit = commit_builder.get() The class may be used reentrantly. This is to support a pattern where a method may be called either with an existing commit_builder instance or create a new instance, and in either case use a with block. In order to improve the performance of the low-level access here, we use libgit2 to access the repository. """ # Class state self.repo = repo self.pygit2_repo = pygit2_get(repo) self.message = message if message is not None else "" self.commit_cls = commit_cls self.initial_empty = initial_empty if not ref: self.ref = None else: self.ref = ref self._count = 0 # State set for the life of the context manager self.lock = RepoLock(repo) self.parents: list[str] | None = None self.commit = None self.index: pygit2.Index = None self.has_changes = False def __enter__(self) -> CommitBuilder: self._count += 1 if self._count != 1: return self self.lock.__enter__() # First we create an empty index self.index = pygit2.Index() if self.ref is not None: try: ref = self.pygit2_repo.lookup_reference(self.ref) except KeyError: self.parents = [] else: self.parents = [ref.peel().id] if not self.initial_empty: self.index.read_tree(ref.peel().tree) else: self.parents = [] return self def __exit__(self, *args: Any, **kwargs: Any) -> None: self._count -= 1 if self._count != 0: return if not self.has_changes: if self.parents: sha1 = self.parents[0] else: return None else: tree_id = self.index.write_tree(self.pygit2_repo) sha1 = self.pygit2_repo.create_commit(self.ref, self.pygit2_repo.default_signature, self.pygit2_repo.default_signature, self.message.encode("utf8"), tree_id, self.parents) self.lock.__exit__(*args, **kwargs) self.commit = self.commit_cls(self.repo, sha1) def add_tree(self, tree: dict[str, bytes]) -> None: self.has_changes = True for path, data in tree.items(): blob = self.pygit2_repo.create_blob(data) index_entry = pygit2.IndexEntry(path, blob, pygit2.GIT_FILEMODE_BLOB) self.index.add(index_entry) def delete(self, delete: list[str]) -> None: self.has_changes = True if delete: for path in delete: self.index.remove(path) def get(self) -> Any | None: return self.commit class ProcessData(metaclass=IdentityMap): obj_type: str = "" def __init__(self, repo: Repo, process_name: ProcessName) -> None: assert process_name.obj_type == self.obj_type self.repo = repo self.pygit2_repo = pygit2_get(repo) self.process_name = process_name self.ref = git.Reference(repo, env.config["sync"]["ref"]) self.path = self.get_path(process_name) self._data = self._load() self._lock = None self._updated: set[str] = set() self._deleted: set[str] = set() self._delete = False def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.process_name}>" def __hash__(self) -> int: return hash(self.process_name) def __eq__(self, other: Any) -> bool: if type(self) is not type(other): return False return self.repo == other.repo and self.process_name == other.process_name def as_mut(self, lock: SyncLock) -> MutGuard: return MutGuard(lock, self) def exit_mut(self) -> None: message = "Update %s\n\n" % self.path with CommitBuilder(self.repo, message=message, ref=self.ref.path) as commit: from . import index if self._delete: self._delete_data("Delete %s" % self.path) self._delete = False elif self._updated or self._deleted: message_parts = [] if self._updated: message_parts.append("Updated: {}\n".format(", ".join(self._updated))) if self._deleted: message_parts.append("Deleted: {}\n".format(", ".join(self._deleted))) self._save(self._data, message=" ".join(message_parts), commit_builder=commit) self._updated = set() self._deleted = set() for idx_cls in index.indicies: idx = idx_cls(self.repo) idx.save(commit_builder=commit) @classmethod @constructor(lambda args: (args["process_name"].subtype, args["process_name"].obj_id)) def create(cls, lock: SyncLock, repo: Repo, process_name: ProcessName, data: dict[str, Any], message: str = "Sync data", commit_builder: Optional[CommitBuilder] = None ) -> ProcessData: assert process_name.obj_type == cls.obj_type path = cls.get_path(process_name) ref = git.Reference(repo, env.config["sync"]["ref"]) try: ref.commit.tree[path] except KeyError: pass else: raise ValueError(f"{cls.__name__} already exists at path {path}") if commit_builder is None: commit_builder = CommitBuilder(repo, message, ref=ref.path) else: assert commit_builder.ref == ref.path with commit_builder as commit: commit.add_tree({path: json.dumps(data).encode("utf8")}) ProcessNameIndex(repo).insert(process_name) return cls(repo, process_name) @classmethod def _cache_key(cls, repo: Repo, process_name: ProcessName) -> tuple[Repo, ProcessNameKey]: return (repo, process_name.key()) @classmethod def get_path(self, process_name: ProcessName) -> str: return process_name.path() @classmethod def load_by_obj(cls, repo: Repo, subtype: str, obj_id: int, seq_id=None # Type: Optional[int] ) -> set[ProcessData]: process_names = ProcessNameIndex(repo).get(cls.obj_type, subtype, str(obj_id)) if seq_id is not None: process_names = {item for item in process_names if item.seq_id == seq_id} return {cls(repo, process_name) for process_name in process_names} @classmethod def load_by_status(cls, repo, subtype, status): from . import index process_names = index.SyncIndex(repo).get((cls.obj_type, subtype, status)) rv = set() for process_name in process_names: rv.add(cls(repo, process_name)) return rv def _save(self, data: dict[str, Any], message: str, commit_builder: CommitBuilder | None = None) -> Any | None: if commit_builder is None: commit_builder = CommitBuilder(self.repo, message=message, ref=self.ref.path) else: commit_builder.message += message tree = {self.path: json.dumps(data).encode("utf8")} with commit_builder as commit: commit.add_tree(tree) return commit.get() def _delete_data(self, message: str, commit_builder: CommitBuilder | None = None) -> None: if commit_builder is None: commit_builder = CommitBuilder(self.repo, message=message, ref=self.ref.path) with commit_builder as commit: commit.delete([self.path]) def _load(self) -> dict[str, Any]: ref = self.pygit2_repo.references[self.ref.path] repo = self.pygit2_repo try: data = repo[repo[ref.peel().tree.id][self.path].id].data except KeyError: return {} return json.loads(data) @property def lock_key(self) -> tuple[str, str]: return (self.process_name.subtype, self.process_name.obj_id) def __getitem__(self, key: str) -> Any: return self._data[key] def __contains__(self, key: str) -> bool: return key in self._data def get(self, key: str, default: Any = None) -> Any: return self._data.get(key, default) def items(self) -> Iterator[tuple[str, Any]]: yield from self._data.items() @mut() def __setitem__(self, key: str, value: Any) -> None: if key not in self._data or self._data[key] != value: self._data[key] = value self._updated.add(key) @mut() def __delitem__(self, key: str) -> None: if key in self._data: del self._data[key] self._deleted.add(key) @mut() def delete(self) -> None: self._delete = True class FrozenDict(Mapping): def __init__(self, **kwargs: Any) -> None: self._data = {} for key, value in kwargs.items(): self._data[key] = value def __getitem__(self, key: str) -> Any: return self._data[key] def __contains__(self, key: Any) -> bool: return key in self._data def copy(self, **kwargs: Any) -> FrozenDict: new_data = self._data.copy() for key, value in kwargs.items(): new_data[key] = value return self.__class__(**new_data) def __iter__(self) -> Iterator[str]: yield from self._data def __len__(self) -> int: return len(self._data) def as_dict(self) -> dict[str, Any]: return self._data.copy() class entry_point: def __init__(self, task): self.task = task def __call__(self, f): def inner(*args: Any, **kwargs: Any) -> LandingSync | None: logger.info(f"Called entry point {f.__module__}.{f.__name__}") logger.debug(f"Called args {args!r} kwargs {kwargs!r}") if self.task in env.config["sync"]["enabled"]: return f(*args, **kwargs) logger.debug("Skipping disabled task %s" % self.task) return None inner.__name__ = f.__name__ inner.__doc__ = f.__doc__ return inner