def forward()

in vissl/losses/swav_loss.py [0:0]


    def forward(self, scores: torch.Tensor, head_id: int):
        assert scores.shape[0] % self.num_crops == 0
        bs = scores.shape[0] // self.num_crops

        total_loss = 0
        n_term_loss = 0

        # 2 big crops are normally used for the assignment
        for i, crop_id in enumerate(self.crops_for_assign):

            # Compute the target assignments, taking crop_id as the features
            # used to compute the codes to which other crops will be mapped
            with torch.no_grad():
                scores_this_crop = scores[bs * crop_id : bs * (crop_id + 1)]

                # Add representations of the queue (this option is useful when
                # the batch size is small, to increase the number of samples
                # in sinkhornknopp to make equal repartition possible)
                if self.use_queue:
                    queue = getattr(self, "local_queue" + str(head_id))[i].clone()
                    scores_this_crop = torch.cat((scores_this_crop, queue))

                # Divide by epsilon (which can be seen as a temperature which
                # helps to sharpen the distribution of the assignments)
                if self.use_double_prec:
                    assignments = torch.exp(
                        scores_this_crop.double() / np.float64(self.epsilon)
                    ).t()
                    assignments = assignments.double()
                else:
                    assignments = scores_this_crop / self.epsilon
                    # use the log-sum-exp trick for numerical stability.
                    M = torch.max(assignments)
                    all_reduce_max(M)
                    assignments -= M
                    assignments = torch.exp(assignments).t()

                # Apply sinkhornknopp algorithm to divide equally the
                # assignment to each of the prototypes
                assignments = distributed_sinkhornknopp(
                    Q=assignments,
                    hard_assignment=self.num_iteration
                    < self.temp_hard_assignment_iters,
                    world_size=self.world_size,
                    num_iter=self.nmb_sinkhornknopp_iters,
                    use_gpu=self.use_gpu,
                    use_double_prec=self.use_double_prec,
                )
                assignments = assignments[:bs]

            # For each crop other than the one used as target assignment
            # compute the cross entropy between the target assigment and
            # the soft-max of the dot product of each crop to the prototypes
            loss = 0
            idx_crop_pred = np.delete(np.arange(self.num_crops), crop_id)
            for p in idx_crop_pred:
                if self.use_double_prec:
                    loss -= torch.mean(
                        torch.sum(
                            assignments
                            * self.log_softmax(
                                scores[bs * p : bs * (p + 1)].double()
                                / np.float64(self.temperature)
                            ),
                            dim=1,
                            dtype=assignments.dtype,
                        )
                    )
                else:
                    loss -= torch.mean(
                        torch.sum(
                            assignments
                            * self.log_softmax(
                                scores[bs * p : bs * (p + 1)] / self.temperature
                            ),
                            dim=1,
                            dtype=assignments.dtype,
                        )
                    )

            # Average of the contribution of each crop (we don't want and
            # increase in the number of crop to impact the loss magnitude
            # and force us to update the LR)
            loss /= len(idx_crop_pred)

            # Average the contribution of each swapped assignment (the
            # division by 'n_term_loss' is done at the end of the loop)
            # for the same reason as above
            total_loss += loss
            n_term_loss += 1

            # Stop training if NaN appears and log the output to help debugging
            # TODO (prigoyal): extract the logic to be common for all losses
            # debug_state() method that all losses can override
            if torch.isnan(loss):
                logging.info(
                    f"Infinite Loss or NaN. Loss value: {loss}, rank: {self.dist_rank}"
                )
                scores_output_file = os.path.join(
                    self.output_dir,
                    "rank" + str(self.dist_rank) + "_scores" + str(i) + ".pth",
                )
                assignments_out_file = os.path.join(
                    self.output_dir,
                    "rank" + str(self.dist_rank) + "_assignments" + str(i) + ".pth",
                )
                with g_pathmgr.open(scores_output_file, "wb") as fwrite:
                    torch.save(scores, fwrite)
                with g_pathmgr.open(assignments_out_file, "wb") as fwrite:
                    torch.save(assignments, fwrite)
                logging.info(f"Saved the scores matrix to: {scores_output_file}")
                logging.info(f"Saved the assignment matrix to: {assignments_out_file}")

        total_loss /= n_term_loss
        return total_loss