def __init__()

in criterions/avid_cma.py [0:0]


    def __init__(self, num_data, embedding_dim,
                 num_negatives=1024,
                 num_negatives_within=None,
                 momentum=0.5,
                 xModalInstCoeff=1.,
                 wModalInstCoeff=0.,
                 xModalPosCoeff=0.,
                 wModalPosCoeff=1.,
                 sampling_args=None,
                 checkpoint=None,
                 resample_freq=-1,
                 device=0):
        super(AVID_CMA, self).__init__()
        '''
        AVID_CMA criterion.
        This module receives the output embeddings of the video 
        and audio models, computes their non-linear projections, 
        manages the memory bank, draws positive correspondences, 
        and computes the final loss (weighted average between 
        instance discrimination and positive discrimination losses).

        Args:
        - num_data: number of instances in the training set.
        - embedding_dim: output dimension of the non-linear projection.
        - num_negatives: number of negatives to draw from memory bank to compute the NCE loss.
        - num_negatives_within: optionally reduce the number of negatives for the within-modal loss.
        - momentum: memory bank EMA momentum parameter.
        - xModalInstCoeff: coefficient for the cross modal instance discrimination loss. (AVID-CMA: 1.0)
        - wModalInstCoeff: coefficient for the within modal instance discrimination loss. (AVID-AVID: 0.0)
        - xModalPosCoeff: coefficient for the cross modal positive discrimination loss. (AVID-CMA: 0.0)
        - wModalPosCoeff: coefficient for the within modal positive discrimination loss. (AVID-AVID: 1.0)
        - checkpoint: optionally specify a checkpoint path to restore the memory bank and partition function
        '''

        # first setup the NCEAverage method to get the scores of the output wrt. memory bank negatives
        self.nce_average = AVIDSimilarityPositiveExpansion(
            memory_size=num_data,
            embedding_dim=embedding_dim,
            num_negatives=num_negatives,
            num_negatives_within=num_negatives_within,
            momentum=momentum,
            xModalInst=xModalInstCoeff>0.,
            xModalPos=xModalPosCoeff>0.,
            wModalInst=wModalInstCoeff>0.,
            wModalPos=wModalPosCoeff>0.,
            sampling_args=sampling_args,
            device=device
        )
        self.nce_average = self.nce_average.cuda(device)

        # Loss coefficients
        sum_coeff = xModalInstCoeff + wModalInstCoeff + xModalPosCoeff + wModalPosCoeff
        self.xModalInstCoeff = xModalInstCoeff / sum_coeff
        self.wModalInstCoeff = wModalInstCoeff / sum_coeff
        self.xModalPosCoeff = xModalPosCoeff / sum_coeff
        self.wModalPosCoeff = wModalPosCoeff / sum_coeff

        # Setup loss function
        self.criterion = NCECriterion(num_data)

        # Restore memory bank and partition function from AVID checkpoint
        # Needs to be done before finding correspondences
        if checkpoint is not None:
            ckp = torch.load(checkpoint, map_location='cpu')['train_criterion']
            state_dict = self.state_dict()
            # Restore memory banks
            state_dict['nce_average.view1_mem'] = ckp['nce_average.view1_mem']
            state_dict['nce_average.view2_mem'] = ckp['nce_average.view2_mem']
            # Restore partition function
            Z = torch.stack([ckp[k] for k in ckp if 'avg_exp_score' in k]).mean()
            for k in state_dict:
                if 'avg_exp_score' in k:
                    state_dict[k] = Z
            self.load_state_dict(state_dict)

        # Find CMA correspondences
        self.resample_freq = resample_freq
        self.nce_average.find_correspondences()