pkg/eas/tf_request.go (132 lines of code) (raw):

package eas import ( "github.com/golang/protobuf/proto" "github.com/alibaba/pairec/v2/pkg/eas/types/tf_predict_protos" ) // TFRequest class for tensorflow data and requests type TFRequest struct { RequestData tf_predict_protos.PredictRequest } // SetSignatureName set signature name for TensorFlow request func (tr *TFRequest) SetSignatureName(sigName string) { tr.RequestData.SignatureName = sigName } // AddFeedFloat32 function adds float values input data for TFRequest func (tr *TFRequest) AddFeedFloat32(inputName string, shape []int64, content []float32) { requestProto := tf_predict_protos.ArrayProto{ Dtype: TfType_DT_FLOAT, ArrayShape: &tf_predict_protos.ArrayShape{ Dim: shape, }, FloatVal: content, } if tr.RequestData.Inputs == nil { tr.RequestData.Inputs = make(map[string]*tf_predict_protos.ArrayProto) } tr.RequestData.Inputs[inputName] = &requestProto } // AddFeedFloat64 function adds double values input data for TFRequest func (tr *TFRequest) AddFeedFloat64(inputName string, shape []int64, content []float64) { requestProto := tf_predict_protos.ArrayProto{ Dtype: TfType_DT_DOUBLE, ArrayShape: &tf_predict_protos.ArrayShape{ Dim: shape, }, DoubleVal: content, } if tr.RequestData.Inputs == nil { tr.RequestData.Inputs = make(map[string]*tf_predict_protos.ArrayProto) } tr.RequestData.Inputs[inputName] = &requestProto } // AddFeedInt32 function adds int values input data for TFRequest func (tr *TFRequest) AddFeedInt32(inputName string, shape []int64, content []int32) { requestProto := tf_predict_protos.ArrayProto{ Dtype: TfType_DT_INT32, ArrayShape: &tf_predict_protos.ArrayShape{ Dim: shape, }, IntVal: content, } if tr.RequestData.Inputs == nil { tr.RequestData.Inputs = make(map[string]*tf_predict_protos.ArrayProto) } tr.RequestData.Inputs[inputName] = &requestProto } // AddFeedInt64 function adds int64 values input data for TFRequest func (tr *TFRequest) AddFeedInt64(inputName string, shape []int64, content []int64) { requestProto := tf_predict_protos.ArrayProto{ Dtype: TfType_DT_INT64, ArrayShape: &tf_predict_protos.ArrayShape{ Dim: shape, }, Int64Val: content, } if tr.RequestData.Inputs == nil { tr.RequestData.Inputs = make(map[string]*tf_predict_protos.ArrayProto) } tr.RequestData.Inputs[inputName] = &requestProto } // AddFeedBool function adds boolean values input data for TFRequest func (tr *TFRequest) AddFeedBool(inputName string, shape []int64, content []bool) { requestProto := tf_predict_protos.ArrayProto{ Dtype: TfType_DT_BOOL, ArrayShape: &tf_predict_protos.ArrayShape{ Dim: shape, }, BoolVal: content, } if tr.RequestData.Inputs == nil { tr.RequestData.Inputs = make(map[string]*tf_predict_protos.ArrayProto) } tr.RequestData.Inputs[inputName] = &requestProto } // AddFeedString function adds string values input data for TFRequest func (tr *TFRequest) AddFeedString(inputName string, shape []int64, content [][]byte) { requestProto := tf_predict_protos.ArrayProto{ Dtype: TfType_DT_STRING, ArrayShape: &tf_predict_protos.ArrayShape{ Dim: shape, }, StringVal: content, } if tr.RequestData.Inputs == nil { tr.RequestData.Inputs = make(map[string]*tf_predict_protos.ArrayProto) } tr.RequestData.Inputs[inputName] = &requestProto } // AddFetch adds output filter (outname) for TensorFlow request func (tr *TFRequest) AddFetch(outName string) { tr.RequestData.OutputFilter = append(tr.RequestData.OutputFilter, outName) } // ToString for interface func (tr TFRequest) ToString() (string, error) { reqdata, err := proto.Marshal(&tr.RequestData) if err != nil { return "", NewPredictError(-1, "", err.Error()) } return string(reqdata), nil } // TFResponse class for Pytf predicted results type TFResponse struct { Response tf_predict_protos.PredictResponse } // GetTensorShape returns []int64 slice as shape of tensor outindexed func (tresp *TFResponse) GetTensorShape(outputName string) []int64 { // return tresp.PredictResponse.Outputs[outputName].ArrayShape.Dim return tresp.Response.Outputs[outputName].ArrayShape.Dim } // GetFloatVal returns []float32 slice as output data func (tresp *TFResponse) GetFloatVal(outputName string) []float32 { return tresp.Response.Outputs[outputName].GetFloatVal() } // GetDoubleVal returns []float64 slice as output data func (tresp *TFResponse) GetDoubleVal(outputName string) []float64 { return tresp.Response.Outputs[outputName].GetDoubleVal() } // GetIntVal returns []int32 slice as output data func (tresp *TFResponse) GetIntVal(outputName string) []int32 { return tresp.Response.Outputs[outputName].GetIntVal() } // GetInt64Val returns []int64 slice as output data func (tresp *TFResponse) GetInt64Val(outputName string) []int64 { return tresp.Response.Outputs[outputName].GetInt64Val() } // GetBoolVal returns []bool slice as output data func (tresp *TFResponse) GetBoolVal(outputName string) []bool { return tresp.Response.Outputs[outputName].GetBoolVal() } // GetStringVal returns []string slice as output data func (tresp *TFResponse) GetStringVal(outputName string) [][]byte { return tresp.Response.Outputs[outputName].GetStringVal() } // Unmarshal for interface func (tresp *TFResponse) unmarshal(body []byte) error { bd := &tf_predict_protos.PredictResponse{} err := proto.Unmarshal(body, bd) if err != nil { return err } tresp.Response = *bd return nil }