def get_mask_func()

in activemri/experimental/cvpr19_models/data/masking_utils.py [0:0]


def get_mask_func(mask_type, which_dataset, rnl_params=None):
    # Whether the number of lines is random or not
    random_num_lines = mask_type[-4:] == "_rnl"
    if "symmetric_basic" in mask_type:
        logging.info(
            f"Mask is symmetric uniform choice with random_num_lines={random_num_lines}."
        )
        return SymmetricUniformChoiceMaskFunc(
            [0.125],
            [4],
            which_dataset,
            random_num_lines=random_num_lines,
            rnl_params=rnl_params,
        )
    if "basic" in mask_type:
        # First two parameters are ignored if `random_num_lines` is True
        logging.info(
            f"Mask is fixed acceleration mask with random_num_lines={random_num_lines}."
        )
        return BasicMaskFunc(
            [0.125],
            [4],
            which_dataset,
            random_num_lines=random_num_lines,
            rnl_params=rnl_params,
        )
    if "low_to_high" in mask_type:
        logging.info(
            f"Mask is symmetric low to high with random_num_lines={random_num_lines}."
        )
        return SymmetricLowToHighMaskFunc(
            [0.125],
            [4],
            which_dataset,
            random_num_lines=random_num_lines,
            rnl_params=rnl_params,
        )
    if "symmetric_grid" in mask_type:
        logging.info("Mask is symmetric grid.")
        return SymmetricUniformGridMaskFunc(
            [], [], which_dataset, random_num_lines=True, rnl_params=rnl_params
        )
    if "grid" in mask_type:
        logging.info("Mask is grid (not symmetric).")
        return UniformGridMaskFunc(
            [], [], which_dataset, random_num_lines=True, rnl_params=rnl_params
        )
    raise ValueError(f"Invalid mask type: {mask_type}.")