func()

in src/psearch/serving/internal/services/spanner_service.go [144:258]


func (s *SpannerService) HybridSearch(ctx context.Context, query string, limit int, minScore float64, alpha float64) ([]models.SearchResult, error) {
	startTime := time.Now()

	// Generate embeddings for the query
	embedding, err := s.embeddings.GenerateEmbedding(ctx, query)
	if err != nil {
		return nil, fmt.Errorf("failed to generate embedding: %v", err)
	}

	// Construct hybrid search SQL query
	// This combines vector similarity search with text search using the configured alpha value
	sql := `
		@{optimizer_version=7}
		WITH ann AS (
		SELECT offset + 1 AS rank, product_id, title, product_data
		FROM UNNEST(ARRAY(
			SELECT AS STRUCT product_id, title, product_data
			FROM products @{FORCE_INDEX=products_by_embedding}
			WHERE embedding IS NOT NULL
			ORDER BY APPROX_COSINE_DISTANCE(embedding, @query_embedding,
			OPTIONS=>JSON'{"num_leaves_to_search": 10}')
			LIMIT @limit)) WITH OFFSET AS offset
		),
		fts AS (
		SELECT offset + 1 AS rank, product_id, title, product_data
		FROM UNNEST(ARRAY(
			SELECT AS STRUCT product_id, title, product_data
			FROM products
			WHERE SEARCH(title_tokens, @query_text)
			ORDER BY SCORE(title_tokens, @query_text) DESC
			LIMIT @limit)) WITH OFFSET AS offset
		)
		SELECT 
			SUM(1 / (60 + rank)) AS rrf_score, 
			product_id,
			ANY_VALUE(title) AS title,
			ANY_VALUE(product_data) AS product_data 
		FROM ((
		SELECT rank, product_id, title, product_data
		FROM ann
		)
		UNION ALL (
		SELECT rank, product_id, title, product_data
		FROM fts
		))
		GROUP BY product_id
		ORDER BY rrf_score DESC
		LIMIT @limit;
	`

	// Create parameters
	params := map[string]interface{}{
		"query_embedding": embedding,
		"query_text":      query,
		"limit":           limit,
	}

	// Execute the query
	stmt := spanner.Statement{SQL: sql, Params: params}
	iter := s.client.Single().Query(ctx, stmt)
	defer iter.Stop()

	var results []models.SearchResult
	for {
		row, err := iter.Next()
		if err == iterator.Done {
			break
		}
		if err != nil {
			return nil, fmt.Errorf("error iterating through search results: %v", err)
		}

		var productIDInt int64
		var title string
		var productDataJSON spanner.NullJSON
		var hybridScore float64

		if err := row.Columns(&hybridScore, &productIDInt, &title, &productDataJSON); err != nil {
			return nil, fmt.Errorf("failed to scan search result: %v", err)
		}

		productID := fmt.Sprintf("%d", productIDInt)

		if !productDataJSON.Valid {
			continue
		}

		// Type assert productDataJSON.Value directly to map[string]interface{}
		productData, ok := productDataJSON.Value.(map[string]interface{})
		if !ok {
			// Log the actual type if the assertion fails
			log.Printf("DEBUG: Unexpected type for productDataJSON.Value in search result: %T", productDataJSON.Value)
			return nil, fmt.Errorf("failed to type assert product data from NullJSON.Value for search result")
		}

		// Skip if score is below minimum threshold
		if hybridScore < minScore {
			continue
		}

		// Transform to search result
		searchResult, err := s.transformToSearchResult(productID, productData, hybridScore)
		if err != nil {
			log.Printf("Warning: could not transform product %s: %v", productID, err)
			continue
		}

		results = append(results, searchResult)
	}

	elapsed := time.Since(startTime)
	log.Printf("Hybrid search completed in %s, found %d results", elapsed, len(results))

	return results, nil
}