in benchmarks/benchmark/tools/locust-load-inference/locust-docker/locust-tasks/load_data.py [0:0]
def load_test_prompts(gcs_path: str, tokenizer: PreTrainedTokenizerBase, max_prompt_len: int):
# strip the "gs://", split into respective paths
split_path = gcs_path[5:].split('/', 1)
bucket_name = split_path[0]
object_name = split_path[1]
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(object_name)
if not bucket.exists():
raise ValueError(
f"Cannot access gs://{bucket_name}, it may not exist or you may not have access to this bucket.")
if not blob.exists():
raise ValueError(
f"Cannot access {gcs_path}, it may not exist or you may not have access to this object.")
test_data = []
start = time.time()
with blob.open("r") as f:
for prompt in f:
prompt_token_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_token_ids)
if prompt_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > max_prompt_len:
# Prune too long sequences.
continue
test_data.append(prompt)
end = time.time()
total_time = end - start
logging.info(f"Filtered test prompts after {total_time} seconds.")
return test_data