assets/inference/environments/mlflow-py312-inference/context/mlmonitoring/config/config.py (139 lines of code) (raw):
"""For config."""
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=global-statement
import os
import json
from ..logger import is_debug, logger
default_queue_capacity = 1500
default_worker_count = 1
default_sample_rate_percentage = 100
class MdcConfig:
"""For MdcConfig."""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
enabled=False,
host="127.0.0.1",
port=50011,
debug=False,
sample_rate_percentage=default_sample_rate_percentage,
model_version=None,
queue_capacity=default_queue_capacity,
worker_disabled=False,
worker_count=default_worker_count,
local_capture=False,
compact_format=False,
):
"""For init."""
self._debug = debug
self._enabled = enabled
self._sample_rate_percentage = sample_rate_percentage
self._host = host
self._port = port
self._model_version = model_version
# queue - max length
self._queue_capacity = queue_capacity
# worker - disabled for test purpose only
self._worker_disabled = worker_disabled
self._worker_count = worker_count
# payload sender
self._local_capture = local_capture
self._compact_format = compact_format
self._collections = {}
def is_debug(self):
"""For is debug."""
return self._debug
def enabled(self):
"""For enabled."""
return self._enabled
def set_enabled(self, enabled):
"""For set enabled."""
self._enabled = enabled
def compact_format(self):
"""For compact format."""
return self._compact_format
def sample_rate_percentage(self):
"""For sample rate percentage."""
return self._sample_rate_percentage
def host(self):
"""For host."""
return self._host
def port(self):
"""For port."""
return self._port
def model_version(self):
"""For model version."""
return self._model_version
def queue_capacity(self):
"""For queue capacity."""
return self._queue_capacity
def worker_disabled(self):
"""For worker disabled."""
return self._worker_disabled
def worker_count(self):
"""For worker count."""
return self._worker_count
def local_capture(self):
"""For local capture."""
return self._local_capture
def add_collection(self, col_name, enabled=False, sample_rate_percentage=100):
"""For add collection."""
self._collections[col_name] = {
"enabled": enabled,
"sampleRatePercentage": sample_rate_percentage,
}
def collections(self):
"""For collections."""
return self._collections
def collection_enabled(self, collection_name):
"""For collection enabled."""
path = os.getenv("AZUREML_MDC_CONFIG_PATH")
if not path:
# for legacy settings, we depend on a global switch to see whether collections are enabled or not.
return self.enabled()
for n, c in self._collections.items():
if n == collection_name:
return c.get("enabled", False)
return False
def collection_sample_rate_percentage(self, collection_name):
"""For collection sample rate percentage."""
path = os.getenv("AZUREML_MDC_CONFIG_PATH")
if not path:
# for legacy settings, we take the global sample_rate_percentage.
return self.sample_rate_percentage()
for n, c in self._collections.items():
if n == collection_name:
return c.get("sampleRatePercentage", default_sample_rate_percentage)
return default_sample_rate_percentage
def loadConfig(model_version=None):
"""For loadConfig."""
debug = is_debug()
path = os.getenv("AZUREML_MDC_CONFIG_PATH")
if path:
with open(path) as f:
cfg = json.load(f)
mdc_cfg = MdcConfig(
host=os.getenv("AZUREML_MDC_HOST", "127.0.0.1"),
port=int(os.getenv("AZUREML_MDC_PORT", "50011")),
debug=debug,
model_version=model_version,
local_capture=cfg.get("runMode", "cloud") == "local"
)
collection_cfg = cfg.get("collections", {})
custom_logging_enabled = False
for col_name, c in collection_cfg.items():
col_name_lower = col_name.lower()
if c.get("enabled", False):
mdc_cfg.add_collection(col_name_lower, True, c.get("sampleRatePercentage", 100))
if col_name_lower not in ('request', 'response'):
custom_logging_enabled = True
mdc_cfg.set_enabled(custom_logging_enabled)
return mdc_cfg
enabled = os.getenv("AZUREML_MDC_ENABLED", "false")
if enabled.lower() == "true":
return MdcConfig(
enabled=True,
host=os.getenv("AZUREML_MDC_HOST", "127.0.0.1"),
port=int(os.getenv("AZUREML_MDC_PORT", "50011")),
debug=debug,
sample_rate_percentage=int(os.getenv("AZUREML_MDC_SAMPLE_RATE", str(default_sample_rate_percentage))),
queue_capacity=int(os.getenv("AZUREML_MDC_QUEUE_CAPACITY", str(default_queue_capacity))),
worker_disabled=os.getenv("AZUREML_MDC_WORKER_DISABLED", "false").lower() == "true",
worker_count=int(os.getenv("AZUREML_MDC_WORKER_COUNT", str(default_worker_count))),
compact_format=os.getenv("AZUREML_MDC_FORMAT_COMPACT", "false").lower() == "true",
local_capture=os.getenv("AZUREML_MDC_LOCAL_CAPTURE", "false").lower() == "true",
model_version=model_version,
)
return MdcConfig(enabled=False, debug=debug)
mdc_config = None
def init_config(model_version=None):
"""For init config."""
global mdc_config
mdc_config = loadConfig(model_version)
logger.info("mdc enabled: %r", mdc_config.enabled())
logger.info("mdc collections count %d", len(mdc_config.collections()))
for n, c in mdc_config.collections().items():
logger.info("mdc collection %s <enabled:%r,sample_percentage:%d>",
n,
c.get("enabled", False),
c.get("sampleRatePercentage", default_sample_rate_percentage))
if mdc_config.is_debug():
config_json = json.dumps(mdc_config.__dict__)
logger.debug("mdc config: %s", config_json)
def teardown_config():
"""For teardown config."""
global mdc_config
mdc_config = None
def get_config():
"""For get config."""
global mdc_config
return mdc_config