decisionai_plugin/common/plugin_service.py (297 lines of code) (raw):
import asyncio
import threading
import atexit
import json
import os
import shutil
import time
import traceback
import uuid
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from os import environ
import yaml
from apscheduler.schedulers.background import BackgroundScheduler
from flask import jsonify, make_response
from .tsanaclient import TSANAClient
from .util.constant import InferenceState
from .util.constant import ModelState
from .util.constant import STATUS_SUCCESS, STATUS_FAIL
from .util.constant import INSTANCE_ID_KEY
from .util.context import Context
from .util.meta import insert_or_update_meta, get_meta, update_state, get_model_list, clear_state_when_necessary
from .util.model import upload_model, download_model
from .util.monitor import init_monitor, run_monitor, stop_monitor
from .util.timeutil import str_to_dt
from .util.kafka_operator_confluent import send_message, consume_loop
from .util.job_record import JobRecord
import zlib
import base64
import gc
#async infras
#executor = ProcessPoolExecutor(max_workers=2)
#ThreadPool easy for debug
#executor = ThreadPoolExecutor(max_workers=2)
#monitor infras
sched = BackgroundScheduler()
from telemetry import log
def load_config(path):
try:
with open(path, 'r') as config_file:
config_yaml = yaml.safe_load(config_file)
Config = namedtuple('Config', sorted(config_yaml))
config = Config(**config_yaml)
return config
except Exception:
return None
class PluginService():
def __init__(self, trainable=True):
config_file = environ.get('SERVICE_CONFIG_FILE')
config = load_config(config_file)
if config is None:
log.error("No configuration '%s', or the configuration is not in JSON format. " % (config_file))
exit()
self.config = config
self.tsanaclient = TSANAClient()
self.trainable = trainable
if self.trainable:
init_monitor(config)
sched.add_job(func=lambda: run_monitor(config), trigger="interval", seconds=10)
sched.start()
atexit.register(lambda: stop_monitor(config))
atexit.register(lambda: sched.shutdown())
self.training_topic = self.__class__.__name__ + '-training'
training_thread = threading.Thread(target=consume_loop, args=(self.train_wrapper, self.training_topic), daemon=True)
training_thread.start()
self.inference_topic = self.__class__.__name__ + '-inference'
inference_thread = threading.Thread(target=consume_loop, args=(self.inference_wrapper, self.inference_topic), daemon=True)
inference_thread.start()
# verify parameters
# Parameters:
# parameters: a dict object which should includes
# apiEndpoint: api endpoint for specific user
# apiKey: api key for specific user
# groupId: groupId in TSANA, which is copied from inference request, or from the entity
# series_sets: Array of series set
# context: request context include subscription and model_id
# Return:
# STATUS_SUCCESS/STATUS_FAIL, error_message
def do_verify(self, parameters, context:Context):
return STATUS_SUCCESS, ''
# check if need to retrain model this time
# Parameters:
# current_series_set: series set used in instance now
# current_params: params used in instance now
# new_series_set: series set used in this request
# new_params: params used in this request
# context: request context include subscription and model_id
# Return:
# True/False
def need_retrain(self, current_series_set, current_params, new_series_set, new_params, context:Context):
return True
# train model
# Parameters:
# model_dir: output dir for model training result, framework will handle model storage
# parameters: training request body which include
# apiEndpoint: api endpoint for specific user
# apiKey: api key for specific user
# groupId: groupId in TSANA
# seriesSets: Array of series set
# startTime: start timestamp
# endTime: end timestamp
# instance: an info dict for this instance which includes
# instanceId: UUID for this instance
# params: training parameters for this request
# series: an array of Series object or None if config.auto_data_retrieving is False
# Series include
# series_id: UUID
# dim: dimension dict for this series
# fields: 1-d string array, ['time', '__VAL__', '__FIELD__.ExpectedValue', '__FIELD__.IsAnomaly', '__FIELD__.PredictionValue', '__FIELD__.PredictionModelScore', '__FIELD__.IsSuppress', '__FIELD__.Period', '__FIELD__.CostPoint', '__FIELD__.Mean', '__FIELD__.STD', '__FIELD__.TrendChangeAnnotate', '__FIELD__.TrendChang...tateIgnore', '__FIELD__.AnomalyAnnotate', ...]
# value: 2-d array, [['2020-10-12T17:55:00Z', 1.0, None, None, None, None, None, None, None, None, None, None, None, None, ...]]
# context: request context include subscription and model_id
# Return:
# STATUS_SUCCESS/STATUS_FAIL, error_message
def do_train(self, model_dir, parameters, series, context:Context):
return STATUS_SUCCESS, ''
# inference model
# Parameters:
# model_dir: input dir for model inference, model has been download and unpacked to this dir
# parameters: inference request body which include
# apiEndpoint: api endpoint for specific user
# apiKey: api key for specific user
# groupId: groupId in TSANA
# seriesSets: Array of series set
# startTime: start timestamp
# endTime: end timestamp
# instance: an info dict for this instance which includes
# instanceId: UUID for this instance
# params: inference parameters for this request
# target: a dict for inference result which include
# dimensions: dimension name list for target, defined when register plugin
# metrics: metric name list for target, defined when register plugin
# granularityName: granularity name for target, defined when register plugin
# hookIds: hook id list, defined when register plugin
# series: an array of Series object or None if config.auto_data_retrieving is False
# Series include
# series_id: UUID
# dim: dimension dict for this series
# fields: 1-d string array, ['time', '__VAL__', '__FIELD__.ExpectedValue', '__FIELD__.IsAnomaly', '__FIELD__.PredictionValue', '__FIELD__.PredictionModelScore', '__FIELD__.IsSuppress', '__FIELD__.Period', '__FIELD__.CostPoint', '__FIELD__.Mean', '__FIELD__.STD', '__FIELD__.TrendChangeAnnotate', '__FIELD__.TrendChang...tateIgnore', '__FIELD__.AnomalyAnnotate', ...]
# value: 2-d array, [['2020-10-12T17:55:00Z', 1.0, None, None, None, None, None, None, None, None, None, None, None, None, ...]]
# context: request context include subscription and model_id
# Return:
# result: STATUS_SUCCESS/STATUS_FAIL
# values: a list of value dict or None if you do not need framework to handle inference result storge, this value dict should include
# metricId: UUID, comes from metrics segment of target of request body
# dimension: dimension dict for this series, dimension names come from target segment of request body
# timestamps: string timestamps list
# values: double type value list, matching timestamps
# fields: field names list, optional
# fieldValues: 2-d array which include a value list for each field, optional
# message: error message
def do_inference(self, model_dir, parameters, series, context:Context):
return STATUS_SUCCESS, None, ''
def do_delete(self, parameters, model_id):
return STATUS_SUCCESS, ''
def get_data_time_range(self, parameters, is_training=False):
return str_to_dt(parameters['startTime']), str_to_dt(parameters['endTime'])
def train_wrapper(self, message):
start = time.time()
subscription = message['subscription']
model_id = message['model_id']
task_id = message['job_id']
parameters = message['params']
model_dir = None
log.info("Start train wrapper for model %s by %s " % (model_id, subscription))
try:
self.tsanaclient.save_training_status(task_id, parameters, ModelState.Pending.name)
model_dir = os.path.join(self.config.model_dir, subscription + '_' + model_id + '_' + str(time.time()))
os.makedirs(model_dir, exist_ok=True)
series = None
if self.config.auto_data_retrieving:
start_time, end_time = self.get_data_time_range(parameters, True)
series = self.tsanaclient.get_timeseries_gw(parameters, parameters['seriesSets'], start_time, end_time)
update_state(self.config, subscription, model_id, ModelState.Training)
self.tsanaclient.save_training_status(task_id, parameters, ModelState.Training.name)
result, message = self.do_train(model_dir, parameters, series, Context(subscription, model_id, task_id))
if result == STATUS_SUCCESS:
self.train_callback(subscription, model_id, task_id, model_dir, parameters, ModelState.Ready, None)
else:
raise Exception(message)
except Exception as e:
self.train_callback(subscription, model_id, task_id, None, parameters, ModelState.Failed, str(e))
result = STATUS_FAIL
finally:
if model_dir is not None:
shutil.rmtree(model_dir, ignore_errors=True)
total_time = (time.time() - start)
log.duration("training_task_duration", total_time, model_id=model_id, task_id=task_id, result=result, endpoint=parameters['apiEndpoint'], group_id=parameters['groupId'], group_name=parameters['groupName'].replace(' ', '_'), instance_id=parameters['instance']['instanceId'], instance_name=parameters['instance']['instanceName'].replace(' ', '_'))
log.count("training_task_count", 1, model_id=model_id, task_id=task_id, result=result, endpoint=parameters['apiEndpoint'], group_id=parameters['groupId'], group_name=parameters['groupName'].replace(' ', '_'), instance_id=parameters['instance']['instanceId'], instance_name=parameters['instance']['instanceName'].replace(' ', '_'))
gc.collect()
return STATUS_SUCCESS, ''
# inference_window: 30
# endTime: endtime
def inference_wrapper(self, message):
start = time.time()
subscription = message['subscription']
model_id = message['model_id']
task_id = message['job_id']
parameters = message['params']
log.info("Start inference wrapper %s by %s " % (model_id, subscription))
try:
self.tsanaclient.save_inference_status(task_id, parameters, InferenceState.Pending.name)
result, message = self.do_verify(parameters, Context(subscription, model_id, task_id))
if result != STATUS_SUCCESS:
raise Exception('Verify failed! ' + message)
model_dir = os.path.join(self.config.model_dir, subscription + '_' + model_id + '_' + str(time.time()))
os.makedirs(model_dir, exist_ok=True)
if self.trainable:
download_model(self.config, subscription, model_id, model_dir)
start_time, end_time = self.get_data_time_range(parameters)
if self.config.auto_data_retrieving:
series = self.tsanaclient.get_timeseries_gw(parameters, parameters['seriesSets'], start_time, end_time)
else:
series = None
self.tsanaclient.save_inference_status(task_id, parameters, InferenceState.Running.name)
result, values, message = self.do_inference(model_dir, parameters, series, Context(subscription, model_id, task_id))
self.inference_callback(subscription, model_id, task_id, parameters, result, values, message)
except Exception as e:
self.inference_callback(subscription, model_id, task_id, parameters, STATUS_FAIL, None, str(e))
finally:
shutil.rmtree(model_dir, ignore_errors=True)
total_time = (time.time() - start)
log.duration("inference_task_duration", total_time, model_id=model_id, task_id=task_id, result=result, endpoint=parameters['apiEndpoint'], group_id=parameters['groupId'], group_name=parameters['groupName'].replace(' ', '_'), instance_id=parameters['instance']['instanceId'], instance_name=parameters['instance']['instanceName'].replace(' ', '_'))
log.count("inference_task_count", 1, model_id=model_id, task_id=task_id, result=result, endpoint=parameters['apiEndpoint'], group_id=parameters['groupId'], group_name=parameters['groupName'].replace(' ', '_'), instance_id=parameters['instance']['instanceId'], instance_name=parameters['instance']['instanceName'].replace(' ', '_'))
gc.collect()
return STATUS_SUCCESS, ''
def train_callback(self, subscription, model_id, task_id, model_dir, parameters, model_state, last_error=None):
try:
meta = get_meta(self.config, subscription, model_id)
if meta is None or meta['state'] == ModelState.Deleted.name:
return STATUS_FAIL, 'Model is not found! '
if model_state == ModelState.Ready:
result, message = upload_model(self.config, subscription, model_id, model_dir)
if result != STATUS_SUCCESS:
model_state = ModelState.Failed
last_error = 'Model storage failed! ' + message
except Exception as e:
model_state = ModelState.Failed
last_error = str(e)
raise e
finally:
update_state(self.config, subscription, model_id, model_state, None, last_error)
self.tsanaclient.save_training_status(task_id, parameters, model_state.name, last_error)
self.tsanaclient.save_training_result(parameters, model_id, model_state.name, last_error)
error_message = last_error + '\n' + traceback.format_exc() if model_state != ModelState.Ready else None
log.info("Training callback by %s, model_id = %s, task_id = %s, state = %s, last_error = %s" % (subscription, model_id, task_id, model_state, error_message if error_message is not None else ''))
def inference_callback(self, subscription, model_id, task_id, parameters, result, values, last_error=None):
try:
if result == STATUS_SUCCESS and values != None:
for value in values:
result, last_error = self.tsanaclient.save_data_points(parameters, value['metricId'], value['dimension'], value['timestamps'], value['values'],
value['fields'] if 'fields' in value else None, value['fieldValues'] if 'fieldValues' in value else None)
if result != STATUS_SUCCESS:
break
except Exception as e:
result = STATUS_FAIL
last_error = str(e)
raise e
finally:
if result == STATUS_SUCCESS:
self.tsanaclient.save_inference_status(task_id, parameters, InferenceState.Ready.name)
else:
self.tsanaclient.save_inference_status(task_id, parameters, InferenceState.Failed.name, last_error)
error_message = last_error + '\n' + traceback.format_exc() if result != STATUS_SUCCESS else None
log.info("Inference callback by %s, model_id = %s, task_id = %s, result = %s, last_error = %s" % (subscription, model_id, task_id, result, error_message if error_message is not None else ''))
def train(self, request):
request_body = json.loads(request.data)
instance_id = request_body['instance']['instanceId']
if not self.trainable:
return make_response(jsonify(dict(instanceId=instance_id, modelId='', taskId='', result=STATUS_SUCCESS, message='Model is not trainable', modelState=ModelState.Ready.name)), 200)
subscription = request.headers.get('apim-subscription-id', 'Official')
request_body[INSTANCE_ID_KEY] = subscription
result, message = self.do_verify(request_body, Context(subscription, '', ''))
if result != STATUS_SUCCESS:
return make_response(jsonify(dict(instanceId=instance_id, modelId='', taskId='', result=STATUS_FAIL, message='Verify failed! ' + message, modelState=ModelState.Deleted.name)), 400)
models_in_train = []
for model in get_model_list(self.config, subscription):
if 'instanceId' in model and model['instanceId'] == request_body['instance']['instanceId'] and (model['state'] == ModelState.Training.name or model['state'] == ModelState.Pending.name):
models_in_train.append(model['modelId'])
if len(models_in_train) >= self.config.models_in_training_limit_per_instance:
return make_response(jsonify(dict(instanceId=instance_id, modelId='', taskId='', result=STATUS_FAIL, message='Models in training limit reached! Abort training this time.', modelState=ModelState.Deleted.name)), 400)
log.info('Create training task')
try:
task_id = str(uuid.uuid1())
if 'modelId' in request_body and request_body['modelId']:
model_id = request_body['modelId']
else:
model_id = str(uuid.uuid1())
insert_or_update_meta(self.config, subscription, model_id, request_body)
job = JobRecord(task_id, JobRecord.MODE_TRAINING, self.__class__.__name__, model_id, subscription, request_body)
send_message(self.training_topic, dict(job))
log.count("training_task_throughput_in", 1, topic_name=self.training_topic, model_id=model_id, endpoint=request_body['apiEndpoint'], group_id=request_body['groupId'], group_name=request_body['groupName'].replace(' ', '_'), instance_id=request_body['instance']['instanceId'], instance_name=request_body['instance']['instanceName'].replace(' ', '_'))
return make_response(jsonify(dict(instanceId=instance_id, modelId=model_id, taskId=task_id, result=STATUS_SUCCESS, message='Training task created', modelState=ModelState.Training.name)), 201)
except Exception as e:
meta = get_meta(self.config, subscription, model_id)
error_message = str(e)
if meta is not None:
update_state(self.config, subscription, model_id, ModelState.Failed, None, error_message)
log.error("Create training task failed! subscription = %s, model_id = %s, task_id = %s, last_error = %s" % (subscription, model_id, task_id, error_message + '\n' + traceback.format_exc()))
return make_response(jsonify(dict(instanceId=instance_id, modelId=model_id, taskId=task_id, result=STATUS_FAIL, message='Fail to create new task ' + error_message, modelState=ModelState.Failed.name)), 400)
def inference(self, request, model_id):
request_body = json.loads(request.data)
instance_id = request_body['instance']['instanceId']
subscription = request.headers.get('apim-subscription-id', 'Official')
request_body[INSTANCE_ID_KEY] = subscription
if self.trainable:
meta = get_meta(self.config, subscription, model_id)
if meta is None:
return make_response(jsonify(dict(instanceId=instance_id, modelId=model_id, taskId='', result=STATUS_FAIL, message='Model is not found!', modelState=ModelState.Deleted.name)), 400)
if meta['state'] != ModelState.Ready.name:
return make_response(jsonify(dict(instanceId=instance_id, modelId=model_id, taskId='', result=STATUS_FAIL, message='Cannot do inference right now, status is ' + meta['state'], modelState=meta['state'])), 400)
try:
series_set = json.loads(meta['series_set'])
except:
series_set = json.loads(zlib.decompress(base64.b64decode(meta['series_set'].encode("ascii"))).decode('utf-8'))
para = json.loads(meta['para'])
current_set = json.dumps(series_set, sort_keys=True)
current_params = json.dumps(para, sort_keys=True)
new_set = json.dumps(request_body['seriesSets'], sort_keys=True)
new_params = json.dumps(request_body['instance']['params'], sort_keys=True)
if current_set != new_set or current_params != new_params:
if self.need_retrain(series_set, para, request_body['seriesSets'], request_body['instance']['params'], Context(subscription, model_id, '')):
return make_response(jsonify(dict(instanceId=instance_id, modelId=model_id, taskId='', result=STATUS_FAIL, message='Inconsistent series sets or params!', modelState=meta['state'])), 400)
log.info('Create inference task')
task_id = str(uuid.uuid1())
job = JobRecord(task_id, JobRecord.MODE_INFERENCE, self.__class__.__name__, model_id, subscription, request_body)
send_message(self.inference_topic, dict(job))
log.count("inference_task_throughput_in", 1, topic_name=self.inference_topic, model_id=model_id, endpoint=request_body['apiEndpoint'], group_id=request_body['groupId'], group_name=request_body['groupName'].replace(' ', '_'), instance_id=request_body['instance']['instanceId'], instance_name=request_body['instance']['instanceName'].replace(' ', '_'))
return make_response(jsonify(dict(instanceId=instance_id, modelId=model_id, taskId=task_id, result=STATUS_SUCCESS, message='Inference task created', modelState=ModelState.Ready.name)), 201)
def state(self, request, model_id):
if not self.trainable:
return make_response(jsonify(dict(instanceId='', modelId=model_id, taskId='', result=STATUS_SUCCESS, message='Model is not trainable', modelState=ModelState.Ready.name)), 200)
try:
subscription = request.headers.get('apim-subscription-id', 'Official')
request_body = json.loads(request.data)
request_body[INSTANCE_ID_KEY] = subscription
meta = get_meta(self.config, subscription, model_id)
if meta == None:
return make_response(jsonify(dict(instanceId='', modelId=model_id, taskId='', result=STATUS_FAIL, message='Model is not found!', modelState=ModelState.Deleted.name)), 400)
meta = clear_state_when_necessary(self.config, subscription, model_id, meta)
return make_response(jsonify(dict(instanceId='', modelId=model_id, taskId='', result=STATUS_SUCCESS, message=meta['last_error'] if 'last_error' in meta else '', modelState=meta['state'])), 200)
except Exception as e:
error_message = str(e)
log.error("Get model state failed! subscription = %s, model_id = %s, last_error = %s" % (subscription, model_id, error_message + '\n' + traceback.format_exc()))
return make_response(jsonify(dict(instanceId='', modelId=model_id, taskId='', result=STATUS_FAIL, message=error_message, modelState=ModelState.Failed.name)), 400)
def list_models(self, request):
subscription = request.headers.get('apim-subscription-id', 'Official')
return make_response(jsonify(get_model_list(self.config, subscription)), 200)
def delete(self, request, model_id):
if not self.trainable:
return make_response(jsonify(dict(instanceId='', modelId=model_id, taskId='', result=STATUS_SUCCESS, message='Model is not trainable')), 200)
try:
subscription = request.headers.get('apim-subscription-id', 'Official')
request_body = json.loads(request.data)
request_body[INSTANCE_ID_KEY] = subscription
instance_id = request_body['instance']['instanceId']
result, message = self.do_delete(request_body, model_id)
if result == STATUS_SUCCESS:
update_state(self.config, subscription, model_id, ModelState.Deleted)
return make_response(jsonify(dict(instanceId=instance_id, modelId=model_id, taskId='', result=STATUS_SUCCESS, message='Model {} has been deleted'.format(model_id), modelState=ModelState.Deleted.name)), 200)
else:
raise Exception(message)
except Exception as e:
error_message = str(e)
log.error("Delete model failed! subscription = %s, model_id = %s, last_error = %s" % (subscription, model_id, error_message + '\n' + traceback.format_exc()))
return make_response(jsonify(dict(instanceId='', modelId=model_id, taskId='', result=STATUS_FAIL, message=error_message, modelState=ModelState.Failed.name)), 400)
def verify(self, request):
request_body = json.loads(request.data)
instance_id = request_body['instance']['instanceId']
subscription = request.headers.get('apim-subscription-id', 'Official')
request_body[INSTANCE_ID_KEY] = subscription
try:
result, message = self.do_verify(request_body, Context(subscription, '', ''))
if result != STATUS_SUCCESS:
return make_response(jsonify(dict(instanceId=instance_id, modelId='', taskId='', result=STATUS_FAIL, message='Verify failed! ' + message, modelState=ModelState.Deleted.name)), 400)
else:
return make_response(jsonify(dict(instanceId=instance_id, modelId='', taskId='', result=STATUS_SUCCESS, message='Verify successfully! ' + message, modelState=ModelState.Deleted.name)), 200)
except Exception as e:
error_message = str(e)
log.error("Verify parameters failed! subscription = %s, instance_id = %s, last_error = %s" % (subscription, instance_id, error_message + '\n' + traceback.format_exc()))
return make_response(jsonify(dict(instanceId=instance_id, modelId='', taskId='', result=STATUS_FAIL, message='Verify failed! ' + error_message, modelState=ModelState.Deleted.name)), 400)