bugbug/db.py (171 lines of code) (raw):
# -*- coding: utf-8 -*-
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
import gzip
import io
import logging
import os
import pickle
from contextlib import contextmanager
from urllib.parse import urljoin
import orjson
import requests
import zstandard
from bugbug import utils
DATABASES = {}
logger = logging.getLogger(__name__)
class LastModifiedNotAvailable(Exception):
pass
def register(path, url, version, support_files=[]):
DATABASES[path] = {"url": url, "version": version, "support_files": support_files}
# Create DB parent directory.
os.makedirs(os.path.abspath(os.path.dirname(path)), exist_ok=True)
if not os.path.exists(f"{path}.version"):
with open(f"{path}.version", "w") as f:
f.write(str(version))
def is_registered(path: str) -> bool:
return path in DATABASES
def exists(path):
return os.path.exists(path)
def is_different_schema(path):
url = urljoin(DATABASES[path]["url"], f"{os.path.basename(path)}.version")
r = utils.get_session("community-tc").get(
url,
headers={
"User-Agent": utils.get_user_agent(),
},
)
if not r.ok:
logger.info("Version file is not yet available to download for %s", path)
return True
prev_version = int(r.text)
return DATABASES[path]["version"] != prev_version
def download_support_file(path, file_name, extract=True):
# If a DB with the current schema is not available yet, we can't download.
if is_different_schema(path):
return False
try:
url = urljoin(DATABASES[path]["url"], file_name)
path = os.path.join(os.path.dirname(path), file_name)
logger.info("Downloading %s to %s", url, path)
updated = utils.download_check_etag(url, path)
if extract and updated and path.endswith(".zst"):
utils.extract_file(path)
os.remove(path)
return True
except requests.exceptions.HTTPError:
logger.info(
"%s is not yet available to download for %s", file_name, path, exc_info=True
)
return False
# Download and extract databases.
def download(path, support_files_too=False, extract=True):
# If a DB with the current schema is not available yet, we can't download.
if is_different_schema(path):
return False
zst_path = f"{path}.zst"
url = DATABASES[path]["url"]
try:
logger.info("Downloading %s to %s", url, zst_path)
updated = utils.download_check_etag(url, zst_path)
if extract and updated:
utils.extract_file(zst_path)
os.remove(zst_path)
successful = True
if support_files_too:
for support_file in DATABASES[path]["support_files"]:
successful |= download_support_file(path, support_file, extract)
return successful
except requests.exceptions.HTTPError:
logger.info("%s is not yet available to download", url, exc_info=True)
return False
def upload(path):
support_files_paths = [
os.path.join(os.path.dirname(path), support_file_path)
for support_file_path in DATABASES[path]["support_files"]
]
utils.upload_s3([f"{path}.zst", f"{path}.version"] + support_files_paths)
def last_modified(path):
if is_different_schema(path):
raise LastModifiedNotAvailable()
url = DATABASES[path]["url"]
last_modified = utils.get_last_modified(url)
if last_modified is None:
raise LastModifiedNotAvailable()
return last_modified
class Store:
def __init__(self, fh):
self.fh = fh
class JSONStore(Store):
def write(self, elems):
for elem in elems:
self.fh.write(orjson.dumps(elem) + b"\n")
def read(self):
for line in io.TextIOWrapper(self.fh, encoding="utf-8"):
yield orjson.loads(line)
class PickleStore(Store):
def write(self, elems):
for elem in elems:
self.fh.write(pickle.dumps(elem))
def read(self):
try:
while True:
yield pickle.load(self.fh)
except EOFError:
pass
COMPRESSION_FORMATS = ["gz", "zstd"]
SERIALIZATION_FORMATS = {"json": JSONStore, "pickle": PickleStore}
@contextmanager
def _db_open(path, mode):
parts = str(path).split(".")
assert len(parts) > 1, "Extension needed to figure out serialization format"
if len(parts) == 2:
db_format = parts[-1]
compression = None
else:
db_format = parts[-2]
compression = parts[-1]
assert compression is None or compression in COMPRESSION_FORMATS
assert db_format in SERIALIZATION_FORMATS
store_constructor = SERIALIZATION_FORMATS[db_format]
if compression == "gz":
with gzip.GzipFile(path, mode) as f:
yield store_constructor(f)
elif compression == "zstd":
if "w" in mode or "a" in mode:
cctx = zstandard.ZstdCompressor()
with open(path, mode) as f:
with cctx.stream_writer(f) as writer:
yield store_constructor(writer)
else:
dctx = zstandard.ZstdDecompressor()
with open(path, mode) as f:
with dctx.stream_reader(f) as reader:
yield store_constructor(reader)
else:
with open(path, mode) as f:
yield store_constructor(f)
def read(path):
assert path in DATABASES
if not os.path.exists(path):
return ()
with _db_open(path, "rb") as store:
for elem in store.read():
yield elem
def write(path, elems):
assert path in DATABASES
with _db_open(path, "wb") as store:
store.write(elems)
def append(path, elems):
assert path in DATABASES
with _db_open(path, "ab") as store:
store.write(elems)
def delete(path, match):
assert path in DATABASES
dirname, basename = os.path.split(path)
new_path = os.path.join(dirname, f"new_{basename}")
def matching_elems(store):
for elem in store.read():
if not match(elem):
yield elem
try:
with _db_open(path, "rb") as rstore:
with _db_open(new_path, "wb") as wstore:
wstore.write(matching_elems(rstore))
except FileNotFoundError:
return
os.unlink(path)
os.rename(new_path, path)