def forward()

in loss_fn/simclr_infonce.py [0:0]


    def forward(self, output: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            output: BxC
            target: BxC or BxKxC <-- In case of MIL NCE, K is the number of
                positives for each batch element.
            Following https://github.com/google-research/simclr/blob/master/objective.py
        """
        # Normalize first, before the gather -- so that all the features I get
        # are normalized
        output = nn.functional.normalize(output, dim=-1, p=2)
        target = nn.functional.normalize(target, dim=-1, p=2)
        # To be consistent with MIL-NCE input, convert K to batch dim,
        # and repeat the output to same value for each repeated target
        elt_for_back_loss = 0
        if target.ndim == 3:
            num_matching = target.size(1)
            target_flat = target.reshape((-1, target.size(-1)))
            # Keep the first one for the back loss
            target = target[:, elt_for_back_loss]
        else:
            num_matching = 1
            target_flat = target
        # Gather all the outputs and all the targets
        output_all = self.gather_embeddings(output)
        target_flat_all = self.gather_embeddings(target_flat)
        batch_size = output.size(0)
        replica_id = utils.get_rank()
        # -> (B, B_full * num_matching)
        labels_onehot = torch.zeros((batch_size, output_all.size(0)),
                                    dtype=output.dtype,
                                    device=output.device)
        extra_zeros = torch.zeros((batch_size, output_all.size(0)),
                                  dtype=output.dtype,
                                  device=output.device)
        ones_diag = torch.eye(batch_size,
                              batch_size,
                              dtype=output.dtype,
                              device=output.device)
        labels_onehot[:, replica_id * batch_size:(replica_id + 1) *
                      batch_size] = ones_diag
        labels_onehot_interleaved = labels_onehot.repeat_interleave(
            num_matching, dim=1)
        # (B, C) * (B_full, C) -> (B, B_full)
        logits_aa = torch.mm(output, output_all.t() / self.temperature)
        # (B, C) * (B_full * num_matching, C) -> (B, B_full * num_matching)
        logits_ab = torch.mm(output, target_flat_all.t() / self.temperature)
        logits_aa = logits_aa - labels_onehot * LARGE_NUM
        loss = self.criterion(
            torch.cat([logits_ab, logits_aa], 1),
            torch.cat([labels_onehot_interleaved, extra_zeros], 1))
        if self.target_to_output_loss:
            # Keep only the first prediction, since that is what I will incur
            # reverse loss with
            target_all = target_flat_all[elt_for_back_loss::num_matching]
            logits_bb = torch.mm(target, target_all.t() / self.temperature)
            logits_bb = logits_bb - labels_onehot * LARGE_NUM
            logits_ba = torch.mm(target, output_all.t() / self.temperature)
            loss = loss + self.criterion(
                torch.cat([logits_ba, logits_bb], 1),
                torch.cat([labels_onehot, extra_zeros], 1))
        return loss