conftest.py (133 lines of code) (raw):

import json import logging import operator import os import sys from pathlib import Path import pytest BASELINE_DIRECTORY = Path(__file__).parent.resolve() / Path("tests") / Path("baselines") / Path("fixture") def walk_path(path: Path): """ Taken from https://stackoverflow.com/a/76236680 Path.walk() is not available until python 3.12 """ subdirs = [d for d in path.iterdir() if d.is_dir()] files = [f for f in path.iterdir() if f.is_file()] yield path, subdirs, files for s in subdirs: yield from walk_path(s) class Baseline: def __init__(self, session): self.rebase = session.config.option.rebase self.references = {} if BASELINE_DIRECTORY.exists(): for root, dirs, files in walk_path(BASELINE_DIRECTORY): for name in files: with (root / name).open() as f: self.references.update(json.load(f)) def get_reference(self, addr, context=[]): reference = self.references.setdefault(addr, {}) for c in context: reference = reference.setdefault(c, {}) return reference def finalize(self): if self.rebase: # aggregate refs by test file refsbyfile = {} for case, ref in self.references.items(): key = case.split("::")[0] reffile = BASELINE_DIRECTORY / Path(key).with_suffix(".json") refsbyfile.setdefault(reffile, {})[case] = ref # dump aggregated refs into their own files for reffile, refs in refsbyfile.items(): reffile.parent.mkdir(parents=True, exist_ok=True) with reffile.open("w+") as f: json.dump(refs, f, indent=2, sort_keys=True) class BaselineRequest: def __init__(self, request): self.baseline = request.session.stash["baseline"] self.addr = request.node.nodeid def assertRef(self, compare, context=[], **kwargs): reference = self.baseline.get_reference(self.addr, context) if self.baseline.rebase: reference.update(**kwargs) for key, actual in kwargs.items(): ref = reference.get(key, None) logging.getLogger().info(f"{'.'.join(context + [key])}:actual = {actual}") logging.getLogger().info(f"{'.'.join(context + [key])}:ref = {ref}") assert compare(actual, ref) def assertEqual(self, context=[], **kwargs): self.assertRef(operator.eq, context, **kwargs) class Secret: """ Taken from: https://stackoverflow.com/a/67393351 """ def __init__(self, value): self.value = value def __repr__(self): return "Secret(********)" def __str___(self): return "*******" def pytest_addoption(parser): parser.addoption("--token", action="store", default=None) parser.addoption("--rebase", action="store_true", help="rebase baseline references from current run") parser.addoption( "--device", "--device-context", action="store", default=None, help=( "Used to enable device specific test configurations and baselines." " If unspecified, the default is to auto-detect the device." ), ) @pytest.fixture def token(request): return Secret(request.config.option.token) def pytest_configure(config): # Bitsandbytes installation for {test_bnb_qlora.py test_bnb_inference.py} tests # This change will be reverted shortly bnb_tests = any("bnb" in name for name in config.known_args_namespace.file_or_dir) if bnb_tests: import subprocess import sys subprocess.check_call( [ sys.executable, "-m", "pip", "install", "git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@multi-backend-refactor", ] ) name = "" try: from optimum.habana.utils import get_device_name name = get_device_name() # get_device_name() returns `gaudi` for G1 if "gaudi" == name: # use "gaudi1" since this is used in tests, baselines, etc. name = "gaudi1" except ValueError: pass # ignore unsupported device, we'll handle it in sessionstart finally: config.stash["physical-device"] = name def pytest_sessionstart(session): session.stash["baseline"] = Baseline(session) # User command-line option takes highest priority if session.config.option.device is not None: device = str(session.config.option.device).strip().lower() # Otherwise, use physical device (auto-detected) else: device = session.config.stash["physical-device"] if not device: raise RuntimeError("Expected a device context but did not detect one.") from tests import utils utils.OH_DEVICE_CONTEXT = device session.config.stash["device-context"] = device # WA: delete the imported top-level tests module so we don't overshadow # tests/transformers/tests module. # This fixes python -m pytest tests/transformers/tests/models/ -s -v del sys.modules["tests"] def pytest_report_header(config): header = [] if "GAUDI2_CI" in os.environ: del os.environ["GAUDI2_CI"] # prevent someone from trying to use it in tests header.append("\n!!!!!!!!!!!!!!! NOTICE !!!!!! NOTICE !!!!!!!!!!!!!!!!!!!!!!") header.append("!! GAUDI2_CI environment variable has been discontinued. !!") header.append("!! The CI device context will be auto-detected or can be !!") header.append("!! overridden with '--device-context' option. See --help !!") header.append("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n") header.append(f" device context: {config.stash['device-context']}") header.append(f"physical device: {config.stash['physical-device'] or None}") if config.stash["device-context"] != config.stash["physical-device"]: header.append("\nBEWARE: The 'device context' != 'physical-device'.") header.append("BEWARE: It is assumed you know what you are doing.\n") return header def pytest_sessionfinish(session): session.stash["baseline"].finalize() @pytest.fixture def baseline(request): return BaselineRequest(request)