in vissl/data/collators/cutmixup_collator.py [0:0]
def cutmixup_collator(batch, **kwargs):
"""
This collator implements CutMix (https://arxiv.org/abs/1905.04899) and/or
MixUp (https://arxiv.org/abs/1710.09412) via ClassyVision's
implementation (link when publicly available).
kwargs:
:mixup_alpha (float): mixup alpha value, mixup is active if > 0.
:cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
:cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active
and uses this vs alpha if not None.
:prob (float): probability of applying mixup or cutmix per batch or element
:switch_prob (float): probability of switching to cutmix instead of mixup
when both are active
:mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of
elements), 'elem' (element)
:correct_lam (bool): apply lambda correction when cutmix bbox clipped by
image borders
:label_smoothing (float): apply label smoothing to the mixed target tensor
:num_classes (int): number of classes for target
The collators collates the batch for the following input (assuming k-copies of image):
Input:
batch: Example
batch = [
{"data" : [img1_0, ..., img1_k], ..},
{"data" : [img2_0, ..., img2_k], ...},
...
]
Returns: Example output:
output = {
"data": torch.tensor([img1_0, ..., imgN_0],
[img1_k, ..., imgN_k]) ..
}
"""
assert "data" in batch[0], "data not found in sample"
assert "label" in batch[0], "label not found in sample"
data = [x["data"] for x in batch]
labels = [torch.tensor(x["label"]) for x in batch]
data_valid = [torch.tensor(x["data_valid"]) for x in batch]
data_idx = [torch.tensor(x["data_idx"]) for x in batch]
num_duplicates, num_images = len(data[0]), len(data)
# Determine ssl method and adjust collator output accordingly
ssl_method = None
if "ssl_method" in kwargs.keys():
ssl_method = kwargs.pop("ssl_method")
# Instantiate CutMix + Mixup (CutMixUp!) object
cutmixup_transform_obj = Mixup(**kwargs)
# TODO: Uncomment in future update when calling via ClassyVision
# cutmixup_transform_obj = classy_cutmixup.Mixup(**kwargs)
output_data, output_label, output_data_valid, output_data_idx = [], [], [], []
for pos in range(num_duplicates):
cutmixup_data, cutmixup_labels = [], []
for idx in range(num_images):
cutmixup_data.append(data[idx][pos])
cutmixup_labels.append(labels[idx][pos])
output_data_valid.append(data_valid[idx][pos])
output_data_idx.append(data_idx[idx][pos])
# Get data and labels into format accepted by Mixup
cutmixup_data = torch.stack(cutmixup_data)
cutmixup_labels = torch.tensor(cutmixup_labels)
cutmixup_output = cutmixup_transform_obj(
{"input": cutmixup_data, "target": cutmixup_labels}
)
output_data.append(cutmixup_output["input"])
output_label.append(cutmixup_output["target"])
# If using moco or simclr, first restructure the data back into the form
# in which it was originally input, then call the collator for that ssl
# method
if ssl_method == "moco" or ssl_method == "simclr":
output_batch = data_back_to_input_form(
output_data, output_label, output_data_valid, output_data_idx
)
if ssl_method == "moco":
return moco_collator(output_batch)
elif ssl_method == "simclr":
return simclr_collator(output_batch)
output_batch = {
"data": [output_data],
"label": [torch.cat(output_label)],
"data_valid": [torch.stack(output_data_valid)],
"data_idx": [torch.stack(output_data_idx)],
}
return output_batch