def get_batch_trial_with_repeated_arms()

in ax/utils/testing/core_stubs.py [0:0]


def get_batch_trial_with_repeated_arms(num_repeated_arms: int) -> BatchTrial:
    """Create a batch that contains both new arms and N arms from the last
    existed trial in the experiment. Where N is equal to the input argument
    'num_repeated_arms'.
    """
    experiment = get_experiment_with_batch_trial()
    if len(experiment.trials) > 0:
        # Get last (previous) trial.
        prev_trial = experiment.trials[len(experiment.trials) - 1]
        # Take the first N arms, where N is num_repeated_arms.

        if len(prev_trial.arms) < num_repeated_arms:
            logger.warning(
                "There are less arms in the previous trial than the value of "
                "input parameter 'num_repeated_arms'. Thus all the arms from "
                "the last trial will be repeated in the new trial."
            )
        prev_arms = prev_trial.arms[:num_repeated_arms]
        if isinstance(prev_trial, BatchTrial):
            prev_weights = prev_trial.weights[:num_repeated_arms]
        else:
            prev_weights = [1] * len(prev_arms)
    else:
        raise Exception(
            "There are no previous trials in this experiment. Thus the new "
            "batch was not created as no repeated arms could be added."
        )

    # Create new (next) arms.
    next_arms = get_arms_from_dict(get_arm_weights2())
    next_weights = get_weights_from_dict(get_arm_weights2())

    # Add num_repeated_arms to the new trial.
    arms = prev_arms + next_arms
    # pyre-fixme[6]: Expected `List[int]` for 1st param but got `List[float]`.
    weights = prev_weights + next_weights
    batch = experiment.new_batch_trial()
    batch.add_arms_and_weights(arms=arms, weights=weights, multiplier=1)
    batch.runner = SyntheticRunner()
    batch.set_status_quo_with_weight(status_quo=arms[0], weight=0.5)
    return batch