def get_perturbation_transform()

in src/fmeval/eval_algorithms/semantic_robustness_utils.py [0:0]


def get_perturbation_transform(config: SemanticRobustnessConfig) -> SemanticPerturbation:
    """Returns a semantic perturbation transform based on parameters in `config`.

    :param config: A config that specifies a perturbation type, which dictates the
        SemanticPerturbation that gets returned, and its configurable parameters.
    :returns: A SemanticPerturbation instance, initialized with parameters passed via `config`.
    """
    if config.perturbation_type == BUTTER_FINGER:
        return ButterFinger(
            input_key=DatasetColumns.MODEL_INPUT.value.name,
            output_keys=[
                create_output_key(ButterFinger.__name__, DatasetColumns.MODEL_INPUT.value.name, i)
                for i in range(config.num_perturbations)
            ],
            num_perturbations=config.num_perturbations,
            perturbation_prob=config.butter_finger_perturbation_prob,
        )
    elif config.perturbation_type == RANDOM_UPPER_CASE:
        return RandomUppercase(
            input_key=DatasetColumns.MODEL_INPUT.value.name,
            output_keys=[
                create_output_key(RandomUppercase.__name__, DatasetColumns.MODEL_INPUT.value.name, i)
                for i in range(config.num_perturbations)
            ],
            num_perturbations=config.num_perturbations,
            uppercase_fraction=config.random_uppercase_corrupt_proportion,
        )
    else:
        return AddRemoveWhitespace(
            input_key=DatasetColumns.MODEL_INPUT.value.name,
            output_keys=[
                create_output_key(AddRemoveWhitespace.__name__, DatasetColumns.MODEL_INPUT.value.name, i)
                for i in range(config.num_perturbations)
            ],
            num_perturbations=config.num_perturbations,
            add_prob=config.whitespace_add_prob,
            remove_prob=config.whitespace_remove_prob,
        )