tools/vector-search-load-testing-framework/locust_tests/locust.py (455 lines of code) (raw):
"""Locust file for load testing Vector Search endpoints (both public HTTP and private PSC/gRPC)."""
import random
import time
from typing import Any, Callable
import google.auth
import google.auth.transport.requests
import google.auth.transport.grpc
from google.cloud.aiplatform_v1 import MatchServiceClient
from google.cloud.aiplatform_v1 import FindNeighborsRequest
from google.cloud.aiplatform_v1 import IndexDatapoint
from google.cloud.aiplatform_v1.services.match_service.transports import grpc as match_transports_grpc
import grpc
import grpc.experimental.gevent as grpc_gevent
import grpc_interceptor
import locust
from locust import env, FastHttpUser, User, task, events, wait_time, tag
import logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s: %(message)s')
# Patch grpc so that it uses gevent instead of asyncio
grpc_gevent.init_gevent()
# gRPC channel cache
_GRPC_CHANNEL_CACHE = {}
class LocustInterceptor(grpc_interceptor.ClientInterceptor):
"""Interceptor for Locust which captures response details."""
def __init__(self, environment, *args, **kwargs):
"""Initializes the interceptor with the specified environment."""
super().__init__(*args, **kwargs)
self.env = environment
def intercept(
self,
method: Callable[[Any, grpc.ClientCallDetails], Any],
request_or_iterator: Any,
call_details: grpc.ClientCallDetails,
) -> Any:
"""Intercepts message to store RPC latency and response size."""
response = None
exception = None
end_perf_counter = None
response_length = 0
start_perf_counter = time.perf_counter()
try:
# Response type
# * Unary: `grpc._interceptor._UnaryOutcome`
# * Streaming: `grpc._channel._MultiThreadedRendezvous`
response_or_responses = method(request_or_iterator, call_details)
end_perf_counter = time.perf_counter()
if isinstance(response_or_responses, grpc._channel._Rendezvous):
responses = list(response_or_responses)
# Re-write perf counter to account for time taken to receive all messages.
end_perf_counter = time.perf_counter()
# Total length = sum(messages).
total_length = 0
for message in responses:
message_pb = message.__class__.pb(message)
response_length = message_pb.ByteSize()
total_length += response_length
# Re-write response to return the actual responses since above logic has
# consumed all responses.
def yield_responses():
for rsp in responses:
yield rsp
response_or_responses = yield_responses()
else:
response = response_or_responses
# Unary
message = response.result()
message_pb = message.__class__.pb(message)
response_length = message_pb.ByteSize()
except grpc.RpcError as e:
exception = e
end_perf_counter = time.perf_counter()
self.env.events.request.fire(
request_type='grpc',
name=call_details.method,
response_time=(end_perf_counter - start_perf_counter) * 1000,
response_length=response_length,
response=response_or_responses,
context=None,
exception=exception,
)
return response_or_responses
def _create_grpc_auth_channel(host: str) -> grpc.Channel:
"""Create a gRPC channel with SSL and auth."""
credentials, _ = google.auth.default()
request = google.auth.transport.requests.Request()
CHANNEL_OPTIONS = [
('grpc.use_local_subchannel_pool', True),
]
return google.auth.transport.grpc.secure_authorized_channel(
credentials,
request,
host,
ssl_credentials=grpc.ssl_channel_credentials(),
options=CHANNEL_OPTIONS,
)
def _cached_grpc_channel(host: str,
auth: bool,
cache: bool = True) -> grpc.Channel:
"""Return a cached gRPC channel for the given host and auth type."""
key = (host, auth)
if cache and key in _GRPC_CHANNEL_CACHE:
return _GRPC_CHANNEL_CACHE[key]
new_channel = (_create_grpc_auth_channel(host)
if auth else grpc.insecure_channel(host))
if not cache:
return new_channel
_GRPC_CHANNEL_CACHE[key] = new_channel
return _GRPC_CHANNEL_CACHE[key]
def intercepted_cached_grpc_channel(
host: str,
auth: bool,
env: locust.env.Environment,
cache: bool = True,
) -> grpc.Channel:
"""Return a intercepted gRPC channel for the given host and auth type."""
channel = _cached_grpc_channel(host, auth=auth, cache=cache)
interceptor = LocustInterceptor(environment=env)
return grpc.intercept_channel(channel, interceptor)
# Create a global config class that will be used throughout the application
class Config:
"""Singleton configuration class that loads from config file just once."""
_instance = None
def __new__(cls, config_file_path=None):
if cls._instance is None:
cls._instance = super(Config, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config_file_path=None):
if self._initialized:
return
if config_file_path:
self._load_config(config_file_path)
self._initialized = True
# Determine endpoint access type from configuration
self._determine_endpoint_access_type()
logging.info(
f"Loaded configuration: ENDPOINT_ACCESS_TYPE={self.endpoint_access_type}, "
f"PSC_ENABLED={self.psc_enabled}, MATCH_GRPC_ADDRESS={self.match_grpc_address}, "
f"ENDPOINT_HOST={self.endpoint_host}, PROJECT_NUMBER={self.project_number}"
)
def _load_config(self, file_path):
"""Load configuration from a bash-style config file."""
self.config = {}
with open(file_path, 'r') as f:
for line in f:
# Skip comments and empty lines
line = line.strip()
if not line or line.startswith('#'):
continue
# Parse variable assignment
if '=' in line:
key, value = line.split('=', 1)
key = key.strip()
value = value.strip()
# Remove surrounding quotes if present
if (value.startswith('"') and
value.endswith('"')) or (value.startswith("'") and
value.endswith("'")):
value = value[1:-1]
self.config[key] = value
# Set attributes from the config
self.project_id = self.config.get('PROJECT_ID')
self.project_number = self.config.get('PROJECT_NUMBER', self.project_id)
self.dimensions = int(self.config.get('INDEX_DIMENSIONS', 768))
self.deployed_index_id = self.config.get('DEPLOYED_INDEX_ID')
self.index_endpoint_id = self.config.get('INDEX_ENDPOINT_ID')
self.endpoint_host = self.config.get('ENDPOINT_HOST')
# Support both old and new config formats
# New format: ENDPOINT_ACCESS_TYPE
self.endpoint_access_type = self.config.get('ENDPOINT_ACCESS_TYPE')
# Old format: PSC_ENABLED
self.psc_enabled = self.config.get('PSC_ENABLED',
'false').lower() in ('true', 'yes',
'1')
# PSC Configuration
self.match_grpc_address = self.config.get('MATCH_GRPC_ADDRESS')
self.service_attachment = self.config.get('SERVICE_ATTACHMENT')
self.psc_ip_address = self.config.get('PSC_IP_ADDRESS')
# Embedding configuration
self.sparse_embedding_num_dimensions = int(
self.config.get('SPARSE_EMBEDDING_NUM_DIMENSIONS', 0))
self.sparse_embedding_num_dimensions_with_values = int(
self.config.get('SPARSE_EMBEDDING_NUM_DIMENSIONS_WITH_VALUES', 0))
self.num_neighbors = int(self.config.get('NUM_NEIGHBORS', 20))
self.num_embeddings_per_request = int(
self.config.get('NUM_EMBEDDINGS_PER_REQUEST', 1))
self.return_full_datapoint = self.config.get(
'RETURN_FULL_DATAPOINT', 'False').lower() in ('true', 'yes', '1')
# Network configuration
self.network_name = self.config.get('NETWORK_NAME', 'default')
# If we have PSC_IP_ADDRESS but not MATCH_GRPC_ADDRESS, construct it
if self.psc_ip_address and not self.match_grpc_address:
self.match_grpc_address = f"{self.psc_ip_address}"
# Get a clean numeric ID from the full endpoint ID
self.endpoint_id_numeric = None
if self.index_endpoint_id and "/" in self.index_endpoint_id:
self.endpoint_id_numeric = self.index_endpoint_id.split("/")[-1]
else:
self.endpoint_id_numeric = self.index_endpoint_id
def _determine_endpoint_access_type(self):
"""Determine the endpoint access type from configuration."""
# If ENDPOINT_ACCESS_TYPE is directly specified, use it
if self.endpoint_access_type:
# Ensure it's one of the valid options
if self.endpoint_access_type not in [
"public", "vpc_peering", "private_service_connect"
]:
logging.warning(
f"Invalid ENDPOINT_ACCESS_TYPE '{self.endpoint_access_type}', defaulting to 'public'"
)
self.endpoint_access_type = "public"
else:
# Otherwise, derive it from PSC_ENABLED
if self.psc_enabled:
self.endpoint_access_type = "private_service_connect"
logging.info(
"Derived endpoint_access_type='private_service_connect' from PSC_ENABLED=true"
)
else:
self.endpoint_access_type = "public"
logging.info(
"Derived endpoint_access_type='public' from PSC_ENABLED=false"
)
def get(self, key, default=None):
"""Get a configuration value by key."""
return getattr(self, key.lower(), self.config.get(key, default))
# Load the config once at startup
config = Config('./locust_config.env')
# Determine if we're using gRPC or HTTP based on endpoint_access_type
USE_GRPC = config.endpoint_access_type in [
"private_service_connect", "vpc_peering"
]
logging.info(
f"Using gRPC mode: {USE_GRPC} based on endpoint_access_type={config.endpoint_access_type}"
)
@events.init_command_line_parser.add_listener
def _(parser):
"""Add command line arguments to the Locust environment."""
# Add user-focused test parameters
parser.add_argument(
"--num-neighbors",
type=int,
default=config.num_neighbors,
help="Number of nearest neighbors to find in each query")
# Add QPS per user control
parser.add_argument(
"--qps-per-user",
type=int,
default=10,
help=
('The QPS each user should target. Locust will try to maintain this rate, '
'but if latency is high, actual QPS may be lower.'),
)
# Advanced parameters
parser.add_argument(
"--fraction-leaf-nodes-to-search-override",
type=float,
default=0.0,
help=
"Advanced: Fraction of leaf nodes to search (0.0-1.0). Higher values increase recall but reduce performance."
)
parser.add_argument(
"--return-full-datapoint",
action="store_true",
default=config.return_full_datapoint,
help=
"Whether to return full datapoint content with search results. Increases response size but provides complete vector data."
)
@events.init.add_listener
def on_locust_init(environment, **kwargs):
"""Set up the host and tags based on configuration."""
# Determine test mode based on endpoint access type
is_grpc_mode = config.endpoint_access_type in [
"private_service_connect", "vpc_peering"
]
# Set default tags based on endpoint access type if no tags were specified
if hasattr(environment.parsed_options,
'tags') and not environment.parsed_options.tags:
if is_grpc_mode:
environment.parsed_options.tags = ['grpc']
logging.info(
"Auto-setting tags to 'grpc' based on endpoint access type 'private_service_connect'"
)
else:
environment.parsed_options.tags = ['http']
logging.info(
f"Auto-setting tags to 'http' based on endpoint access type '{config.endpoint_access_type}'"
)
# Set host based on endpoint access type if no host was specified
if not environment.host:
if is_grpc_mode:
# PSC/gRPC mode
grpc_address = config.match_grpc_address
if grpc_address:
logging.info(
f"Auto-setting host to gRPC address: {grpc_address}")
environment.host = grpc_address
else:
logging.warning(
"No MATCH_GRPC_ADDRESS found in configuration, host must be specified manually for PSC/gRPC mode"
)
else:
# HTTP mode
endpoint_host = config.endpoint_host
if endpoint_host:
host = f"https://{endpoint_host}"
logging.info(f"Auto-setting host to HTTP endpoint: {host}")
environment.host = host
else:
logging.warning(
"No ENDPOINT_HOST found in configuration, host must be specified manually for HTTP mode"
)
# Base class with common functionality
class BaseVectorSearchUser:
"""Base class with common functionality for vector search users."""
def __init__(self, environment: env.Environment):
# Read technical parameters from config
self.deployed_index_id = config.deployed_index_id
self.index_endpoint_id = config.index_endpoint_id
self.project_id = config.project_id
self.project_number = config.project_number
self.dimensions = config.dimensions
self.endpoint_id_numeric = config.endpoint_id_numeric
# Store parsed options needed for requests
self.num_neighbors = environment.parsed_options.num_neighbors
self.fraction_leaf_nodes_to_search_override = environment.parsed_options.fraction_leaf_nodes_to_search_override
self.return_full_datapoints = environment.parsed_options.return_full_datapoint
def generate_random_vector(self, dimensions):
"""Generate a random vector with the specified dimensions."""
return [random.randint(-1000000, 1000000) for _ in range(dimensions)]
def generate_sparse_embedding(self):
"""Generate random sparse embedding based on configuration."""
values = [
random.uniform(-1.0, 1.0)
for _ in range(config.sparse_embedding_num_dimensions_with_values)
]
dimensions = random.sample(
range(config.sparse_embedding_num_dimensions),
config.sparse_embedding_num_dimensions_with_values)
return values, dimensions
class VectorSearchHttpUser(FastHttpUser):
"""HTTP-based Vector Search user using FastHttpUser."""
abstract = True # This is a abstract base class
def __init__(self, environment: env.Environment):
super().__init__(environment)
# Initialize base functionality
self.base = BaseVectorSearchUser(environment)
# Set up QPS-based wait time if specified
user_qps = environment.parsed_options.qps_per_user
if user_qps > 0:
# Use constant throughput based on QPS setting
def wait_time_fn():
fn = wait_time.constant_throughput(user_qps)
return fn(self)
self.wait_time = wait_time_fn
# Set up HTTP authentication
self.credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"])
self.auth_req = google.auth.transport.requests.Request()
self.credentials.refresh(self.auth_req)
self.token_refresh_time = time.time(
) + 3500 # Refresh after ~58 minutes
self.headers = {
"Authorization": "Bearer " + self.credentials.token,
"Content-Type": "application/json",
}
# Build the endpoint URL
self.public_endpoint_url = f"/v1/projects/{self.base.project_number}/locations/us-central1/indexEndpoints/{self.base.endpoint_id_numeric}:findNeighbors"
# Build the base request
self.request = {
"deployedIndexId": self.base.deployed_index_id,
}
self.dp = {
"datapointId": "0",
}
self.query = {
"datapoint": self.dp,
"neighborCount": self.base.num_neighbors,
}
# Add optional parameters if specified
if self.base.fraction_leaf_nodes_to_search_override > 0:
self.query[
"fractionLeafNodesToSearchOverride"] = self.base.fraction_leaf_nodes_to_search_override
self.request["queries"] = [self.query]
logging.info("HTTP client initialized")
def on_start(self):
"""Called when a user starts."""
# Ensure token is valid at start
self.credentials.refresh(self.auth_req)
self.headers["Authorization"] = "Bearer " + self.credentials.token
self.token_refresh_time = time.time() + 3500
@task
@tag('http')
def http_find_neighbors(self):
"""Execute a Vector Search query using HTTP."""
# Check if token needs refreshing
if time.time() > self.token_refresh_time:
try:
self.credentials.refresh(self.auth_req)
self.headers[
"Authorization"] = "Bearer " + self.credentials.token
self.token_refresh_time = time.time() + 3500
logging.debug("OAuth token refreshed preemptively")
except Exception as e:
logging.error(f"Failed to refresh token: {str(e)}")
# Handle sparse embedding case
if (config.sparse_embedding_num_dimensions > 0 and
config.sparse_embedding_num_dimensions_with_values > 0 and
config.sparse_embedding_num_dimensions_with_values <=
config.sparse_embedding_num_dimensions):
values, dimensions = self.base.generate_sparse_embedding()
self.request["queries"][0]["datapoint"]["sparseEmbedding"] = {
"values": values,
"dimensions": dimensions
}
else:
# Standard feature vector case
self.request["queries"][0]["datapoint"][
"featureVector"] = self.base.generate_random_vector(
self.base.dimensions)
# Set return_full_datapoint flag based on the parameter
self.request["queries"][0][
"returnFullDatapoint"] = self.environment.parsed_options.return_full_datapoint
# Send the request using FastHttpUser
with self.client.request(
"POST",
url=self.public_endpoint_url,
json=self.request,
catch_response=True,
headers=self.headers,
) as response:
if response.status_code == 401:
# Refresh token on auth error
self.credentials.refresh(self.auth_req)
self.headers[
"Authorization"] = "Bearer " + self.credentials.token
self.token_refresh_time = time.time() + 3500
response.failure("Authentication failure, token refreshed")
elif response.status_code == 403:
# Log detailed error for permission issues
error_msg = f"Permission denied: {response.text}"
response.failure(error_msg)
logging.error(f"HTTP 403 error: {response.text}")
elif response.status_code != 200:
# Mark failed responses
response.failure(
f"Failed with status code: {response.status_code}, body: {response.text}"
)
class VectorSearchGrpcUser(User):
"""gRPC-based Vector Search user."""
abstract = True # This is a abstract base class
def __init__(self, environment: env.Environment):
super().__init__(environment)
# Initialize base functionality
self.base = BaseVectorSearchUser(environment)
# Set up QPS-based wait time if specified
user_qps = environment.parsed_options.qps_per_user
if user_qps > 0:
# Use constant throughput based on QPS setting
def wait_time_fn():
fn = wait_time.constant_throughput(user_qps)
return fn(self)
self.wait_time = wait_time_fn
# Get the PSC address from the config
self.match_grpc_address = config.match_grpc_address
# Validate configuration
if not self.match_grpc_address:
raise ValueError(
"MATCH_GRPC_ADDRESS must be provided for PSC/gRPC connections")
logging.info(f"Using PSC/gRPC address: {self.match_grpc_address}")
# Create a gRPC channel with interceptor
channel = intercepted_cached_grpc_channel(
self.match_grpc_address,
auth=False, # PSC connections don't need auth
env=environment)
# Create the client
self.grpc_client = MatchServiceClient(
transport=match_transports_grpc.MatchServiceGrpcTransport(
channel=channel))
logging.info("gRPC client initialized")
@task
@tag('grpc')
def grpc_find_neighbors(self):
"""Execute a Vector Search query using gRPC."""
# Create datapoint based on embedding type
if (config.sparse_embedding_num_dimensions > 0 and
config.sparse_embedding_num_dimensions_with_values > 0 and
config.sparse_embedding_num_dimensions_with_values <=
config.sparse_embedding_num_dimensions):
# Sparse embedding case
values, dimensions = self.base.generate_sparse_embedding()
datapoint = IndexDatapoint(datapoint_id='0',
sparse_embedding={
'dimensions': dimensions,
'values': values
})
else:
# Dense embedding case
datapoint = IndexDatapoint(
datapoint_id="0",
feature_vector=self.base.generate_random_vector(
self.base.dimensions))
# Create a query
query = FindNeighborsRequest.Query(
datapoint=datapoint,
neighbor_count=self.base.num_neighbors,
)
# Add optional parameters if specified
if self.base.fraction_leaf_nodes_to_search_override > 0:
query.fraction_leaf_nodes_to_search_override = self.base.fraction_leaf_nodes_to_search_override
# Create the request - use the proper format with project number
index_endpoint = f"projects/{self.base.project_number}/locations/us-central1/indexEndpoints/{self.base.endpoint_id_numeric}"
request = FindNeighborsRequest(
index_endpoint=index_endpoint,
deployed_index_id=self.base.deployed_index_id,
queries=[query],
return_full_datapoint=self.environment.parsed_options.
return_full_datapoint,
)
# The interceptor will handle performance metrics automatically
try:
self.grpc_client.find_neighbors(request)
except Exception as e:
logging.error(f"Error in gRPC call: {str(e)}")
raise # The interceptor will handle the error reporting
# Concrete implementation classes that dynamically set their abstract attribute
# based on the endpoint access type (grpc vs http)
class HttpVectorSearchUser(VectorSearchHttpUser):
"""Concrete HTTP-based Vector Search user class."""
# Dynamically set abstract based on the endpoint access type
# For HTTP endpoints, set abstract=False (available)
# For gRPC endpoints, set abstract=True (unavailable)
abstract = USE_GRPC # abstract=True if using gRPC, abstract=False if using HTTP
def __init__(self, environment):
super().__init__(environment)
logging.info(
f"HttpVectorSearchUser initialized with abstract={self.abstract}")
class GrpcVectorSearchUser(VectorSearchGrpcUser):
"""Concrete gRPC-based Vector Search user class."""
# Opposite of HttpVectorSearchUser
# For gRPC endpoints, set abstract=False (available)
# For HTTP endpoints, set abstract=True (unavailable)
abstract = not USE_GRPC # abstract=True if using HTTP, abstract=False if using gRPC
def __init__(self, environment):
super().__init__(environment)
logging.info(
f"GrpcVectorSearchUser initialized with abstract={self.abstract}")
# Log which class is being used
if USE_GRPC:
logging.info(
"Using gRPC mode, GrpcVectorSearchUser is active and HttpVectorSearchUser is abstract"
)
else:
logging.info(
"Using HTTP mode, HttpVectorSearchUser is active and GrpcVectorSearchUser is abstract"
)