pyiceberg/utils/config.py (105 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import logging import os from typing import List, Optional import strictyaml from pyiceberg.typedef import UTF8, FrozenDict, RecursiveDict from pyiceberg.types import strtobool PYICEBERG = "pyiceberg_" DEFAULT = "default" CATALOG = "catalog" DEFAULT_CATALOG = f"{DEFAULT}-{CATALOG}" PYICEBERG_HOME = "PYICEBERG_HOME" PYICEBERG_YML = ".pyiceberg.yaml" logger = logging.getLogger(__name__) def merge_config(lhs: RecursiveDict, rhs: RecursiveDict) -> RecursiveDict: """Merge right-hand side into the left-hand side.""" new_config = lhs.copy() for rhs_key, rhs_value in rhs.items(): if rhs_key in new_config: lhs_value = new_config[rhs_key] if isinstance(lhs_value, dict) and isinstance(rhs_value, dict): # If they are both dicts, then we have to go deeper new_config[rhs_key] = merge_config(lhs_value, rhs_value) else: # Take the non-null value, with precedence on rhs new_config[rhs_key] = rhs_value or lhs_value else: # New key new_config[rhs_key] = rhs_value return new_config def _lowercase_dictionary_keys(input_dict: RecursiveDict) -> RecursiveDict: """Lowers all the keys of a dictionary in a recursive manner, to make the lookup case-insensitive.""" return {k.lower(): _lowercase_dictionary_keys(v) if isinstance(v, dict) else v for k, v in input_dict.items()} class Config: config: RecursiveDict def __init__(self) -> None: config = self._from_configuration_files() or {} config = merge_config(config, self._from_environment_variables(config)) self.config = FrozenDict(**config) @staticmethod def _from_configuration_files() -> Optional[RecursiveDict]: """Load the first configuration file that its finds. Will first look in the PYICEBERG_HOME env variable, and then in the home directory. """ def _load_yaml(directory: Optional[str]) -> Optional[RecursiveDict]: if directory: path = os.path.join(directory, PYICEBERG_YML) if os.path.isfile(path): with open(path, encoding=UTF8) as f: yml_str = f.read() file_config = strictyaml.load(yml_str).data file_config_lowercase = _lowercase_dictionary_keys(file_config) return file_config_lowercase return None # Directories to search for the configuration file # The current search order is: PYICEBERG_HOME, home directory, then current directory search_dirs = [os.environ.get(PYICEBERG_HOME), os.path.expanduser("~"), os.getcwd()] for directory in search_dirs: if config := _load_yaml(directory): return config # Didn't find a config return None @staticmethod def _from_environment_variables(config: RecursiveDict) -> RecursiveDict: """Read the environment variables, to check if there are any prepended by PYICEBERG_. Args: config: Existing configuration that's being amended with configuration from environment variables. Returns: Amended configuration. """ def set_property(_config: RecursiveDict, path: List[str], config_value: str) -> None: while len(path) > 0: element = path.pop(0) if len(path) == 0: # We're at the end _config[element] = config_value else: # We have to go deeper if element not in _config: _config[element] = {} if isinstance(_config[element], dict): _config = _config[element] # type: ignore else: raise ValueError( f"Incompatible configurations, merging dict with a value: {'.'.join(path)}, value: {config_value}" ) for env_var, config_value in os.environ.items(): # Make it lowercase to make it case-insensitive env_var_lower = env_var.lower() if env_var_lower.startswith(PYICEBERG.lower()): key = env_var_lower[len(PYICEBERG) :] parts = key.split("__", maxsplit=2) parts_normalized = [part.replace("__", ".").replace("_", "-") for part in parts] set_property(config, parts_normalized, config_value) return config def get_default_catalog_name(self) -> str: """Return the default catalog name. Returns: The name of the default catalog in `default-catalog`. Returns `default` when the key cannot be found in the config file. """ if default_catalog_name := self.config.get(DEFAULT_CATALOG): if not isinstance(default_catalog_name, str): raise ValueError(f"Default catalog name should be a str: {default_catalog_name}") return default_catalog_name return DEFAULT def get_catalog_config(self, catalog_name: str) -> Optional[RecursiveDict]: if CATALOG in self.config: catalog_name_lower = catalog_name.lower() catalogs = self.config[CATALOG] if not isinstance(catalogs, dict): raise ValueError(f"Catalog configurations needs to be an object: {catalog_name}") if catalog_name_lower in catalogs: catalog_conf = catalogs[catalog_name_lower] assert isinstance(catalog_conf, dict), f"Configuration path catalogs.{catalog_name_lower} needs to be an object" return catalog_conf return None def get_int(self, key: str) -> Optional[int]: if (val := self.config.get(key)) is not None: try: return int(val) # type: ignore except ValueError as err: raise ValueError(f"{key} should be an integer or left unset. Current value: {val}") from err return None def get_bool(self, key: str) -> Optional[bool]: if (val := self.config.get(key)) is not None: try: return strtobool(val) # type: ignore except ValueError as err: raise ValueError(f"{key} should be a boolean or left unset. Current value: {val}") from err return None