in ax/utils/testing/core_stubs.py [0:0]
def get_batch_trial_with_repeated_arms(num_repeated_arms: int) -> BatchTrial:
"""Create a batch that contains both new arms and N arms from the last
existed trial in the experiment. Where N is equal to the input argument
'num_repeated_arms'.
"""
experiment = get_experiment_with_batch_trial()
if len(experiment.trials) > 0:
# Get last (previous) trial.
prev_trial = experiment.trials[len(experiment.trials) - 1]
# Take the first N arms, where N is num_repeated_arms.
if len(prev_trial.arms) < num_repeated_arms:
logger.warning(
"There are less arms in the previous trial than the value of "
"input parameter 'num_repeated_arms'. Thus all the arms from "
"the last trial will be repeated in the new trial."
)
prev_arms = prev_trial.arms[:num_repeated_arms]
if isinstance(prev_trial, BatchTrial):
prev_weights = prev_trial.weights[:num_repeated_arms]
else:
prev_weights = [1] * len(prev_arms)
else:
raise Exception(
"There are no previous trials in this experiment. Thus the new "
"batch was not created as no repeated arms could be added."
)
# Create new (next) arms.
next_arms = get_arms_from_dict(get_arm_weights2())
next_weights = get_weights_from_dict(get_arm_weights2())
# Add num_repeated_arms to the new trial.
arms = prev_arms + next_arms
# pyre-fixme[6]: Expected `List[int]` for 1st param but got `List[float]`.
weights = prev_weights + next_weights
batch = experiment.new_batch_trial()
batch.add_arms_and_weights(arms=arms, weights=weights, multiplier=1)
batch.runner = SyntheticRunner()
batch.set_status_quo_with_weight(status_quo=arms[0], weight=0.5)
return batch