pyrit/prompt_target/batch_helper.py (39 lines of code) (raw):
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import asyncio
from typing import Any, Callable, Sequence
from pyrit.prompt_target import PromptTarget
def _get_chunks(*args, batch_size: int):
"""
Helper function utilized during prompt batching to chunk based off of size.
Args:
*args: Arguments to chunk; each argument should be a list
batch_size (int): Batch size
"""
if len(args) == 0:
raise ValueError("No arguments provided to chunk.")
for arg in args[1:]:
if len(arg) != len(args[0]):
raise ValueError("All arguments must have the same length.")
for i in range(0, len(args[0]), batch_size):
yield [arg[i : i + batch_size] for arg in args]
def _validate_rate_limit_parameters(prompt_target: PromptTarget, batch_size: int):
"""
Helper function to validate the constraints between Rate Limit (Requests Per Minute)
and batch size.
Args:
prompt_target (PromptTarget): Target to validate
batch_size (int): Batch size
Raises:
ValueError: When rate limit RPM is specified for the target and batching is not adjusted to 1.
"""
exc_message = "Batch size must be configured to 1 for the target requests per minute value to be respected."
if prompt_target and prompt_target._max_requests_per_minute and batch_size != 1:
raise ValueError(exc_message)
async def batch_task_async(
*,
prompt_target: PromptTarget,
batch_size: int,
items_to_batch: Sequence[Sequence[Any]],
task_func: Callable,
task_arguments: list[str],
**task_kwargs,
):
"""
Performs provided task in batches and validates parameters using helpers.
Args:
prompt_target(PromptTarget): Target to validate
batch_size (int): Batch size
items_to_batch (list[list[Any]]): Lists of items to batch
task_func (Callable): Task to perform in batches
task_arguments (list[str]): Name of arguments to assign lists of items to
**task_kwargs: Any other keyword arguments that task needs
Returns:
responses(list): List of results from the batched function
"""
responses = []
_validate_rate_limit_parameters(prompt_target=prompt_target, batch_size=batch_size)
if len(items_to_batch) == 0 or len(items_to_batch[0]) == 0:
raise ValueError("No items to batch.")
if len(items_to_batch) != len(task_arguments):
raise ValueError("Number of lists of items to batch must match number of task arguments.")
for task_args in _get_chunks(*items_to_batch, batch_size=batch_size):
tasks = []
for batch_index in range(len(task_args[0])):
for arg_index, task_argument in enumerate(task_arguments):
task_kwargs[task_argument] = task_args[arg_index][batch_index]
tasks.append(task_func(**task_kwargs))
batch_results = await asyncio.gather(*tasks)
responses.extend(batch_results)
return responses