common/init_deployment_config.py (269 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Processes config.json for resolving default values and validation. """ import argparse from importlib.util import module_from_spec from importlib.util import spec_from_file_location import json import logging from pathlib import Path import sys import typing import uuid from google.cloud.exceptions import BadRequest from google.cloud.exceptions import Forbidden from google.cloud.exceptions import ServerError from google.cloud.exceptions import Unauthorized # Make sure common modules are in Python path sys.path.append(str(Path(__file__).parent.parent)) # pylint:disable=wrong-import-position from common.py_libs import bq_helper from common.py_libs import constants from common.py_libs import cortex_bq_client from common.py_libs import resource_validation_helper _DEFAULT_CONFIG_ = "config/config.json" _VALIDATOR_FILE_NAME_ = "config_validator" _VALIDATE_FUNC_NAME_ = "validate" _DEFAULT_TEST_HARNESS_PROJECT = "kittycorn-public" def _load_config(config_file): """Loads config file. Args: config_file (StrOrPath): config file path Raises: FileNotFoundError: Config file not found. RuntimeError: Invalid file format. Returns: typing.Dict[str, any]: configuration dictionary """ if not Path(config_file).is_file(): raise FileNotFoundError(f"🛑 Config file '{config_file}' " "does not exist. 🛑") with open(config_file, mode="r", encoding="utf-8") as cf: try: config = json.load(cf) except json.JSONDecodeError: e_msg = f"🛑 Config file '{config_file}' is malformed or empty. 🛑" raise RuntimeError(e_msg) from None return config def _save_config(config_file, config_dict): """Saves config dictionary to a json file. Args: config_file (str): Config file path. config_dict (typing.Dict[str, any]): Config dictionary. """ with open(config_file, mode="w", encoding="utf-8") as cf: json.dump(config_dict, cf, indent=4) def _validate_workload( config: typing.Dict[str, any], # type: ignore workload_path: str ) -> typing.Optional[typing.Dict[str, any]]: # type: ignore """Searches for {_VALIDATOR_FILE_NAME_}.py (config_validator.py) file in `workload_path` sub-directory of current directory, and calls {_VALIDATE_FUNC_NAME_} (validate) function in it. Args: config (dict): Configuration dictionary to be passed to the workload's `validate` function workload_path (str): Sub-directory path Returns: dict: Validated and processed configuration dictionary as returned from `validate` function. Returns None if config_validator.py file doesn't exists or of validate function is not defined in it. """ repo_root = Path.cwd().absolute() validator_dir = repo_root.joinpath( workload_path) if workload_path else repo_root full_file_path = repo_root.joinpath( validator_dir, f"{_VALIDATOR_FILE_NAME_}.py").absolute() if not full_file_path.exists(): logging.error("🛑 No config validator for `%s`. Missing `%s` 🛑.", workload_path, full_file_path) return None logging.info("Found %s.py in %s. Running 'validate'.", _VALIDATOR_FILE_NAME_, validator_dir) spec = spec_from_file_location(_VALIDATOR_FILE_NAME_, full_file_path) module = module_from_spec(spec) # type: ignore spec.loader.exec_module(module) # type: ignore if not hasattr(module, _VALIDATE_FUNC_NAME_): logging.error("🛑 %s doesn't have %s function. 🛑", full_file_path, _VALIDATE_FUNC_NAME_) return None validate_func = getattr(module, _VALIDATE_FUNC_NAME_) return validate_func(config) def _validate_config_resources(config: typing.Dict[str, typing.Any]) -> bool: source = config["projectIdSource"] target = config["projectIdTarget"] location = config["location"] # Checking if we can create datasets in source and target projects. projects = [source] if source != target: projects.append(target) for project in projects: bq_client = cortex_bq_client.CortexBQClient(project=project, location=location) temp_dataset_name = f"tmp_cortex_{uuid.uuid4().hex}" full_temp_dataset_name = f"{project}.{temp_dataset_name}" try: bq_helper.create_dataset(bq_client, full_temp_dataset_name, location, True) logging.info( "✅ BigQuery in project `%s` is available " "for writing.", project) except (Forbidden, Unauthorized): logging.exception( "🛑 Insufficient permissions to create datasets " "in project `%s`. 🛑", project) return False except (BadRequest, ServerError): logging.exception( "🛑 Error when trying to create a BigQuery dataset " "in project `%s`. 🛑", project) return False finally: try: bq_client.delete_dataset(full_temp_dataset_name, not_found_ok=True) except BadRequest: logging.warning( "⚠️ Couldn't delete temporary dataset `%s`. " "Please delete it manually. ⚠️", full_temp_dataset_name) # targetBucket must exist and be writable buckets = [ resource_validation_helper.BucketConstraints( str(config["targetBucket"]), True, location) ] # K9 dataset must be writable, if exist. # If it doesn't exist, it will be created later. datasets = [ resource_validation_helper.DatasetConstraints( f'{source}.{config["k9"]["datasets"]["processing"]}', False, True, location), resource_validation_helper.DatasetConstraints( f'{target}.{config["k9"]["datasets"]["reporting"]}', False, True, location), # Vertex AI dataset must be in the same region as Vertex AI region resource_validation_helper.DatasetConstraints( f'{source}.{config["VertexAI"]["processingDataset"]}', False, True, config["VertexAI"]["region"]) ] return resource_validation_helper.validate_resources(buckets, datasets) def validate_config( config: typing.Dict[str, any], # type: ignore sub_validators: typing.List[str] # type: ignore ) -> typing.Optional[typing.Dict[str, typing.Any]]: # type: ignore """Performs common config validation. Calls workload-specific validators and discovers all config issues. It will discover and log all issues before returning. Args: config (typing.Dict[str, any]): loaded config.json as a dictionary. sub_validators (typing.List[str]): sub-directories with config_validator.py to call. Returns: typing.Optional[dict]: validated and updated config dictionary or None if invalid config. """ failed = False if not config.get("projectIdSource"): logging.error("🛑 Missing 'projectIdSource' configuration value. 🛑") failed = True if not config.get("targetBucket"): logging.error("🛑 Missing 'targetBucket' configuration value. 🛑") failed = True config["projectIdTarget"] = config.get("projectIdTarget", "") if not config["projectIdTarget"]: logging.warning(("⚠️ projectIdTarget is not specified." "Using projectIdSource (%s) ⚠️"), config["projectIdSource"]) config["projectIdTarget"] = config["projectIdSource"] config["deploySAP"] = config.get("deploySAP", False) config["deploySFDC"] = config.get("deploySFDC", False) config["deployMarketing"] = config.get("deployMarketing", False) config["deployOracleEBS"] = config.get("deployOracleEBS", False) config["deployDataMesh"] = config.get("deployDataMesh", False) config["testData"] = config.get("testData", False) config["enableTaskDependencies"] = config.get("enableTaskDependencies", False) config["turboMode"] = config.get("turboMode", True) config["location"] = config.get("location") if not config["location"]: logging.warning("⚠️ No location specified. Using `US`. ⚠️") config["location"] = "us" if config.get("testDataProject", "") == "": logging.warning("testDataProject is empty. Using default project `%s`.", _DEFAULT_TEST_HARNESS_PROJECT) config["testDataProject"] = _DEFAULT_TEST_HARNESS_PROJECT if "k9" not in config: logging.warning("No K9 configuration. Using defaults.") config["k9"] = {} if "datasets" not in config["k9"]: logging.warning(("⚠️ No K9 datasets configuration. " "Using defaults. ⚠️")) config["k9"]["datasets"] = {} if not config["k9"]["datasets"].get("processing"): logging.warning(("⚠️ No K9 processing dataset specified. " "Defaulting to K9_PROCESSING. ⚠️")) config["k9"]["datasets"]["processing"] = "K9_PROCESSING" if not config["k9"]["datasets"].get("reporting"): logging.warning(("⚠️ No K9 reporting dataset specified. " "Defaulting to K9_REPORTING. ⚠️")) config["k9"]["datasets"]["reporting"] = "K9_REPORTING" if "VertexAI" not in config: config["VertexAI"] = {} config["VertexAI"]["region"] = config["VertexAI"].get("region", "") if not config["VertexAI"]["region"]: bq_location = config["location"].lower() if "-" in bq_location: # single region vertexai_region = bq_location elif bq_location == "eu": vertexai_region = "europe-west1" else: vertexai_region = "us-central1" logging.warning("⚠️ No Vertex AI region specified. Using `%s`. ⚠️", vertexai_region) config["VertexAI"]["region"] = vertexai_region else: vertexai_region = config["VertexAI"]["region"].lower() bq_location = config["location"].lower() if "-" in bq_location and vertexai_region != bq_location: logging.error(("🛑 Vertex AI region must match BigQuery location. " "It should be in `%s` region. 🛑"), bq_location) failed = True elif bq_location == "eu" and not vertexai_region.startswith("europe-"): logging.error(("🛑 Vertex AI region must match BigQuery location. " "It should be in one of `europe-` regions. 🛑")) failed = True elif bq_location == "us" and not vertexai_region.startswith("us-"): logging.error(("🛑 Vertex AI region must match BigQuery location. " "It should be in one of `us-` regions. 🛑")) failed = True config["VertexAI"]["processingDataset"] = config["VertexAI"].get( "processingDataset") if not config["VertexAI"]["processingDataset"]: vertex_pd_default = "CORTEX_VERTEX_AI_PROCESSING" logging.warning("⚠️ No Vertex AI dataset specified. Using `%s`. ⚠️", vertex_pd_default) config["VertexAI"]["processingDataset"] = vertex_pd_default if config["deployDataMesh"] and "DataMesh" not in config: logging.error("🛑 Data Mesh is enabled but no options are specified. 🛑") failed = True logging.info("Fetching test harness version.") config["testHarnessVersion"] = config.get("testHarnessVersion", constants.TEST_HARNESS_VERSION) logging.info("Fetching Cortex version.") config["cortexVersion"] = config.get("cortexVersion", constants.CORTEX_VERSION) logging.info("Validating common configuration resources.") try: _validate_config_resources(config) except Exception: # pylint:disable=broad-except logging.error("🛑 Resource validation failed. 🛑") failed = True if failed: logging.error("🛑 Common configuration is invalid. 🛑") else: logging.info("✅ Common configuration is valid. ✅") # Go over all sub-validator directories, # and call validate() in config_validator.py in every directory. for validator in sub_validators: validator_text = validator if validator else "current repository" logging.info("Running config validator in `%s`.", validator_text) try: config = _validate_workload(config, validator) # type: ignore except Exception: # pylint:disable=broad-except logging.error("🛑 Validation in `%s` failed. 🛑", validator_text) # TODO: make exception(groups) chained failed = True else: logging.info("✅ Validation succeeded in `%s`. ✅", validator_text) if failed: return None else: return config def main(args: typing.Sequence[str]) -> int: logging.basicConfig( format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO, ) parser = argparse.ArgumentParser(description="Cortex Config Validator.") parser.add_argument("--config-file", type=str, required=False, default=_DEFAULT_CONFIG_) parser.add_argument("--sub-validator", action="append", default=[], required=False) params = parser.parse_args(args) logging.info("🦄 Running config validation and processing: %s.", params.config_file) try: config = _load_config(params.config_file) except RuntimeError: logging.exception("Invalid JSON") return 1 if config.get("validated", False): logging.info("✅ Configuration in `%s` " "was previously validated.", params.config_file) else: config = validate_config(config, list(params.sub_validator)) if not config: logging.error("🛑🔪 Configuration in `%s` is invalid. 🔪🛑\n\n", params.config_file) return 1 config["validated"] = True _save_config(params.config_file, config) logging.info(("🦄 Configuration in `%s` was " "successfully validated and processed. Let's roll! 🦄\n\n"), params.config_file) return 0 if __name__ == "__main__": sys.exit(main(sys.argv[1:]))