in criterions/avid_cma.py [0:0]
def __init__(self, num_data, embedding_dim,
num_negatives=1024,
num_negatives_within=None,
momentum=0.5,
xModalInstCoeff=1.,
wModalInstCoeff=0.,
xModalPosCoeff=0.,
wModalPosCoeff=1.,
sampling_args=None,
checkpoint=None,
resample_freq=-1,
device=0):
super(AVID_CMA, self).__init__()
'''
AVID_CMA criterion.
This module receives the output embeddings of the video
and audio models, computes their non-linear projections,
manages the memory bank, draws positive correspondences,
and computes the final loss (weighted average between
instance discrimination and positive discrimination losses).
Args:
- num_data: number of instances in the training set.
- embedding_dim: output dimension of the non-linear projection.
- num_negatives: number of negatives to draw from memory bank to compute the NCE loss.
- num_negatives_within: optionally reduce the number of negatives for the within-modal loss.
- momentum: memory bank EMA momentum parameter.
- xModalInstCoeff: coefficient for the cross modal instance discrimination loss. (AVID-CMA: 1.0)
- wModalInstCoeff: coefficient for the within modal instance discrimination loss. (AVID-AVID: 0.0)
- xModalPosCoeff: coefficient for the cross modal positive discrimination loss. (AVID-CMA: 0.0)
- wModalPosCoeff: coefficient for the within modal positive discrimination loss. (AVID-AVID: 1.0)
- checkpoint: optionally specify a checkpoint path to restore the memory bank and partition function
'''
# first setup the NCEAverage method to get the scores of the output wrt. memory bank negatives
self.nce_average = AVIDSimilarityPositiveExpansion(
memory_size=num_data,
embedding_dim=embedding_dim,
num_negatives=num_negatives,
num_negatives_within=num_negatives_within,
momentum=momentum,
xModalInst=xModalInstCoeff>0.,
xModalPos=xModalPosCoeff>0.,
wModalInst=wModalInstCoeff>0.,
wModalPos=wModalPosCoeff>0.,
sampling_args=sampling_args,
device=device
)
self.nce_average = self.nce_average.cuda(device)
# Loss coefficients
sum_coeff = xModalInstCoeff + wModalInstCoeff + xModalPosCoeff + wModalPosCoeff
self.xModalInstCoeff = xModalInstCoeff / sum_coeff
self.wModalInstCoeff = wModalInstCoeff / sum_coeff
self.xModalPosCoeff = xModalPosCoeff / sum_coeff
self.wModalPosCoeff = wModalPosCoeff / sum_coeff
# Setup loss function
self.criterion = NCECriterion(num_data)
# Restore memory bank and partition function from AVID checkpoint
# Needs to be done before finding correspondences
if checkpoint is not None:
ckp = torch.load(checkpoint, map_location='cpu')['train_criterion']
state_dict = self.state_dict()
# Restore memory banks
state_dict['nce_average.view1_mem'] = ckp['nce_average.view1_mem']
state_dict['nce_average.view2_mem'] = ckp['nce_average.view2_mem']
# Restore partition function
Z = torch.stack([ckp[k] for k in ckp if 'avg_exp_score' in k]).mean()
for k in state_dict:
if 'avg_exp_score' in k:
state_dict[k] = Z
self.load_state_dict(state_dict)
# Find CMA correspondences
self.resample_freq = resample_freq
self.nce_average.find_correspondences()