def cutmixup_collator()

in vissl/data/collators/cutmixup_collator.py [0:0]


def cutmixup_collator(batch, **kwargs):
    """
    This collator implements CutMix (https://arxiv.org/abs/1905.04899) and/or
    MixUp (https://arxiv.org/abs/1710.09412) via ClassyVision's
    implementation (link when publicly available).

    kwargs:
    :mixup_alpha (float): mixup alpha value, mixup is active if > 0.
    :cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
    :cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active
    and uses this vs alpha if not None.
    :prob (float): probability of applying mixup or cutmix per batch or element
    :switch_prob (float): probability of switching to cutmix instead of mixup
    when both are active
    :mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of
    elements), 'elem' (element)
    :correct_lam (bool): apply lambda correction when cutmix bbox clipped by
    image borders
    :label_smoothing (float): apply label smoothing to the mixed target tensor
    :num_classes (int): number of classes for target


    The collators collates the batch for the following input (assuming k-copies of image):

    Input:
        batch: Example
                batch = [
                    {"data" : [img1_0, ..., img1_k], ..},
                    {"data" : [img2_0, ..., img2_k], ...},
                    ...
                ]

    Returns: Example output:
                output = {
                            "data": torch.tensor([img1_0, ..., imgN_0],
                                [img1_k, ..., imgN_k]) ..
                         }
    """
    assert "data" in batch[0], "data not found in sample"
    assert "label" in batch[0], "label not found in sample"

    data = [x["data"] for x in batch]
    labels = [torch.tensor(x["label"]) for x in batch]
    data_valid = [torch.tensor(x["data_valid"]) for x in batch]
    data_idx = [torch.tensor(x["data_idx"]) for x in batch]
    num_duplicates, num_images = len(data[0]), len(data)

    # Determine ssl method and adjust collator output accordingly
    ssl_method = None
    if "ssl_method" in kwargs.keys():
        ssl_method = kwargs.pop("ssl_method")

    # Instantiate CutMix + Mixup (CutMixUp!) object
    cutmixup_transform_obj = Mixup(**kwargs)
    # TODO: Uncomment in future update when calling via ClassyVision
    # cutmixup_transform_obj = classy_cutmixup.Mixup(**kwargs)

    output_data, output_label, output_data_valid, output_data_idx = [], [], [], []
    for pos in range(num_duplicates):
        cutmixup_data, cutmixup_labels = [], []
        for idx in range(num_images):
            cutmixup_data.append(data[idx][pos])
            cutmixup_labels.append(labels[idx][pos])
            output_data_valid.append(data_valid[idx][pos])
            output_data_idx.append(data_idx[idx][pos])
        # Get data and labels into format accepted by Mixup
        cutmixup_data = torch.stack(cutmixup_data)
        cutmixup_labels = torch.tensor(cutmixup_labels)
        cutmixup_output = cutmixup_transform_obj(
            {"input": cutmixup_data, "target": cutmixup_labels}
        )
        output_data.append(cutmixup_output["input"])
        output_label.append(cutmixup_output["target"])

    # If using moco or simclr, first restructure the data back into the form
    # in which it was originally input, then call the collator for that ssl
    # method
    if ssl_method == "moco" or ssl_method == "simclr":
        output_batch = data_back_to_input_form(
            output_data, output_label, output_data_valid, output_data_idx
        )
        if ssl_method == "moco":
            return moco_collator(output_batch)
        elif ssl_method == "simclr":
            return simclr_collator(output_batch)
    output_batch = {
        "data": [output_data],
        "label": [torch.cat(output_label)],
        "data_valid": [torch.stack(output_data_valid)],
        "data_idx": [torch.stack(output_data_idx)],
    }
    return output_batch