easy_rec/python/inference/client/easyrec_request.py (33 lines of code) (raw):
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from eas_prediction.request import Request
from easy_rec.python.protos.predict_pb2 import PBRequest
from easy_rec.python.protos.predict_pb2 import PBResponse
# from eas_prediction.request import Response
class EasyrecRequest(Request):
"""Request for tensorflow services whose input data is in format of protobuf.
This class privide methods to fill generate PBRequest and parse PBResponse.
"""
def __init__(self, signature_name=None):
self.request_data = PBRequest()
self.signature_name = signature_name
def __str__(self):
return self.request_data
def set_signature_name(self, singature_name):
"""Set the signature name of the model.
Args:
singature_name: signature name of the model
"""
self.signature_name = singature_name
def add_feed(self, data, dbg_lvl=0):
if not isinstance(data, PBRequest):
self.request_data.ParseFromString(data)
else:
self.request_data = data
self.request_data.debug_level = dbg_lvl
def add_user_fea_flt(self, k, v):
self.request_data.user_features[k].float_feature = float(v)
def add_user_fea_s(self, k, v):
self.request_data.user_features[k].string_feature = str(v)
def set_faiss_neigh_num(self, neigh_num):
self.request_data.faiss_neigh_num = neigh_num
def keep_one_item_ids(self):
item_id = self.request_data.item_ids[0]
self.request_data.ClearField('item_ids')
self.request_data.item_ids.extend([item_id])
def to_string(self):
"""Serialize the request to string for transmission.
Returns:
the request data in format of string
"""
return self.request_data.SerializeToString()
def parse_response(self, response_data):
"""Parse the given response data in string format to the related TFResponse object.
Args:
response_data: the service response data in string format
Returns:
the TFResponse object related the request
"""
self.response = PBResponse()
self.response.ParseFromString(response_data)
return self.response