def load_test_prompts()

in benchmarks/benchmark/tools/locust-load-inference/locust-docker/locust-tasks/load_data.py [0:0]


def load_test_prompts(gcs_path: str, tokenizer: PreTrainedTokenizerBase, max_prompt_len: int):
    # strip the "gs://", split into respective paths
    split_path = gcs_path[5:].split('/', 1)
    bucket_name = split_path[0]
    object_name = split_path[1]
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(object_name)

    if not bucket.exists():
        raise ValueError(
            f"Cannot access gs://{bucket_name}, it may not exist or you may not have access to this bucket.")
    if not blob.exists():
        raise ValueError(
            f"Cannot access {gcs_path}, it may not exist or you may not have access to this object.")

    test_data = []
    start = time.time()
    with blob.open("r") as f:
        for prompt in f:
            prompt_token_ids = tokenizer(prompt).input_ids
            prompt_len = len(prompt_token_ids)
            if prompt_len < 4:
                # Prune too short sequences.
                # This is because TGI causes errors when the input or output length
                # is too short.
                continue
            if prompt_len > max_prompt_len:
                # Prune too long sequences.
                continue
            test_data.append(prompt)
    end = time.time()
    total_time = end - start
    logging.info(f"Filtered test prompts after {total_time} seconds.")
    return test_data