in src/psearch/serving/internal/services/embedding_service.go [56:160]
func (s *EmbeddingService) GenerateEmbedding(ctx context.Context, text string) ([]float32, error) {
startTime := time.Now()
// Construct the API endpoint URL
url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict",
s.config.Region,
s.config.ProjectID,
s.config.Region,
s.config.GeminiModelName, // This needs to be the embedding model ID
)
// Construct the request body structure matching the REST API
requestPayload := struct {
Instances []struct {
Content string `json:"content"`
TaskType string `json:"task_type"` // Note: snake_case in REST API
} `json:"instances"`
}{
Instances: []struct {
Content string `json:"content"`
TaskType string `json:"task_type"`
}{
{Content: text, TaskType: "RETRIEVAL_QUERY"}, // Use appropriate task type
},
}
// Marshal the request payload to JSON
jsonBody, err := json.Marshal(requestPayload)
if err != nil {
return nil, fmt.Errorf("failed to marshal REST request body: %v", err)
}
log.Printf("DEBUG: Embedding Request Body: %s", string(jsonBody)) // Log request body
// Create the HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create REST http request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
// Execute the request using the authenticated client
log.Printf("DEBUG: Sending embedding request to %s", url)
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute REST http request: %v", err)
}
defer resp.Body.Close()
// Read the response body
responseBodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read REST response body: %v", err)
}
// Check for non-200 status codes
if resp.StatusCode != http.StatusOK {
log.Printf("ERROR: Embedding API request failed with status %d: %s", resp.StatusCode, string(responseBodyBytes))
// Attempt to parse standard Google API error structure
var apiError struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}
if json.Unmarshal(responseBodyBytes, &apiError) == nil && apiError.Error.Message != "" {
return nil, fmt.Errorf("embedding API error: %s (code %d, status %s)", apiError.Error.Message, apiError.Error.Code, apiError.Error.Status)
}
// Fallback error
return nil, fmt.Errorf("embedding API request failed with status %d", resp.StatusCode)
}
// Define the expected response structure
var responsePayload struct {
Predictions []struct {
Embeddings struct {
Values []float32 `json:"values"`
Statistics struct {
TokenCount int `json:"token_count"`
Truncated bool `json:"truncated"`
} `json:"statistics"`
} `json:"embeddings"`
} `json:"predictions"`
// DeployedModelID string `json:"deployedModelId"` // Optional
}
// Unmarshal the response JSON
if err := json.Unmarshal(responseBodyBytes, &responsePayload); err != nil {
log.Printf("ERROR: Failed to unmarshal embedding response: %s", string(responseBodyBytes))
return nil, fmt.Errorf("failed to unmarshal REST response body: %v", err)
}
// Extract the embedding values
if len(responsePayload.Predictions) == 0 || len(responsePayload.Predictions[0].Embeddings.Values) == 0 {
log.Printf("WARN: Embedding response contained no predictions or empty values: %+v", responsePayload)
return nil, fmt.Errorf("no embeddings returned from REST API")
}
embedding := responsePayload.Predictions[0].Embeddings.Values
// Log the time taken
elapsed := time.Since(startTime)
log.Printf("Generated embedding via REST in %s (dimension: %d)", elapsed, len(embedding))
return embedding, nil
}