def update()

in ssl/real-dataset/byol_trainer.py [0:0]


    def update(self, batch_view_1, batch_view_2):
        # compute query feature
        if self.predictor:
            before_predictor_1 = self.online_network(batch_view_1)
            before_predictor_2 = self.online_network(batch_view_2)

            if self.dyn_bn:
               before_predictor_1 = self.bn_before_online(before_predictor_1)
               before_predictor_2 = self.bn_before_online(before_predictor_2)

            if self.predictor_reg in ["minimal_space", "solve", "corr", "partition"] or self.corr_collect:
                corrs = []
                means = []

                corrs_per_partition_pos = [list() for i in range(self.n_corr)]
                corrs_per_partition_neg = [list() for i in range(self.n_corr)]
                partitions = []

                before_detach_1 = before_predictor_1.detach() 
                before_detach_2 = before_predictor_2.detach()

                for b in (before_detach_1, before_detach_2):
                    corr = torch.bmm(b.unsqueeze(2), b.unsqueeze(1))
                    corrs.append(corr)
                    means.append(b)

                    if self.predictor_reg == "partition":
                        pred = b @ self.partition_w
                        thres = pred.mean(dim=0, keepdim=True)
                        # n_batch x n_corr
                        partition = pred >= thres
                        partitions.append(partition)
                        for i in range(self.n_corr):
                            corrs_per_partition_pos[i].append(corr[partition[:,i], :, :])
                            corrs_per_partition_neg[i].append(corr[~partition[:,i], :, :])
                            cnt = partition[:,i].sum().item()
                            self.counts_pos[i] += cnt
                            self.counts_neg[i] += partition.size(0) - cnt

                    if torch.any(torch.isnan(corr)).item():
                        import pdb
                        pdb.set_trace()

                self.cum_mean1.add_list(means)
                self.cum_corr.add_list(corrs)

                if self.predictor_reg == "partition":
                    if not self.predictor_signaling_2:
                        log.info(f"Partition corr matrix. n_corr: {self.n_corr}, counts_pos: {self.counts_pos}, counts_neg: {self.counts_neg}")
                    for i in range(self.n_corr):
                        self.cum_corrs_pos[i].add_list(corrs_per_partition_pos[i])
                        self.cum_corrs_neg[i].add_list(corrs_per_partition_neg[i])
            
            if self.predictor_reg != "partition":
                predictions_from_view_1 = self.predictor(before_predictor_1)
                predictions_from_view_2 = self.predictor(before_predictor_2)
            else:
                # Using special way to compute through predictor.
                ws_pos = torch.zeros(self.n_corr, 128, 128).to(device=before_predictor_1.get_device())
                ws_neg = torch.zeros(self.n_corr, 128, 128).to(device=before_predictor_1.get_device())
                for i in range(self.n_corr):
                    M_pos = self.cum_corrs_pos[i].get()
                    M_neg = self.cum_corrs_neg[i].get()

                    check_valid(M_pos)
                    check_valid(M_neg)

                    ws_pos[i,:,:] = self.compute_w_corr(M_pos)
                    ws_neg[i,:,:] = self.compute_w_corr(M_neg)

                # Then we can use ws according to partition.
                ws_sel_1 = (partitions[0].float() @ ws_pos.view(self.n_corr, -1) + (1 - partitions[0].float()) @ ws_neg.view(self.n_corr, -1)) / self.n_corr
                predictions_from_view_1 = torch.bmm(ws_sel_1.view(-1, 128, 128), before_predictor_1.unsqueeze(2)).squeeze(2)

                ws_sel_2 = (partitions[1].float() @ ws_pos.view(self.n_corr, -1) + (1 - partitions[1].float()) @ ws_neg.view(self.n_corr, -1)) / self.n_corr
                predictions_from_view_2 = torch.bmm(ws_sel_2.view(-1, 128, 128), before_predictor_2.unsqueeze(2)).squeeze(2)
        else:
            predictions_from_view_1 = self.online_network(batch_view_1)
            predictions_from_view_2 = self.online_network(batch_view_2)

        # compute key features
        with torch.no_grad():
            self.target_network.projetion.set_adj_grad(False)

            targets_to_view_2 = self.target_network(batch_view_1)
            targets_to_view_1 = self.target_network(batch_view_2)

            self.target_network.projetion.set_adj_grad(True)

            if self.dyn_bn:
               targets_to_view_2 = self.bn_before_target(targets_to_view_2)
               targets_to_view_1 = self.bn_before_target(targets_to_view_1)

            if self.predictor_reg in ["minimal_space", "solve", "corr", "directcopy"] or self.corr_collect or self.use_order_of_variance:
                cross_corr1 = torch.bmm(targets_to_view_1.unsqueeze(2), before_detach_1.unsqueeze(1)).mean(dim=0)
                cross_corr2 = torch.bmm(targets_to_view_2.unsqueeze(2), before_detach_2.unsqueeze(1)).mean(dim=0)
                cross_corr = (cross_corr1 + cross_corr2) / 2

                mean_f_ema = (targets_to_view_1.mean(dim=0) + targets_to_view_2.mean(dim=0)) / 2

                if torch.any(torch.isnan(cross_corr)).item():
                    import pdb
                    pdb.set_trace()

                self.cum_mean2.add(mean_f_ema)
                self.cum_cross_corr.add(cross_corr)

        if self.use_order_of_variance:
            if not self.predictor_signaling_2:
                log.info(f"Use order of variance!")

            # Skip the predictor completely.
            M = self.cum_corr.get()
            M2 = self.cum_cross_corr.get()
            # just check their diagonal.
            Mdiag = M.diag()
            M2diag = M2.diag()
            # ratio 
            var_ratio = M2diag / (Mdiag + 1e-5)
            # Ideally we want to have low variance in M2 but high variance in M
            _, indices = var_ratio.sort()
            # Then setup the goal
            d = indices.size(0)
            d_partial = d // 3
            good_indices = indices[:d_partial]
            bad_indices = indices[d_partial:]

            # Compute variance.
            before_predictor = torch.cat([before_predictor_1, before_predictor_2], dim=0)
            before_predictor_normalized = before_predictor / before_predictor.norm(dim=1, keepdim=True)
            variances = before_predictor_normalized.var(dim=0)
            # Minimize the bad variance (suppress the features), while maximize the good variance (boost the feature)
            loss = variances[bad_indices].mean() - variances[good_indices].mean()
        else:
            loss = self.regression_loss(predictions_from_view_1, targets_to_view_1, l2_normalized=self.use_l2_normalization)
            loss += self.regression_loss(predictions_from_view_2, targets_to_view_2, l2_normalized=self.use_l2_normalization)

        return loss.mean()