pai/predictor.py (889 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import functools
import json
import posixpath
import time
from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from io import IOBase
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from urllib.parse import urlencode
import aiohttp
import requests
from .common.consts import FrameworkTypes
from .common.docker_utils import ContainerRun
from .common.logging import get_logger
from .common.utils import http_user_agent, is_package_available
from .exception import PredictionException
from .serializers import (
JsonSerializer,
PyTorchSerializer,
SerializerBase,
TensorFlowSerializer,
)
from .session import Session, get_default_session
if is_package_available("openai"):
from openai import OpenAI
logger = get_logger(__name__)
_PAI_SERVICE_CONSOLE_URI_PATTERN = (
"https://pai.console.aliyun.com/?regionId={region_id}&workspaceId={workspace_id}#"
"/eas/serviceDetail/{service_name}/detail"
)
_QUEUE_SERVICE_REQUEST_ID_HEADER = "X-Eas-Queueservice-Request-Id"
_QUEUE_SERVICE_SINK_PATH = "sink"
_DEFAULT_ASYNC_WORKER_COUNT = 30
class ServiceStatus(object):
"""All EAS inference service status."""
Running = "Running"
Waiting = "Waiting"
Scaling = "Scaling"
Stopped = "Stopped"
Failed = "Failed"
DeleteFailed = "DeleteFailed"
@classmethod
def completed_status(cls):
return [
cls.Running,
cls.Stopped,
cls.Failed,
cls.DeleteFailed,
]
class EndpointType(object):
# Public Internet Endpoint
INTERNET = "INTERNET"
# VPC Endpoint
INTRANET = "INTRANET"
class ServiceType(object):
Standard = "Standard"
Async = "Async"
class PredictorBase(ABC):
@abstractmethod
def predict(self, *args, **kwargs) -> Any:
"""Perform inference on the provided data and return prediction result."""
@abstractmethod
def raw_predict(
self,
data: Any = None,
path: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
method: str = "POST",
timeout: Optional[Union[float, Tuple[float, float]]] = None,
**kwargs,
):
pass
class RawResponse(object):
"""Response object returned by the predictor.raw_predict."""
def __init__(self, status_code: int, headers: Dict[str, str], content: bytes):
"""Initialize a RawResponse object.
Args:
status_code (int):
headers (dict):
content (bytes):
"""
self.status_code = status_code
self.headers = headers
self.content = content
def json(self):
"""Returns the json-encoded content of a response
Returns:
Dict[str, Any]: The json-encoded content of a response.
"""
return json.loads(self.content)
class _ServicePredictorMixin(object):
def __init__(
self,
service_name: str,
session: Optional[Session] = None,
endpoint_type: str = EndpointType.INTERNET,
serializer: Optional[SerializerBase] = None,
):
self.service_name = service_name
self.session = session or get_default_session()
self._service_api_object = self.describe_service()
self.endpoint_type = endpoint_type
self.serializer = serializer or self._get_default_serializer()
self._request_session = requests.Session()
def __repr__(self):
return "{}(service_name={}, endpoint_type={})".format(
type(self).__name__,
self.service_name,
self.endpoint_type,
)
def __del__(self):
self._request_session.close()
def refresh(self):
self._service_api_object = self.describe_service()
@property
def endpoint(self):
if self.endpoint_type == EndpointType.INTRANET:
return self._service_api_object["IntranetEndpoint"]
else:
return self._service_api_object["InternetEndpoint"]
@property
def intranet_endpoint(self):
return self._service_api_object["IntranetEndpoint"]
@property
def internet_endpoint(self):
return self._service_api_object["InternetEndpoint"]
@property
def service_status(self):
"""Returns the status of the service."""
return self._service_api_object["Status"]
@property
def access_token(self) -> str:
"""Access token of the service."""
return self._service_api_object["AccessToken"]
@property
def labels(self) -> Dict[str, str]:
"""Labels of the service."""
labels = {
item["LabelKey"]: item["LabelValue"]
for item in self._service_api_object.get("Labels", [])
}
return labels
@property
def console_uri(self):
"""Returns the console URI of the service."""
return _PAI_SERVICE_CONSOLE_URI_PATTERN.format(
workspace_id=self.session.workspace_id,
region_id=self.session.region_id,
service_name=self.service_name,
)
def _get_default_serializer(self):
"""Get default serializer for the predictor by inspecting the service config."""
from pai.model._model import _BuiltinProcessor
service_config = json.loads(self._service_api_object["ServiceConfig"])
processor_code = service_config.get("processor")
# If the prediction service is serving with custom processor or custom
# container, use JsonSerializer as default serializer.
if not processor_code:
return JsonSerializer()
if processor_code in (
_BuiltinProcessor.PMML,
_BuiltinProcessor.XGBoost,
):
return JsonSerializer()
elif processor_code.startswith(FrameworkTypes.TensorFlow.lower()):
serializer = TensorFlowSerializer()
return serializer
elif processor_code.startswith(FrameworkTypes.PyTorch.lower()):
return PyTorchSerializer()
else:
return JsonSerializer()
def _post_init_serializer(self):
"""Post-initialize the serializer by invoking serializer.inspect_from_service"""
if not hasattr(self.serializer, "__post_init_serializer_flag") and hasattr(
self.serializer, "inspect_from_service"
):
self.serializer.inspect_from_service(
self.service_name, session=self.session
)
setattr(self.serializer, "__post_init_serializer_flag", 1)
def inspect_model_signature_def(self):
"""Get SignatureDef of the serving model.
.. note::
Only the service using the TensorFlow processor supports getting the
model signature_definition.
Returns:
Dict[str, Any]: A dictionary representing the signature definition of the
serving model.
"""
service_config = json.loads(self._service_api_object["ServiceConfig"])
processor_code = service_config.get("processor")
if processor_code and processor_code.startswith("tensorflow"):
return TensorFlowSerializer.inspect_model_signature_def(
self.service_name, session=self.session
)
raise RuntimeError(
"Only the online prediction service using the TensorFlow processor supports"
" getting the signature_definition"
)
def describe_service(self) -> Dict[str, Any]:
"""Describe the service that referred by the predictor.
Returns:
Dict[str, Any]: Response from PAI API service.
"""
return self.session.service_api.get(self.service_name)
def start_service(self, wait=True):
"""Start the stopped service."""
self.session.service_api.start(name=self.service_name)
if wait:
status = ServiceStatus.Running
unexpected_status = ServiceStatus.completed_status()
unexpected_status.remove(status)
type(self)._wait_for_status(
service_name=self.service_name,
status=status,
unexpected_status=unexpected_status,
session=self.session,
)
self.refresh()
def stop_service(self, wait=True):
"""Stop the running service."""
self.session.service_api.stop(name=self.service_name)
if wait:
status = ServiceStatus.Stopped
unexpected_status = ServiceStatus.completed_status()
unexpected_status.remove(status)
unexpected_status.remove(ServiceStatus.Running)
type(self)._wait_for_status(
service_name=self.service_name,
status=status,
unexpected_status=unexpected_status,
session=self.session,
)
self.refresh()
def delete_service(self):
"""Delete the service."""
self.session.service_api.delete(name=self.service_name)
def wait_for_ready(self):
"""Wait until the service enter running status.
Raises:
RuntimeError: Raise if the service terminated unexpectedly.
"""
if self.service_status == ServiceStatus.Running:
return
logger.info(
"Service waiting for ready: service_name={}".format(self.service_name)
)
unexpected_status = ServiceStatus.completed_status()
unexpected_status.remove(ServiceStatus.Running)
type(self)._wait_for_status(
service_name=self.service_name,
status=ServiceStatus.Running,
unexpected_status=unexpected_status,
session=self.session,
)
# hack: PAI-EAS gateway may not be ready when the service is ready.
self._wait_for_gateway_ready()
self.refresh()
def wait(self):
"""Wait for the service to be ready."""
return self.wait_for_ready()
def _wait_for_gateway_ready(self, attempts: int = 60, interval: int = 2):
"""Hacky way to wait for the service gateway to be ready.
Args:
attempts (int): Number of attempts to wait for the service gateway to be
ready.
interval (int): Interval between each attempt.
"""
def _is_gateway_ready():
# can't use HEAD method to check gateway status because the service will
# block the request until timeout.
resp = self._send_request(method="GET")
logger.debug(
"Check gateway status result: status_code=%s content=%s",
resp.status_code,
resp.content,
)
res = not (
# following status code and content indicates the gateway is not ready
(
resp.status_code == 503
and (b"no healthy upstream" in resp.content or not resp.content)
)
or (resp.status_code == 404 and not resp.content)
)
return res
err_count_threshold = 3
err_count = 0
while attempts > 0:
attempts -= 1
try:
if _is_gateway_ready():
break
except requests.exceptions.RequestException as e:
err_count += 1
if err_count >= err_count_threshold:
logger.warning("Failed to check gateway status: %s", e)
break
time.sleep(interval)
else:
logger.warning("Timeout waiting for gateway to be ready.")
@classmethod
def _wait_for_status(
cls,
service_name: str,
status: str,
unexpected_status: List[str],
interval: int = 3,
session: Optional[Session] = None,
):
session = session or get_default_session()
service_api_object = session.service_api.get(service_name)
last_status = service_api_object["Status"]
last_msg = service_api_object["Message"]
time.sleep(interval)
while True:
service_api_object = session.service_api.get(service_name)
# Check the service status
cur_status = service_api_object["Status"]
if cur_status == status:
return status
elif unexpected_status and cur_status in unexpected_status:
# Unexpected terminated status
raise RuntimeError(
f"The Service terminated unexpectedly: "
f"name={service_api_object['ServiceName']} "
f"status={service_api_object['Status']} "
f"reason={service_api_object['Reason']} "
f"message={service_api_object['Message']}."
)
elif (
last_status == cur_status and service_api_object["Message"] == last_msg
) and cur_status != ServiceStatus.Waiting:
# If service.status and service.message have not changed and
# service.status is not 'Waiting', do not print the service
# status/message.
pass
else:
logger.info(
f"Refresh Service status: "
f"name={service_api_object['ServiceName']} "
f"id={service_api_object['ServiceId']} "
f"status={service_api_object['Status']} "
f"reason={service_api_object['Reason']} "
f"message={service_api_object['Message']}."
)
last_status = service_api_object["Status"]
last_msg = service_api_object["Message"]
time.sleep(interval)
def switch_version(self, version: int):
"""Switch service to target version.
Args:
version (int): Target version
"""
service_api_object = self.describe_service()
current_version = service_api_object["CurrentVersion"]
latest_version = service_api_object["LatestVersion"]
if current_version == version:
raise ValueError("Target version equals to current version.")
if version > latest_version:
raise ValueError("Target version greater than latest version.")
self.session.service_api.update_version(self.service_name, version=version)
@classmethod
def deploy(
cls,
config: Dict[str, Any],
session: Optional[Session] = None,
endpoint_type: str = EndpointType.INTERNET,
serializer: Optional[SerializerBase] = None,
wait: bool = True,
) -> PredictorBase:
"""Deploy an online prediction service using given configuration.
Args:
config (Dict[str, Any]): A dictionary of service configuration.
session (:class:`pai.session.Session`, optional): An optional
session object. If not provided, a default session will be used.
serializer: An optional serializer object. If not provided, a
default serializer will be used.
endpoint_type: The type of endpoint to use.
wait: Whether to wait for the service to be ready before returning.
Returns:
:class:`pai.predictor.PredictorBase`: A Predictor object for the deployed
online prediction service.
"""
session = session or get_default_session()
name = session.service_api.create(config=config)
if wait:
# Wait until the service is ready
unexpected_status = ServiceStatus.completed_status()
unexpected_status.remove(ServiceStatus.Running)
Predictor._wait_for_status(
service_name=name,
status=ServiceStatus.Running,
unexpected_status=unexpected_status,
session=session,
)
service_api_obj = session.service_api.get(name)
if service_api_obj["ServiceType"] == ServiceType.Async:
p = AsyncPredictor(
service_name=name,
endpoint_type=endpoint_type,
serializer=serializer,
)
else:
p = Predictor(
service_name=name,
endpoint_type=endpoint_type,
serializer=serializer,
)
return p
def _build_url(
self, path: Optional[str] = None, params: Dict[str, str] = None
) -> str:
url = self.endpoint
if path:
if path.startswith("/"):
path = path[1:]
url = posixpath.join(url, path)
# Add params to URL
url = url + "?" + urlencode(params) if params else url
return url
def _build_headers(self, headers: Dict[str, str] = None) -> Dict[str, str]:
headers = headers or dict()
headers["Authorization"] = self.access_token
headers["User-Agent"] = http_user_agent(headers.get("User-Agent"))
return headers
def _handle_input(self, data):
return self.serializer.serialize(data) if self.serializer else data
def _handle_output(self, content: bytes):
return self.serializer.deserialize(content) if self.serializer else content
def _handle_raw_input(self, data):
if isinstance(data, (IOBase, bytes, str)):
# if data is a file-like object, bytes, or string, it will be sent as
# request body
json_data, data = None, data
else:
# otherwise, it will be treated as a JSON serializable object and sent as
# JSON.
json_data, data = data, None
return json_data, data
def _handle_raw_output(self, status_code: int, headers: dict, content: bytes):
return RawResponse(status_code, headers, content)
def _send_request(
self,
data=None,
path=None,
method="POST",
json=None,
headers=None,
params=None,
**kwargs,
):
url = self._build_url(path)
resp = self._request_session.request(
url=url,
json=json,
data=data,
headers=self._build_headers(headers),
method=method,
params=params,
**kwargs,
)
return resp
async def _send_request_async(
self,
data=None,
path=None,
method="POST",
json=None,
headers=None,
params=None,
**kwargs,
):
url = self._build_url(path=path, params=params)
headers = self._build_headers(headers)
async with aiohttp.ClientSession() as session:
return await session.request(
method=method,
url=url,
headers=headers,
data=data,
json=json,
**kwargs,
)
class Predictor(PredictorBase, _ServicePredictorMixin):
"""Predictor is responsible for making prediction to an online service.
The `predictor.predict` method sends the input data to the online prediction service
and returns the prediction result. The serializer object of the predictor is
responsible for data transformation when the `predict` method is invoked. The input
data is serialized using the `serializer.serialize` method before it is sent, and
the response is deserialized using the `serializer.deserialize` method before the
prediction result returns.
Examples::
# Initialize a predictor object from an existing service using PyTorch
# processor.
torch_predictor = Predictor(service_name="example_torch_service")
result = torch_predictor.predict(numpy.asarray([[22,33,44], [19,22,33]]))
assert isinstance(result, numpy.ndarray)
"""
def __init__(
self,
service_name: str,
endpoint_type: str = EndpointType.INTERNET,
serializer: Optional[SerializerBase] = None,
session: Optional[Session] = None,
):
"""Construct a `Predictor` object using an existing prediction service.
Args:
service_name (str): Name of the existing prediction service.
endpoint_type (str): Selects the endpoint used by the predictor, which
should be one of `INTERNET` or `INTRANET`. The `INTERNET` endpoint type
means that the predictor calls the service over a public endpoint, while
the `INTRANET` endpoint type is over a VPC endpoint.
serializer (SerializerBase, optional): A serializer object that transforms
the input Python object for data transmission and deserialize the
response data to Python object.
session (Session, optional): A PAI session object used for communicating
with PAI service.
"""
super(Predictor, self).__init__(
service_name=service_name,
session=session or get_default_session(),
endpoint_type=endpoint_type,
serializer=serializer,
)
self._check()
def _check(self):
config = json.loads(self._service_api_object["ServiceConfig"])
if config.get("metadata", {}).get("type") == ServiceType.Async:
logger.warning(
"Predictor is not recommended to make prediction to a async"
" prediction service."
)
def predict(self, data):
"""Make a prediction with the online prediction service.
The serializer object for the predictor is responsible for data transformation
when the 'predict' method is invoked. The input data is serialized using the
`serializer.serialize` method before it is sent, and the response is
deserialized using the `serializer.deserialize` method before the prediction
result returns.
Args:
data: The input data for the prediction. It will be serialized using the
serializer of the predictor before transmitted to the prediction
service.
Returns:
object: Prediction result.
Raises:
PredictionException: Raise if status code of the prediction response does
not equal 2xx.
"""
self._post_init_serializer()
data = self._handle_input(data)
resp = self._send_request(
data,
)
if resp.status_code // 100 != 2:
raise PredictionException(resp.status_code, resp.content)
return self._handle_output(
resp.content,
)
def raw_predict(
self,
data: Any = None,
path: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
method: str = "POST",
timeout: Optional[Union[float, Tuple[float, float]]] = None,
**kwargs,
) -> RawResponse:
"""Make a prediction with the online prediction service.
Args:
data (Any): Input data to be sent to the prediction service. If it is a
file-like object, bytes, or string, it will be sent as the request body.
Otherwise, it will be treated as a JSON serializable object and sent as
JSON.
path (str, optional): Path for the request to be sent to. If it is provided,
it will be appended to the endpoint URL (Default None).
headers (dict, optional): Request headers.
method (str, optional): Request method, default to 'POST'.
timeout(float, tuple(float, float), optional): Timeout setting for the
request (Default 10).
**kwargs: Additional keyword arguments for the request.
Returns:
RawResponse: Prediction response from the service.
Raises:
PredictionException: Raise if status code of the prediction response does
not equal 2xx.
"""
json_data, data = self._handle_raw_input(data)
resp = self._send_request(
data=data,
json=json_data,
method=method,
path=path,
headers=headers,
timeout=timeout,
**kwargs,
)
if resp.status_code // 100 != 2:
raise PredictionException(resp.status_code, resp.content)
resp = RawResponse(
status_code=resp.status_code,
content=resp.content,
headers=dict(resp.headers),
)
return resp
def openai(self, url_suffix: str = "v1", **kwargs) -> "OpenAI":
"""Initialize an OpenAI client from the predictor.
Only used for OpenAI API compatible services, such as Large Language Model
service from PAI QuickStart.
Args:
url_suffix (str, optional): URL suffix that will be appended to the
EAS service endpoint to form the base URL for the OpenAI client.
(Default "v1").
**kwargs: Additional keyword arguments for the OpenAI client.
Returns:
OpenAI: An OpenAI client object.
"""
if not is_package_available("openai"):
raise ImportError(
"openai package is not installed, install it with `pip install openai`."
)
if url_suffix.startswith("/"):
default_base_url = posixpath.join(self.endpoint, url_suffix[1:])
else:
default_base_url = posixpath.join(self.endpoint, url_suffix)
base_url = kwargs.pop("base_url", default_base_url)
api_key = kwargs.pop("api_key", self.access_token)
return OpenAI(base_url=base_url, api_key=api_key, **kwargs)
class WaitConfig(object):
"""WaitConfig is used to set polling configurations for waiting for asynchronous
requests to complete."""
def __init__(self, max_attempts: int = 0, interval: int = 5):
if interval <= 0:
raise ValueError("interval must be positive integer.")
self.max_attempts = max_attempts
self.interval = interval
class AsyncTask(object):
"""AsyncTask is a wrapper class for `concurrent.futures.Future` object that represents
a prediction call submitted to an async prediction service.
"""
def __init__(
self,
future: Future,
):
self.future = future
super(AsyncTask, self).__init__()
def result(self, timeout: Optional[float] = None):
"""
Returns the prediction result of the call.
Args:
timeout (float, optional): Timeout setting (Default None).
Returns:
The result of the prediction call.
"""
return self.future.result(timeout=timeout)
def done(self):
return self.future.done()
def exception(self, timeout: Optional[float] = None) -> Optional[Exception]:
return self.future.exception()
def running(self):
return self.future.running()
def cancel(self):
return self.future.cancel()
def cancelled(self):
return self.future.cancelled()
class AsyncPredictor(PredictorBase, _ServicePredictorMixin):
"""A class that facilitates making predictions to asynchronous prediction service.
Examples::
# Initialize an AsyncPredictor object using the name of a running service.
async_predictor = AsyncPredictor(service_name="example_service")
# Make a prediction with the service and get the prediction result.
resp = async_predictor.predict(data="YourPredictionData")
result = resp.wait()
# Make a prediction with async API.
import asyncio
result = asyncio.run(async_predictor.predict_async(data="YourPredictionData"))
"""
def __init__(
self,
service_name: str,
max_workers: Optional[int] = None,
endpoint_type: str = EndpointType.INTERNET,
serializer: Optional[SerializerBase] = None,
session: Optional[Session] = None,
):
"""Construct a `AsyncPredictor` object using an existing async prediction service.
Args:
service_name (str): Name of the existing prediction service.
max_workers (int): The maximum number of threads that can be used to
execute the given prediction calls.
endpoint_type (str): Selects the endpoint used by the predictor, which
should be one of `INTERNET` or `INTRANET`. The `INTERNET` endpoint type
means that the predictor calls the service over a public endpoint, while
the `INTRANET` endpoint type is over a VPC endpoint.
serializer (SerializerBase, optional): A serializer object that transforms
the input Python object for data transmission and deserialize the
response data to Python object.
session (Session, optional): A PAI session object used for communicating
with PAI service.
"""
super(AsyncPredictor, self).__init__(
service_name=service_name,
session=session or get_default_session(),
endpoint_type=endpoint_type,
serializer=serializer,
)
self._max_workers = max_workers
self.executor = ThreadPoolExecutor(max_workers=self._max_workers)
self._check()
@property
def max_workers(self):
return self._max_workers
@max_workers.setter
def max_workers(self, n: int):
if hasattr(self, "executor"):
logger.info("Waiting for all submitted tasks in the queue to complete...")
self.executor.shutdown()
self._max_workers = n
self.executor = ThreadPoolExecutor(max_workers=self._max_workers)
def __del__(self):
"""wait for all pending tasks to complete before exit."""
if hasattr(self, "executor"):
logger.info("Waiting for all pending tasks to complete...")
self.executor.shutdown()
super(AsyncPredictor, self).__del__()
def _check(self):
config = json.loads(self._service_api_object["ServiceConfig"])
if config.get("metadata", {}).get("type") != ServiceType.Async:
logger.warning(
"AsyncPredictor is not recommended to make prediction to a standard "
" prediction service."
)
def _get_result(
self, request_id: str
) -> Optional[Tuple[int, Dict[str, str], bytes]]:
resp = self._send_request(
method="GET",
path=_QUEUE_SERVICE_SINK_PATH,
params={
"requestId": request_id,
# _raw_ is false because we want to get the encapsulated prediction
# result in response body.
"_raw_": "false",
},
)
logger.debug(
"Poll prediction result: request_id=%s status_code=%s, content=%s",
request_id,
resp.status_code,
resp.content,
)
if resp.status_code == 204:
# Status code 204 means could not find prediction response for the specific
# request id.
return
# Raise exception if status code is not 2xx.
if resp.status_code // 100 != 2:
raise RuntimeError(
"Pulling prediction result failed: status_code={} content={}".format(
resp.status_code, resp.content.decode("utf-8")
)
)
return self._parse_encapsulated_response(resp.json()[0])
def _parse_encapsulated_response(self, data) -> Tuple[int, Dict[str, str], bytes]:
tags = data["tags"]
# If the status code from prediction service is not 200, a tag with
# key 'lastCode' will be added to the tags in response.
status_code = int(tags.get("lastCode", 200))
data = base64.b64decode(data["data"])
# currently, headers are not supported in async prediction service.
headers = dict()
return status_code, headers, data
async def _get_result_async(
self, request_id: str
) -> Optional[Tuple[int, Dict[str, str], bytes]]:
resp = await self._send_request_async(
method="GET",
path=_QUEUE_SERVICE_SINK_PATH,
params={
"requestId": request_id,
# _raw_ is false because we want to get the encapsulated prediction
# result in response body.
"_raw_": "false",
},
)
status_code = resp.status
content = await resp.read()
logger.debug(
"Get prediction result: request_id=%s status_code=%s, content=%s",
request_id,
status_code,
content,
)
if status_code == 204:
# Status code 204 means could not find prediction response for the specific
# request id.
return
if status_code // 100 != 2:
raise RuntimeError(
"Pulling prediction result failed: status_code={} content={}".format(
status_code, content.decode("utf-8")
)
)
data = (await resp.json())[0]
return self._parse_encapsulated_response(data)
def _poll_result(
self, request_id: str, wait_config: WaitConfig
) -> Tuple[int, Dict[str, str], bytes]:
# if max_attempts is negative or zero, then wait forever
attempts = -1 if wait_config.max_attempts <= 0 else wait_config.max_attempts
while attempts != 0:
attempts -= 1
result = self._get_result(request_id=request_id)
if not result:
time.sleep(wait_config.interval)
continue
status_code, headers, content = result
# check real prediction response
if status_code // 100 != 2:
raise PredictionException(
code=status_code,
message=f"Prediction failed: status_code={status_code}"
f" content={content.decode()}",
)
return status_code, headers, content
# Polling prediction result timeout.
raise RuntimeError(
f"Polling prediction result timeout: request_id={request_id}, "
f"total_time={wait_config.max_attempts * wait_config.interval}"
)
async def _poll_result_async(
self, request_id, wait_config: WaitConfig
) -> Tuple[int, Dict[str, str], bytes]:
# if max_attempts is negative or zero, then wait forever
attempts = -1 if wait_config.max_attempts <= 0 else wait_config.max_attempts
while attempts != 0:
attempts -= 1
result = await self._get_result_async(request_id)
if not result:
await asyncio.sleep(wait_config.interval)
continue
status_code, headers, content = result
# check real prediction response
if status_code // 100 != 2:
raise PredictionException(
f"Prediction failed: status_code={status_code} content={content.decode()}"
)
return status_code, headers, content
# Polling prediction result timeout.
raise RuntimeError(
f"Polling prediction result timeout: request_id={request_id}, "
f"total_time={wait_config.max_attempts * wait_config.interval}"
)
def _get_request_id(self, resp: requests.models.Response) -> str:
if resp.status_code // 100 != 2:
raise RuntimeError(
f"Send prediction request failed. status_code={resp.status_code} "
f"message={resp.text}"
)
if _QUEUE_SERVICE_REQUEST_ID_HEADER not in resp.headers:
logger.error(
f"Send prediction request failed. Missing request id."
f" status_code={resp.status_code} content={resp.text}"
)
raise RuntimeError("Missing request id in response header.")
request_id = resp.headers[_QUEUE_SERVICE_REQUEST_ID_HEADER]
logger.debug(
f"Send prediction request successfully. request_id={request_id}"
f" status_code={resp.status_code}",
)
return request_id
async def _get_request_id_async(self, resp: aiohttp.ClientResponse) -> str:
content = await resp.read()
if resp.status != 200:
raise RuntimeError(
"Send request to async prediction service failed: status_code={} "
"content={}".format(resp.status, content.decode("utf-8"))
)
if _QUEUE_SERVICE_REQUEST_ID_HEADER not in resp.headers:
logger.error(
f"Send prediction request failed. Missing request id."
f" status_code={resp.status} content={content.decode()}"
)
raise RuntimeError("Missing request id in response header.")
request_id = resp.headers[_QUEUE_SERVICE_REQUEST_ID_HEADER]
logger.debug(
f"Send prediction request successfully. request_id={request_id}"
f" status_code={resp.status}",
)
return request_id
def _predict_fn(
self,
data,
):
"""Make a prediction with the async prediction service."""
# serialize input data
data = self._handle_input(data)
resp = self._send_request(data=data)
request_id = self._get_request_id(resp)
logger.debug("Async prediction RequestId: ", request_id)
# poll prediction result
status, headers, content = self._poll_result(
request_id=request_id, wait_config=WaitConfig()
)
return self._handle_output(content)
def _wrap_callback_fn(self, cb: Callable):
"""Wrap the callback function to handle the prediction result."""
@functools.wraps(cb)
def _(future: Future):
return cb(future.result())
return _
def predict(
self,
data,
callback: Optional[Union[Callable, List[Callable]]] = None,
):
"""Make a prediction with the async prediction service.
The input data is serialized using the `serializer.serialize` method before it
is sent, and the response body is deserialized using the
`serializer.deserialize` method the prediction result returns.
Args:
data: The input data for the prediction. It will be serialized using the
serializer of the predictor before transmitted to the prediction
service.
callback (Union[Callable, List[Callable]], optional): A Callback function,
or a list of callback functions used to process the prediction result.
Returns:
AsyncTask: The task object that can be used to retrieve the prediction
result.
"""
self._post_init_serializer()
future = self.executor.submit(self._predict_fn, data)
if isinstance(callback, Callable):
callback = [callback]
if callback:
for cb in callback:
future.add_done_callback(self._wrap_callback_fn(cb))
return AsyncTask(future=future)
async def predict_async(self, data, wait_config: WaitConfig = WaitConfig()):
"""Make a prediction with the async prediction service.
The serializer object for the predictor is responsible for data transformation
when the 'predict' method is invoked. The input data is serialized using the
`serializer.serialize` method before it is sent, and the response is
deserialized using the `serializer.deserialize` method before the prediction
result returns.
Args:
data: The input data for the prediction. It will be serialized using the
serializer of the predictor before transmitted to the prediction
service.
wait_config (WaitConfig): A config object that controls the behavior of
polling the prediction result.
Returns:
Prediction result.
"""
self._post_init_serializer()
data = self._handle_input(data)
resp = await self._send_request_async(data=data)
request_id = await self._get_request_id_async(resp)
status_code, headers, content = await self._poll_result_async(
request_id=request_id, wait_config=wait_config
)
return self._handle_output(content)
def _raw_predict_fn(self, data, method, path, headers, **kwargs):
json_data, data = self._handle_raw_input(data)
resp = self._send_request(
path=path,
json=json_data,
data=data,
headers=self._build_headers(headers),
method=method,
**kwargs,
)
request_id = self._get_request_id(resp)
status, headers, content = self._poll_result(
request_id, wait_config=WaitConfig()
)
return RawResponse(status, headers, content)
def raw_predict(
self,
data: Any = None,
callback: Optional[Union[Callable, List[Callable], None]] = None,
method: str = "POST",
path: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
**kwargs,
) -> AsyncTask:
"""Make a prediction with the online prediction service.
Args:
data (Any): Input data to be sent to the prediction service. If it is a
file-like object, bytes, or string, it will be sent as the request body.
Otherwise, it will be treated as a JSON serializable object and sent as
JSON.
callback (Union[Callable, List[Callable]], optional): A Callback function,
or a list of callback functions used to process the prediction result.
path (str, optional): Path for the request to be sent to. If it is provided,
it will be appended to the endpoint URL (Default None).
headers (dict, optional): Request headers.
method (str, optional): Request method, default to 'POST'.
**kwargs: Additional keyword arguments for the request.
Returns:
AsyncTask: The task object that can be used to retrieve the prediction
result.
Examples:
from pai.predictor import AsyncPredictor, AsyncTask
predictor = AsyncPredictor()
task: AsyncTask = predictor.raw_predict(data="YourPredictionData")
print(task.result())
"""
future = self.executor.submit(
self._raw_predict_fn, data, method, path, headers, **kwargs
)
cbs = [callback] if isinstance(callback, Callable) else callback
if cbs:
for cb in cbs:
future.add_done_callback(self._wrap_callback_fn(cb))
return AsyncTask(future=future)
async def raw_predict_async(
self,
data,
wait_config: WaitConfig = WaitConfig(),
method: str = "POST",
headers: Optional[Dict[str, str]] = None,
path: Optional[str] = None,
**kwargs,
) -> RawResponse:
"""Make a prediction with the online prediction service.
Args:
data (Any): Input data to be sent to the prediction service. If it is a
file-like object, bytes, or string, it will be sent as the request body.
Otherwise, it will be treated as a JSON serializable object and sent as
JSON.
wait_config (WaitConfig): A config object that controls the behavior of
polling the prediction result.
path (str, optional): Path for the request to be sent to. If it is provided,
it will be appended to the endpoint URL (Default None).
headers (dict, optional): Request headers.
method (str, optional): Request method, default to 'POST'.
**kwargs: Additional keyword arguments for the request.
Returns:
RawResponse: Prediction result.
"""
if self.service_status not in ServiceStatus.completed_status():
self.wait_for_ready()
json_data, data = self._handle_raw_input(data)
resp = await self._send_request_async(
data=data,
method=method,
json=json_data,
path=path,
headers=headers,
**kwargs,
)
request_id = await self._get_request_id_async(resp)
# Polling the prediction result.
status_code, headers, content = await self._poll_result_async(
request_id=request_id, wait_config=wait_config
)
return self._handle_raw_output(status_code, headers, content)
class LocalPredictor(PredictorBase):
"""Perform prediction to a local service running with docker."""
def __init__(
self,
port: int,
container_id: Optional[str] = None,
serializer: Optional[SerializerBase] = None,
):
"""LocalPredictor initializer.
Args:
port (int): The port of the local service.
container_id (str, optional): The container id of the local service.
serializer (SerializerBase, optional): A serializer object that transforms.
"""
self.container_id = container_id
self.port = port
self.serializer = serializer or JsonSerializer()
self._container_run = (
self._build_container_run(container_id, port=port)
if self.container_id
else None
)
@classmethod
def _build_container_run(cls, container_id, port):
try:
import docker
except ImportError:
raise ImportError("Please install docker first: pip install docker")
client = docker.from_env()
container = client.containers.get(container_id)
return ContainerRun(container=container, port=port)
def predict(self, data) -> Any:
"""Perform prediction with the given data.
Args:
data: The data to be predicted.
"""
request_data = self.serializer.serialize(data=data)
response = requests.post(
url="http://127.0.0.1:{port}/".format(port=self._container_run.port),
data=request_data,
)
if response.status_code // 100 != 2:
raise PredictionException(
code=response.status_code,
message=response.content,
)
return self.serializer.deserialize(response.content)
def _build_headers(
self, headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
headers = headers or dict()
headers["User-Agent"] = http_user_agent(headers.get("User-Agent"))
return headers
def _build_url(self, path: Optional[str] = None):
url = "http://127.0.0.1:{}".format(self.port)
if path:
if path.startswith("/"):
path = path[1:]
url = posixpath.join(url, path)
return url
def raw_predict(
self,
data: Any = None,
path: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
method: str = "POST",
timeout: Optional[Union[float, Tuple[float, float]]] = None,
**kwargs,
) -> RawResponse:
"""Make a prediction with the online prediction service.
Args:
data (Any): Input data to be sent to the prediction service. If it is a
file-like object, bytes, or string, it will be sent as the request body.
Otherwise, it will be treated as a JSON serializable object and sent as
JSON.
path (str, optional): Path for the request to be sent to. If it is provided,
it will be appended to the endpoint URL (Default None).
headers (dict, optional): Request headers.
method (str, optional): Request method, default to 'POST'.
timeout(float, tuple(float, float), optional): Timeout setting for the
request (Default 10).
Returns:
RawResponse: Prediction response from the service.
Raises:
PredictionException: Raise if status code of the prediction response does
not equal 2xx.
"""
if isinstance(data, (IOBase, bytes, str)):
# if data is a file-like object, bytes, or string, it will be sent as
# request body
json_data, data = None, data
else:
# otherwise, it will be treated as a JSON serializable object and sent as
# JSON.
json_data, data = data, None
header = self._build_headers(headers=headers)
url = self._build_url(path)
resp = requests.request(
url=url,
json=json_data,
data=data,
headers=header,
method=method,
timeout=timeout,
**kwargs,
)
resp = RawResponse(
status_code=resp.status_code,
content=resp.content,
headers=dict(resp.headers),
)
if resp.status_code // 100 != 2:
raise PredictionException(resp.status_code, resp.content)
return resp
def delete_service(self):
"""Delete the docker container that running the service."""
if self._container_run:
self._container_run.stop()
def wait_for_ready(self):
self._container_run.wait_for_ready()
# ensure the server is ready.
self._wait_local_server_ready()
time.sleep(5)
def wait(self):
return self.wait_for_ready()
def _wait_local_server_ready(
self,
interval: int = 5,
):
"""Wait for the local model server to be ready."""
container_run = self._container_run
while True:
try:
# Check whether the container is still running.
if not container_run.is_running():
raise RuntimeError(
"Container exited unexpectedly, status: {}".format(
container_run.status
)
)
# Make a HEAD request to the server, just test for connection.
requests.head(
f"http://127.0.0.1:{container_run.port}/",
)
break
except requests.ConnectionError:
# ConnectionError means server is not ready.
logger.debug("Waiting for the container to be ready...")
time.sleep(interval)
continue