in vissl/utils/hydra_config.py [0:0]
def infer_losses_config(cfg):
"""
Infer settings for various self-supervised losses. Takes care of setting various loss
parameters correctly like world size, batch size per gpu, effective global batch size,
collator etc.
Each loss has additional set of parameters that can be inferred to ensure smooth
training in case user forgets to adjust all the parameters.
"""
train_transforms = cfg.DATA.TRAIN.TRANSFORMS
total_num_crops = next(
(
transform["total_num_crops"]
for transform in train_transforms
if "total_num_crops" in transform
),
None,
)
# some inference for the Info-NCE loss.
if "simclr_info_nce_loss" in cfg.LOSS.name:
cfg.LOSS[cfg.LOSS.name]["buffer_params"]["world_size"] = (
cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
)
world_size = cfg.LOSS[cfg.LOSS.name]["buffer_params"]["world_size"]
batch_size = cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA
num_positives = 2 # simclr uses 2 copies per image
cfg.LOSS[cfg.LOSS.name]["buffer_params"]["effective_batch_size"] = (
num_positives * batch_size * world_size
)
# bce_logits_multiple_output_single_target
if cfg.LOSS.name == "bce_logits_multiple_output_single_target":
world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
cfg.LOSS.bce_logits_multiple_output_single_target.world_size = world_size
# multicrop version of simclr loss
if cfg.LOSS.name == "multicrop_simclr_info_nce_loss":
world_size = cfg.LOSS.multicrop_simclr_info_nce_loss.buffer_params.world_size
batch_size = cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA
cfg.LOSS.multicrop_simclr_info_nce_loss.buffer_params.world_size = world_size
cfg.LOSS.multicrop_simclr_info_nce_loss.buffer_params.effective_batch_size = (
batch_size * world_size
)
cfg.LOSS.multicrop_simclr_info_nce_loss.num_crops = (
total_num_crops or cfg.LOSS.multicrop_simclr_info_nce_loss.num_crops
)
cfg.DATA.TRAIN.COLLATE_FUNCTION = "multicrop_collator"
# some inference for the DeepCluster-v2 loss.
if cfg.LOSS.name == "deepclusterv2_loss":
cfg.LOSS.deepclusterv2_loss.DROP_LAST = cfg.DATA.TRAIN.DROP_LAST
cfg.LOSS.deepclusterv2_loss.BATCHSIZE_PER_REPLICA = (
cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA
)
cfg.LOSS.deepclusterv2_loss.num_crops = (
total_num_crops or cfg.LOSS.deepclusterv2_loss.num_crops
)
cfg.DATA.TRAIN.COLLATE_FUNCTION = "multicrop_collator"
# some inference for the SwAV loss.
if cfg.LOSS.name == "swav_loss":
assert len(cfg.MODEL.HEAD.PARAMS) == 1
assert cfg.MODEL.HEAD.PARAMS[0][0] in {"swav_head", "swav_head_fsdp"}
assert cfg.DATA.TRAIN.COLLATE_FUNCTION in [
"multicrop_collator",
"multicrop_mixup_collator",
"cutmixup_collator",
], (
"for swav loss, use either a collator from "
"[multicrop_collator, multicrop_mixup_collator]"
)
cfg.LOSS.swav_loss.num_prototypes = cfg.MODEL.HEAD.PARAMS[0][1]["num_clusters"]
cfg.LOSS.swav_loss.embedding_dim = cfg.MODEL.HEAD.PARAMS[0][1]["dims"][-1]
cfg.LOSS.swav_loss.num_crops = total_num_crops or cfg.LOSS.swav_loss.num_crops
from vissl.utils.checkpoint import get_checkpoint_folder
cfg.LOSS.swav_loss.output_dir = get_checkpoint_folder(cfg)
world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
batch_size = cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA
batch_size *= world_size
queue_length = cfg.LOSS.swav_loss.queue.queue_length
queue_length -= queue_length % batch_size
cfg.LOSS.swav_loss.queue.queue_length = queue_length
cfg.LOSS.swav_loss.queue.local_queue_length = queue_length // world_size
# some inference for the SwAV momentum loss.
if cfg.LOSS.name == "swav_momentum_loss":
assert len(cfg.MODEL.HEAD.PARAMS) == 1
assert cfg.MODEL.HEAD.PARAMS[0][0] == "swav_head"
cfg.LOSS.swav_momentum_loss.num_prototypes = cfg.MODEL.HEAD.PARAMS[0][1][
"num_clusters"
]
cfg.LOSS.swav_momentum_loss.embedding_dim = cfg.MODEL.HEAD.PARAMS[0][1]["dims"][
-1
]
cfg.LOSS.swav_momentum_loss.num_crops = (
total_num_crops or cfg.LOSS.swav_momentum_loss.num_crops
)
cfg.DATA.TRAIN.COLLATE_FUNCTION = "multicrop_collator"
world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
batch_size = cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA
batch_size *= world_size
queue_length = cfg.LOSS.swav_momentum_loss.queue.queue_length
queue_length -= queue_length % batch_size
cfg.LOSS.swav_momentum_loss.queue.queue_length = queue_length
cfg.LOSS.swav_momentum_loss.queue.local_queue_length = (
queue_length // world_size
)
# some inference for DINO loss.
if cfg.LOSS.name == "dino_loss":
assert len(cfg.MODEL.HEAD.PARAMS) == 1
assert cfg.MODEL.HEAD.PARAMS[0][0] == "swav_head"
cfg.LOSS.dino_loss.output_dim = cfg.MODEL.HEAD.PARAMS[0][1]["num_clusters"][0]
cfg.LOSS.dino_loss.num_crops = total_num_crops or cfg.LOSS.dino_loss.num_crops
cfg.DATA.TRAIN.COLLATE_FUNCTION = "multicrop_collator"
return cfg