algorithm/eas/tf_response.go (86 lines of code) (raw):

package eas import ( "fmt" "github.com/alibaba/pairec/v2/algorithm/response" "github.com/alibaba/pairec/v2/pkg/eas/types/tf_predict_protos" ) type tfResponse struct { score float64 scoreArr map[string]float64 multiValModule bool } func (r *tfResponse) GetScore() float64 { return r.score } func (r *tfResponse) GetScoreMap() map[string]float64 { return r.scoreArr } func (r *tfResponse) GetModuleType() bool { return r.multiValModule } func tfMutValResponseFunc(data interface{}) (ret []response.AlgoResponse, err error) { resp, ok := data.(*tf_predict_protos.PredictResponse) if !ok { err = fmt.Errorf("invalid data type, %v", data) return } var response []map[string]float64 for name, arrayProto := range resp.GetOutputs() { for i, val := range arrayProto.FloatVal { if i >= len(response) { response = append(response, map[string]float64{name: float64(val)}) } else { response[i][name] = float64(val) } } } for _, v := range response { ret = append(ret, &tfResponse{scoreArr: v, multiValModule: true}) } return } func tfResponseFunc(data interface{}) (ret []response.AlgoResponse, err error) { resp, ok := data.(*tf_predict_protos.PredictResponse) if !ok { err = fmt.Errorf("invalid data type, %v", data) return } for _, arrayProto := range resp.GetOutputs() { for _, val := range arrayProto.FloatVal { ret = append(ret, &tfResponse{score: float64(val)}) } break } return } func tfUseEmbResponseFunc(data interface{}) (ret []response.AlgoResponse, err error) { resp, ok := data.(*tf_predict_protos.PredictResponse) if !ok { err = fmt.Errorf("invalid data type, %v", data) return } for _, arrayProto := range resp.GetOutputs() { if arrayProto.GetDtype() == tf_predict_protos.ArrayDataType_DT_STRING { if len(arrayProto.GetStringVal()) > 0 { ret = append(ret, &TFUserEmbResponse{userEmb: string(arrayProto.GetStringVal()[0])}) return } } } return } type TFUserEmbResponse struct { userEmb string } func (r *TFUserEmbResponse) GetScore() float64 { return 0 } func (r *TFUserEmbResponse) GetScoreMap() map[string]float64 { return nil } func (r *TFUserEmbResponse) GetModuleType() bool { return false } func (r *TFUserEmbResponse) GetUserEmb() string { return r.userEmb }