def gen_random_num()

in benchmark/embedding/benchmark.py [0:0]


def gen_random_num(param: int, count: int, distribution: str) -> np.ndarray:
    """Generate random values based on the specified distribution.

    Args:
        param: The maximum value for the distribution parameter (max length or batch size)
        count: Number of values to generate
        distribution: Type of distribution - "uniform", "normal", or "fixed"

    Returns:
        A single integer if count=1, otherwise a numpy array of integers

    Raises:
        ValueError: If an invalid distribution is specified
    """
    if distribution == "uniform":
        # Generate values following uniform distribution between 1 and param
        values = np.random.randint(1, param + 1, count)
    elif distribution == "normal":
        # Generate values following normal distribution with mean at param/2
        # Standard deviation is param/4, clipped to range [1, param]
        values = np.random.normal(param / 2, param / 4, count)
        values = np.clip(values, 1, param).astype(int)
    elif distribution == "fixed":
        # Generate constant values equal to param
        values = np.full(count, param)
    else:
        raise ValueError(f"Invalid distribution: {distribution}")

    # Return a single value if count=1, otherwise the full array
    if count == 1:
        return values[0]
    return values