dlvm/nvidia/webapp/server.py (111 lines of code) (raw):

# Copyright 2019 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Web server to access Prediction server in GCP. A Flask server used for Predictions. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from ast import literal_eval from flask import Flask from flask import render_template from flask import request from flask import redirect from PIL import Image from werkzeug.utils import secure_filename import base64 import logging import numpy as np import os import requests app = Flask(__name__, static_folder='static') MODEL_TYPE = 'jpg' # tensor | jpg LOAD_BALANCER = '' # Load Balancer IP Address (Network Services|Load Balancing) URL = 'http://%s/v1/models/default:predict' % LOAD_BALANCER UPLOAD_FOLDER = 'static/images/' ALLOWED_EXTENSIONS = set(['png', 'jpg', 'jpeg']) # Wait this long for outgoing HTTP connections to be established. _CONNECT_TIMEOUT_SECONDS = 90 # Wait this long to read from an HTTP socket. _READ_TIMEOUT_SECONDS = 120 app.config['MAX_CONTENT_LENGTH'] = 1 * 1024 * 1024 app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER def get_classes(): """Get classes... Returns: A dictionary with class information: {1: 'cat'} """ # Classes dictionary. url = 'https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw' \ '/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5' \ '/imagenet1000_clsidx_to_labels.txt' response = requests.get(url) response.raise_for_status() return literal_eval(response.text) def convert_to_json(image_file): """Open image, convert it to numpy and create JSON request Args: image_file: A `str` with file path Returns: A dictionary used to get inference using Tensors. """ img = Image.open(image_file).resize((240, 240)) img_array = np.array(img) predict_request = {"instances": [img_array.tolist()]} return predict_request def convert_to_base64(image_file): """Open image and convert it to base64 Args: image_file: A `str` with file path Returns: A dictionary used to get inference using Image Base64. """ with open(image_file, 'rb') as f: jpeg_bytes = base64.b64encode(f.read()).decode('utf-8') return '{"instances" : [{"b64": "%s"}]}' % jpeg_bytes def conversion_helper(model_type, filename): """ :param model_type: :param filename: :return: """ if model_type == 'jpg': return convert_to_base64( os.path.join(app.config['UPLOAD_FOLDER'], filename)) elif model_type == 'tensor': return convert_to_json( os.path.join(app.config['UPLOAD_FOLDER'], filename)) else: logging.error('Invalid model') return redirect(request.url) def model_predict(predict_request): """Sends Image for prediction. Args: predict_request: A dictionary used for Inference Returns: A JSON object with inference response. """ session = requests.Session() try: response = session.post( URL, data=predict_request, timeout=(_CONNECT_TIMEOUT_SECONDS, _READ_TIMEOUT_SECONDS), allow_redirects=False) response.raise_for_status() return response.json() except requests.exceptions.HTTPError as err: logging.exception(err) if err.response.status_code == 400: logging.exception('Server error %s', URL) return if err.response.status_code == 404: logging.exception('Page not found %s', URL) return def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS @app.route('/') def home(): return render_template('index.html') @app.route('/upload', methods=['POST']) def upload_file(): logging.info(request.files) if 'image' not in request.files: logging.error('Error. No file part') return redirect(request.url) file = request.files['image'] if not file.filename: logging.error('Not selected file') return redirect(request.url) if file and allowed_file(file.filename): filename = secure_filename(file.filename) user_image = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(user_image) # Call API for prediction. predict_request = conversion_helper(MODEL_TYPE, filename) try: if predict_request: response = model_predict(predict_request) if response: prediction_class = response.get('predictions')[0].get('classes') - 1 prediction_probabilities = response.get('predictions')[0].get( 'probabilities') if prediction_class: return render_template('index.html', init=True, user_image=filename, prediction=classes[prediction_class]) else: return render_template('index.html', init=True, prediction=None) except Exception as e: logging.exception('Not a valid request: {} '.format(e)) return redirect(request.url) # Obtain classes before Server starts. classes = get_classes() if __name__ == '__main__': if not LOAD_BALANCER: raise ValueError('Define Load Balancer') app.run(debug=True, port=8001)