mysqloperator/controller/utils.py (150 lines of code) (raw):
# Copyright (c) 2020, 2022, Oracle and/or its affiliates.
#
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
#
import datetime
import time
import os
import string
import random
import base64
import threading
import json
import hashlib
from . import config
def b64decode(s: str) -> str:
return base64.b64decode(s).decode("utf8")
def b64encode(s: str) -> str:
return base64.b64encode(bytes(s, "utf8")).decode("ascii")
def sha256(s: str) -> str:
return hashlib.sha256(bytes(s, "utf8")).hexdigest()
class EphemeralState:
# State that's not persisted between operator restarts
# Use only if get() returning None is interpreted as "skip optimization"
def __init__(self):
self.data = {}
self.context = {}
self.time = {}
self.lock = threading.Lock()
def get(self, obj, key: str):
key = obj.namespace+"/"+obj.name+"/"+key
with self.lock:
return self.data.get(key)
def testset(self, obj, key: str, value, context: str):
key = obj.namespace+"/"+obj.name+"/"+key
with self.lock:
old_data = self.data.get(key)
old_context = self.context.get(key)
old_time = self.time.get(key)
if old_data is None:
self.data[key] = value
self.context[key] = context
self.time[key] = datetime.datetime.now()
return (old_data, old_context, old_time)
def set(self, obj, key: str, value, context: str) -> None:
key = obj.namespace+"/"+obj.name+"/"+key
with self.lock:
self.data[key] = value
self.context[key] = context
self.time[key] = datetime.datetime.now()
g_ephemeral_pod_state = EphemeralState()
def isotime() -> str:
return datetime.datetime.utcnow().replace(microsecond=0).isoformat()+"Z"
def timestamp(dash: bool = True, four_digit_year: bool = True) -> str:
dash_str = "-" if dash else ""
year_str = "%Y" if four_digit_year else "%y"
return datetime.datetime.utcnow().replace(microsecond=0).strftime(f"{year_str}%m%d{dash_str}%H%M%S")
def merge_patch_object(base: dict, patch: dict, prefix: str = "", key: str = "", none_deletes: bool = False) -> None:
assert not key, "not implemented" # TODO support key
if type(base) != type(patch):
raise ValueError(f"Invalid type in patch at {prefix}")
if type(base) != dict:
raise ValueError(f"Invalid type in base at {prefix}")
def get_named_object(l, name):
for o in l:
assert type(o) == dict, f"{prefix}: {name} = {o}"
if o["name"] == name:
return o
return None
for k, v in patch.items():
ov = base.get(k)
if ov is not None:
if type(ov) == dict:
if type(v) != dict:
# TODO
raise ValueError(f"Invalid type in {prefix}")
else:
merge_patch_object(ov, v, prefix+"."+k, none_deletes=none_deletes)
elif type(ov) == list:
if type(v) != list:
# TODO
raise ValueError(f"Invalid type in {prefix}")
else:
if not ov:
base[k] = v
else:
if type(v[0]) != dict:
base[k] = v
else:
# When merging lists of objects, we matching objects by name
# If there's no matching object, we append
# If there's a matching object, recursively patch
for i, elem in enumerate(v):
if type(elem) != dict:
raise ValueError(
f"Invalid type in {prefix}")
name = elem.get("name")
if not name:
raise ValueError(
"Object in list must have name")
o = get_named_object(ov, name)
if o:
merge_patch_object(
o, elem, prefix+"."+k+"["+str(i)+"]",
none_deletes=none_deletes)
else:
ov.append(elem)
elif type(ov) not in (dict, list) and type(v) in (dict, list):
raise ValueError(f"Invalid type in {prefix}")
else:
if none_deletes and v is None:
del base[k]
else:
base[k] = v
else:
if none_deletes and v is None:
pass
else:
base[k] = v
def generate_password() -> str:
random.seed(int(str(time.time()).split(".")[-1]))
return "-".join("".join(random.choice(string.ascii_letters+string.digits+"_.=+-~") for i in range(5)) for ii in range(5))
def version_to_int(version: str) -> int:
# x.y.z[.w]
parts = version.split(".")
if len(parts) > 4 or len(parts) < 3:
raise ValueError(
f"Invalid version number {version}. Must be n.n.n or n.n.n.n")
parts = [int(p) for p in parts]
# allow the last digit to be as long as a date value
if len(parts) > 3:
return parts[0] * 1000000000000 + parts[1] * 10000000000 + parts[2] + 100000000 + parts[3]
else:
return parts[0] * 1000000000000 + parts[1] * 10000000000 + parts[2] + 100000000
def version_in_range(version: str, minimum = None, maximum = None, check_disabled = True) -> list[bool, str]:
if not minimum:
minimum = config.MIN_SUPPORTED_MYSQL_VERSION
if not maximum:
maximum = config.MAX_SUPPORTED_MYSQL_VERSION
# Some versions have been disabled due to major issues
if check_disabled and version in config.DISABLED_MYSQL_VERSION:
return [False, config.DISABLED_MYSQL_VERSION[version]]
version_int = version_to_int(version)
min_version = version_to_int(minimum)
max_version = version_to_int(maximum)
if not max_version >= version_int >= min_version:
return [False,
f"version {version} must be between "
f"{minimum} and {maximum}"]
return [True, None]
def indent(s: str, spaces: int) -> str:
if s:
ind = "\n" + " "*spaces
return " " * spaces + ind.join(s.split("\n"))
return ""
def log_banner(path: str, logger) -> None:
import pkg_resources
from . import config
kopf_version = pkg_resources.get_distribution('kopf').version
ts = datetime.datetime.fromtimestamp(os.stat(path).st_mtime).isoformat()
path = os.path.basename(path)
logger.info(
f"MySQL Operator/{path}={config.OPERATOR_VERSION} timestamp={ts} kopf={kopf_version} uid={os.getuid()}")
def dict_to_json_string(d : dict) -> str:
return json.dumps(d, indent = 4)