def infer_losses_config()

in vissl/utils/hydra_config.py [0:0]


def infer_losses_config(cfg):
    """
    Infer settings for various self-supervised losses. Takes care of setting various loss
    parameters correctly like world size, batch size per gpu, effective global batch size,
    collator etc.
    Each loss has additional set of parameters that can be inferred to ensure smooth
    training in case user forgets to adjust all the parameters.
    """
    train_transforms = cfg.DATA.TRAIN.TRANSFORMS
    total_num_crops = next(
        (
            transform["total_num_crops"]
            for transform in train_transforms
            if "total_num_crops" in transform
        ),
        None,
    )

    # some inference for the Info-NCE loss.
    if "simclr_info_nce_loss" in cfg.LOSS.name:
        cfg.LOSS[cfg.LOSS.name]["buffer_params"]["world_size"] = (
            cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
        )

        world_size = cfg.LOSS[cfg.LOSS.name]["buffer_params"]["world_size"]
        batch_size = cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA
        num_positives = 2  # simclr uses 2 copies per image
        cfg.LOSS[cfg.LOSS.name]["buffer_params"]["effective_batch_size"] = (
            num_positives * batch_size * world_size
        )

    # bce_logits_multiple_output_single_target
    if cfg.LOSS.name == "bce_logits_multiple_output_single_target":
        world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
        cfg.LOSS.bce_logits_multiple_output_single_target.world_size = world_size

    # multicrop version of simclr loss
    if cfg.LOSS.name == "multicrop_simclr_info_nce_loss":
        world_size = cfg.LOSS.multicrop_simclr_info_nce_loss.buffer_params.world_size
        batch_size = cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA
        cfg.LOSS.multicrop_simclr_info_nce_loss.buffer_params.world_size = world_size
        cfg.LOSS.multicrop_simclr_info_nce_loss.buffer_params.effective_batch_size = (
            batch_size * world_size
        )
        cfg.LOSS.multicrop_simclr_info_nce_loss.num_crops = (
            total_num_crops or cfg.LOSS.multicrop_simclr_info_nce_loss.num_crops
        )
        cfg.DATA.TRAIN.COLLATE_FUNCTION = "multicrop_collator"

    # some inference for the DeepCluster-v2 loss.
    if cfg.LOSS.name == "deepclusterv2_loss":
        cfg.LOSS.deepclusterv2_loss.DROP_LAST = cfg.DATA.TRAIN.DROP_LAST
        cfg.LOSS.deepclusterv2_loss.BATCHSIZE_PER_REPLICA = (
            cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA
        )
        cfg.LOSS.deepclusterv2_loss.num_crops = (
            total_num_crops or cfg.LOSS.deepclusterv2_loss.num_crops
        )
        cfg.DATA.TRAIN.COLLATE_FUNCTION = "multicrop_collator"

    # some inference for the SwAV loss.
    if cfg.LOSS.name == "swav_loss":
        assert len(cfg.MODEL.HEAD.PARAMS) == 1
        assert cfg.MODEL.HEAD.PARAMS[0][0] in {"swav_head", "swav_head_fsdp"}
        assert cfg.DATA.TRAIN.COLLATE_FUNCTION in [
            "multicrop_collator",
            "multicrop_mixup_collator",
            "cutmixup_collator",
        ], (
            "for swav loss, use either a collator from "
            "[multicrop_collator, multicrop_mixup_collator]"
        )
        cfg.LOSS.swav_loss.num_prototypes = cfg.MODEL.HEAD.PARAMS[0][1]["num_clusters"]
        cfg.LOSS.swav_loss.embedding_dim = cfg.MODEL.HEAD.PARAMS[0][1]["dims"][-1]
        cfg.LOSS.swav_loss.num_crops = total_num_crops or cfg.LOSS.swav_loss.num_crops
        from vissl.utils.checkpoint import get_checkpoint_folder

        cfg.LOSS.swav_loss.output_dir = get_checkpoint_folder(cfg)
        world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
        batch_size = cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA
        batch_size *= world_size
        queue_length = cfg.LOSS.swav_loss.queue.queue_length
        queue_length -= queue_length % batch_size
        cfg.LOSS.swav_loss.queue.queue_length = queue_length
        cfg.LOSS.swav_loss.queue.local_queue_length = queue_length // world_size

    # some inference for the SwAV momentum loss.
    if cfg.LOSS.name == "swav_momentum_loss":
        assert len(cfg.MODEL.HEAD.PARAMS) == 1
        assert cfg.MODEL.HEAD.PARAMS[0][0] == "swav_head"
        cfg.LOSS.swav_momentum_loss.num_prototypes = cfg.MODEL.HEAD.PARAMS[0][1][
            "num_clusters"
        ]
        cfg.LOSS.swav_momentum_loss.embedding_dim = cfg.MODEL.HEAD.PARAMS[0][1]["dims"][
            -1
        ]

        cfg.LOSS.swav_momentum_loss.num_crops = (
            total_num_crops or cfg.LOSS.swav_momentum_loss.num_crops
        )
        cfg.DATA.TRAIN.COLLATE_FUNCTION = "multicrop_collator"
        world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
        batch_size = cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA
        batch_size *= world_size
        queue_length = cfg.LOSS.swav_momentum_loss.queue.queue_length
        queue_length -= queue_length % batch_size
        cfg.LOSS.swav_momentum_loss.queue.queue_length = queue_length
        cfg.LOSS.swav_momentum_loss.queue.local_queue_length = (
            queue_length // world_size
        )

    # some inference for DINO loss.
    if cfg.LOSS.name == "dino_loss":
        assert len(cfg.MODEL.HEAD.PARAMS) == 1
        assert cfg.MODEL.HEAD.PARAMS[0][0] == "swav_head"
        cfg.LOSS.dino_loss.output_dim = cfg.MODEL.HEAD.PARAMS[0][1]["num_clusters"][0]
        cfg.LOSS.dino_loss.num_crops = total_num_crops or cfg.LOSS.dino_loss.num_crops
        cfg.DATA.TRAIN.COLLATE_FUNCTION = "multicrop_collator"

    return cfg