in decisionai_plugin/common/plugin_service.py [0:0]
def train_wrapper(self, message):
start = time.time()
subscription = message['subscription']
model_id = message['model_id']
task_id = message['job_id']
parameters = message['params']
model_dir = None
log.info("Start train wrapper for model %s by %s " % (model_id, subscription))
try:
self.tsanaclient.save_training_status(task_id, parameters, ModelState.Pending.name)
model_dir = os.path.join(self.config.model_dir, subscription + '_' + model_id + '_' + str(time.time()))
os.makedirs(model_dir, exist_ok=True)
series = None
if self.config.auto_data_retrieving:
start_time, end_time = self.get_data_time_range(parameters, True)
series = self.tsanaclient.get_timeseries_gw(parameters, parameters['seriesSets'], start_time, end_time)
update_state(self.config, subscription, model_id, ModelState.Training)
self.tsanaclient.save_training_status(task_id, parameters, ModelState.Training.name)
result, message = self.do_train(model_dir, parameters, series, Context(subscription, model_id, task_id))
if result == STATUS_SUCCESS:
self.train_callback(subscription, model_id, task_id, model_dir, parameters, ModelState.Ready, None)
else:
raise Exception(message)
except Exception as e:
self.train_callback(subscription, model_id, task_id, None, parameters, ModelState.Failed, str(e))
result = STATUS_FAIL
finally:
if model_dir is not None:
shutil.rmtree(model_dir, ignore_errors=True)
total_time = (time.time() - start)
log.duration("training_task_duration", total_time, model_id=model_id, task_id=task_id, result=result, endpoint=parameters['apiEndpoint'], group_id=parameters['groupId'], group_name=parameters['groupName'].replace(' ', '_'), instance_id=parameters['instance']['instanceId'], instance_name=parameters['instance']['instanceName'].replace(' ', '_'))
log.count("training_task_count", 1, model_id=model_id, task_id=task_id, result=result, endpoint=parameters['apiEndpoint'], group_id=parameters['groupId'], group_name=parameters['groupName'].replace(' ', '_'), instance_id=parameters['instance']['instanceId'], instance_name=parameters['instance']['instanceName'].replace(' ', '_'))
gc.collect()
return STATUS_SUCCESS, ''