def __init__()

in criterions/avid.py [0:0]


    def __init__(self, num_data, embedding_dim,
                 num_negatives=4096,
                 momentum=0.9,
                 xModal_coeff=1.,
                 wModal_coeff=0.,
                 checkpoint=None,
                 device=0):
        super(AVID, self).__init__()
        '''
        AVID criterion.
        This module receives the output embeddings of the video 
        and audio models, computes their non-linear projections, 
        manages the memory bank and computes the final loss.

        Args:
        - num_data: number of instances in the training set.
        - embedding_dim: output dimension of the non-linear projection.
        - num_negatives: number of negatives to draw from memory bank to compute the NCE loss.
        - momentum: memory bank EMA momemtum parameter.
        - xModal_coeff: coefficient for the cross modal loss. (Cross-AVID: 1.0 | Self-AVID: 0.0 | Joint-AVID: 1.0)
        - wModal_coeff: coefficient for the within modal loss. (Cross-AVID: 0.0 | Self-AVID: 1.0 | Joint-AVID: 1.0)
        - checkpoint: optinally specify a checkpoint path to restore the memory bank and partition function
        '''

        self.nce_average = AVIDSimilarityMemoryBank(
            memory_size=num_data,
            embedding_dim=embedding_dim,
            num_negatives=num_negatives,
            momentum=momentum,
            xModal=xModal_coeff>0.,
            wModal=wModal_coeff>0.,
            device=device
        )
        self.nce_average = self.nce_average.cuda(device)

        sum_coeff = (xModal_coeff + wModal_coeff)
        self.xModal_coeff = xModal_coeff / sum_coeff
        self.wModal_coeff = wModal_coeff / sum_coeff
        self.criterion = NCECriterion(num_data)

        # Restore memory bank and partition function if necessary
        if checkpoint is not None:
            ckp = torch.load(checkpoint, map_location='cpu')['train_criterion']
            state_dict = self.state_dict()

            # Restore memory banks
            state_dict['nce_average.view1_mem'] = ckp['nce_average.view1_mem']
            state_dict['nce_average.view2_mem'] = ckp['nce_average.view2_mem']

            # Restore partition function
            Z = torch.stack([ckp[k] for k in ckp if 'avg_exp_score' in k]).mean()
            for k in state_dict:
                if 'avg_exp_score' in k:
                    state_dict[k] = Z
            self.load_state_dict(state_dict)