tools/agile-machine-learning-api/main.py (367 lines of code) (raw):

# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ End-to-End API Framework for Componentised Machine Learning Applications """ import json import logging import os import uuid from logging.handlers import RotatingFileHandler import tensorflow as tf import yaml from flask import Flask, request, Response from flask_json_schema import JsonSchema, JsonValidationError import deploy import predict import train from lime_utils import visualization, visualization_2 from schema import TRAIN_SCHEMA, PREDICT_SCHEMA, DEPLOY_SCHEMA, LIME_SCHEMA, LIME_SCHEMA_2 APP = Flask(__name__) SCHEMA = JsonSchema(APP) def read_yaml(): """ Reads the config file to variables Returns: A dict containing configurations """ with open("config/config_file.yaml", 'r') as ymlfile: return yaml.load(ymlfile) def get_job_link(): """ Reads the config file to variables Returns: A job link string """ with open('config/developer.yaml', 'r') as ymlfile: return yaml.load(ymlfile)['job_link'] @APP.errorhandler(JsonValidationError) def validation(error): """ Handles validation message for type casting of parameters Arguments : error: object, Error message Returns: Response of the validation error """ return Response( json.dumps( { 'Message': error.message, 'Data': [ validation_error.message for validation_error in error.errors], 'Success': False}), status=400, mimetype='application/json') @APP.route('/train', methods=['POST']) @SCHEMA.validate(TRAIN_SCHEMA) def app_train(): """ Training API call Returns: Json Response of Training API call Raise: Validation error : If data types of input parameters is incorrect Access Denied to project : When the given service account key cannot interact with GCP. """ return_message = json.dumps({ "Success": False, "Message": "", "Data": {} }) response_code = 400 try: call_id = uuid.uuid4() cfg = read_yaml() jobid = 'C' + str(call_id).replace('-', '_') payload = request.get_json() if isinstance(payload['train_csv_path'], list): train_csv_path = ' '.join([os.path.join(cfg['bucket_name'], str( path)) for path in payload['train_csv_path']]) else: train_csv_path = os.path.join( cfg['bucket_name'], payload['train_csv_path']) eval_csv_path = os.path.join( cfg['bucket_name'], payload['eval_csv_path']) export_dir = os.path.join( cfg['bucket_name'], payload['export_dir'], jobid) APP.logger.info('[{}] Config file loaded'.format(jobid)) response = train.post( cfg=cfg, train_csv_path=train_csv_path, eval_csv_path=eval_csv_path, task_type=payload['task_type'], target_var=payload['target_var'], data_type=( 'None' if payload.get('data_type') is None else str( payload['data_type'])), column_name=( 'None' if payload.get('column_name') is None else str( payload['column_name'])), na_values=('None' if payload.get('na_values') is None else str( payload['na_values'])), condition=('None' if payload.get('condition') is None else str( payload['condition'])), n_classes=( '2' if payload.get('n_classes') is None else str( payload['n_classes'])), to_drop=('None' if payload.get('to_drop') is None else str( payload['to_drop'])), name=payload['name'], hidden_units=( '64' if payload.get('hidden_units') is None else str( payload['hidden_units'])), num_layers=( '2' if payload.get('num_layers') is None else str( payload['num_layers'])), lin_opt=( 'ftrl' if payload.get('lin_opt') is None else payload['lin_opt']), deep_opt=( 'adam' if payload.get('deep_opt') is None else payload['deep_opt']), train_steps=( '50000' if payload.get('train_steps') is None else str( payload['train_steps'])), export_dir=export_dir, jobid=jobid) APP.logger.info('[{}] '.format(jobid) + str(payload)) APP.logger.info('[{}] Training Job submitted to CMLE'.format(jobid)) return_message = json.dumps({ "Success": True, "Message": "{}/{}?project={}".format(get_job_link(), jobid, cfg['project_id']), "Data": { 'jobid': jobid, 'response': response } }) response_code = 200 except IOError as err: APP.logger.error(str(err)) return_message = json.dumps({ "Success": False, "Message": "Please check the config.yaml file", "Data": {"error_message": str(err)} }) response_code = 500 except AssertionError as err: APP.logger.error(str(err)) return_message = json.dumps( {"Success": False, "Message": str(err), "Data": []}) response_code = 500 except Exception as err: APP.logger.error(str(err)) return_message = json.dumps({ "Success": False, "Message": str(err), "Data": err }) response_code = 500 finally: return Response( return_message, status=response_code, mimetype='application/json') @APP.route('/deploy', methods=['POST']) @SCHEMA.validate(DEPLOY_SCHEMA) def app_deploy(): """ Deployment API call Returns: JSON response of Deployment API call Raise: Validation error : If data types of input parameters is incorrect Access Denied to project : When the given service account key cannot interact with GCP. """ return_message = json.dumps({ "Success": False, "Message": "", "Data": {} }) response_code = 400 try: cfg = read_yaml() APP.logger.info('Config file loaded') payload = request.get_json() response = deploy.post( cfg=cfg, job_id=payload['job_id'], model_name=payload['model_name'], version_name=payload['version_name'], trained_model_location=payload['trained_model_location'], runtime_version=payload['runtime_version'] ) return_message = json.dumps({"Success": True, "Message": "Model is successfully deployed", "Data": response}) APP.logger.info('route /deploy has been called') APP.logger.info('[{}]'.format(payload)) APP.logger.info(return_message) response_code = 200 except IOError as err: APP.logger.error(str(err)) APP.logger.info('Invalid config.yaml file has been loaded') return_message = json.dumps({"Success": False, "Message": "Please check the config.yaml file", "Data": {"error_message": str(err)}}) response_code = 500 except IndexError as err: APP.logger.error(str(err)) APP.logger.info('Unable to locate saved model location') return_message = json.dumps( { "Success": False, "Message": "Please provide a valid 'job_id' and 'trained_model_location'", "Data": []}) response_code = 500 except AssertionError as err: APP.logger.error(str(err)) return_message = json.dumps( {"Success": False, "Message": str(err), "Data": []}) response_code = 500 except Exception as err: APP.logger.error(err) return_message = json.dumps( {"Success": False, "Message": str(err), "Data": None}) response_code = 500 finally: return Response( return_message, status=response_code, mimetype='application/json') @APP.route('/predict', methods=['POST']) @SCHEMA.validate(PREDICT_SCHEMA) def app_predict(): """ Predict function for deployed models Returns: JSON response of Prediction API call Raise: Validation error : If data types of input parameters is incorrect Access Denied to project : When the given service account key cannot interact with GCP. """ return_message = json.dumps({ "Success": False, "Message": "", "Data": {} }) response_code = 400 try: cfg = read_yaml() APP.logger.info('Config file loaded') payload = request.get_json() response = predict.post(cfg=cfg, model_name=payload['model_name'], instances=payload['instances'], version_name=payload['version_name']) return_message = json.dumps({ "Success": True, "Message": "Predictions done", "Data": [["%.4f" % x for x in point['probabilities']] for point in response]}) APP.logger.info('[{}]'.format(payload)) APP.logger.info(return_message) response_code = 200 except IOError as err: APP.logger.error(str(err)) return_message = json.dumps( {"Success": False, "Message": "Please check the config.yaml file", "Data": []}) response_code = 500 except AssertionError as err: APP.logger.error(str(err)) return_message = json.dumps( {"Success": False, "Message": str(err), "Data": []}) response_code = 500 except KeyError as err: APP.logger.error( 'Error in fetching the response of the predict function') return_message = json.dumps( { "Success": False, "Message": { "Message": "Please check prediction data-points given to the API call", "Error_message": str(err)}, "Data": None}) response_code = 500 except Exception as err: APP.logger.error(err) return_message = json.dumps( {"Success": False, "Message": str(err), "Data": None}) response_code = 500 finally: return Response( return_message, status=response_code, mimetype='application/json') @SCHEMA.validate(LIME_SCHEMA) @APP.route('/predict/lime', methods=['POST']) def lime_prediction(): response_code = 400 return_message = json.dumps({ "Success": False, "Message": "", "Data": {} }) try: cfg = read_yaml() payload = request.get_json() result = visualization( cfg=cfg, job_id=payload['job_id'], model_dir=payload['export_dir'], predict_json=payload['predict_json'], batch_prediction=payload['batch_prediction'], d_points=payload['data_points'], name=payload['name'] ) response_code = 200 return_message = json.dumps({ "Success": True, "Message": str(result), "Data": {} }) except IOError as err: APP.logger.error(str(err)) return_message = json.dumps( {"Success": False, "Message": "Please check the config.yaml file", "Data": []}) response_code = 500 except AssertionError as err: APP.logger.error(str(err)) return_message = json.dumps( {"Success": False, "Message": str(err), "Data": []}) response_code = 500 except tf.errors.InvalidArgumentError as err: APP.logger.error(str(err._message)) response_code = 500 return_message = json.dumps({ "Success": False, "Message": str(err._message.split('\n')[0]), "Data": {} }) except KeyError as err: response_code = 500 APP.logger.error( str('Following feature[s] missing in the data provided {}'.format(err))) return_message = json.dumps({"Success": False, "Message": str( 'Following feature[s] missing in the data provided {}'.format(err)), "Data": {}}) finally: return Response( return_message, status=response_code, mimetype='application/json') @SCHEMA.validate(LIME_SCHEMA_2) @APP.route('/predict/lime2', methods=['POST']) def lime_prediction_2(): return_message = json.dumps({ "Success": False, "Message": "", "Data": {} }) response_code = 500 try: cfg = read_yaml() payload = request.get_json() result = visualization_2( cfg=cfg, job_id=payload['job_id'], model_dir=payload['export_dir'], predict_json=payload['predict_json'], batch_prediction=payload['batch_prediction'], name=payload['name']) response_code = 200 return_message = json.dumps( {"Success": True, "Message": result, "Data": []}) except IOError as err: APP.logger.error(str(err)) return_message = json.dumps( {"Success": False, "Message": "Please check the config.yaml file", "Data": []}) response_code = 500 except ValueError as err: APP.logger.error(str(err)) return_message = json.dumps( {"Success": False, "Message": str(err), "Data": []}) response_code = 500 except AssertionError as err: APP.logger.error(str(err)) return_message = json.dumps( {"Success": False, "Message": str(err), "Data": []}) response_code = 500 except tf.errors.InvalidArgumentError as err: APP.logger.error(str(err._message)) response_code = 500 return_message = json.dumps({ "Success": False, "Message": str(err._message.split('\n')[0]), "Data": {} }) except Exception as err: return_message = json.dumps( {"Success": False, "Message": str(err), "Data": []}) response_code = 500 finally: return Response( return_message, status=response_code, mimetype='application/json') if __name__ == '__main__': HANDLER = RotatingFileHandler('api.log', maxBytes=10000, backupCount=1) FORMATTER = logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s') HANDLER.setFormatter(FORMATTER) APP.logger.addHandler(HANDLER) APP.logger.setLevel(logging.INFO) APP.run(host='127.0.0.1', port=8080, debug=False, threaded=True)