in ssl/real-dataset/byol_trainer.py [0:0]
def update(self, batch_view_1, batch_view_2):
# compute query feature
if self.predictor:
before_predictor_1 = self.online_network(batch_view_1)
before_predictor_2 = self.online_network(batch_view_2)
if self.dyn_bn:
before_predictor_1 = self.bn_before_online(before_predictor_1)
before_predictor_2 = self.bn_before_online(before_predictor_2)
if self.predictor_reg in ["minimal_space", "solve", "corr", "partition"] or self.corr_collect:
corrs = []
means = []
corrs_per_partition_pos = [list() for i in range(self.n_corr)]
corrs_per_partition_neg = [list() for i in range(self.n_corr)]
partitions = []
before_detach_1 = before_predictor_1.detach()
before_detach_2 = before_predictor_2.detach()
for b in (before_detach_1, before_detach_2):
corr = torch.bmm(b.unsqueeze(2), b.unsqueeze(1))
corrs.append(corr)
means.append(b)
if self.predictor_reg == "partition":
pred = b @ self.partition_w
thres = pred.mean(dim=0, keepdim=True)
# n_batch x n_corr
partition = pred >= thres
partitions.append(partition)
for i in range(self.n_corr):
corrs_per_partition_pos[i].append(corr[partition[:,i], :, :])
corrs_per_partition_neg[i].append(corr[~partition[:,i], :, :])
cnt = partition[:,i].sum().item()
self.counts_pos[i] += cnt
self.counts_neg[i] += partition.size(0) - cnt
if torch.any(torch.isnan(corr)).item():
import pdb
pdb.set_trace()
self.cum_mean1.add_list(means)
self.cum_corr.add_list(corrs)
if self.predictor_reg == "partition":
if not self.predictor_signaling_2:
log.info(f"Partition corr matrix. n_corr: {self.n_corr}, counts_pos: {self.counts_pos}, counts_neg: {self.counts_neg}")
for i in range(self.n_corr):
self.cum_corrs_pos[i].add_list(corrs_per_partition_pos[i])
self.cum_corrs_neg[i].add_list(corrs_per_partition_neg[i])
if self.predictor_reg != "partition":
predictions_from_view_1 = self.predictor(before_predictor_1)
predictions_from_view_2 = self.predictor(before_predictor_2)
else:
# Using special way to compute through predictor.
ws_pos = torch.zeros(self.n_corr, 128, 128).to(device=before_predictor_1.get_device())
ws_neg = torch.zeros(self.n_corr, 128, 128).to(device=before_predictor_1.get_device())
for i in range(self.n_corr):
M_pos = self.cum_corrs_pos[i].get()
M_neg = self.cum_corrs_neg[i].get()
check_valid(M_pos)
check_valid(M_neg)
ws_pos[i,:,:] = self.compute_w_corr(M_pos)
ws_neg[i,:,:] = self.compute_w_corr(M_neg)
# Then we can use ws according to partition.
ws_sel_1 = (partitions[0].float() @ ws_pos.view(self.n_corr, -1) + (1 - partitions[0].float()) @ ws_neg.view(self.n_corr, -1)) / self.n_corr
predictions_from_view_1 = torch.bmm(ws_sel_1.view(-1, 128, 128), before_predictor_1.unsqueeze(2)).squeeze(2)
ws_sel_2 = (partitions[1].float() @ ws_pos.view(self.n_corr, -1) + (1 - partitions[1].float()) @ ws_neg.view(self.n_corr, -1)) / self.n_corr
predictions_from_view_2 = torch.bmm(ws_sel_2.view(-1, 128, 128), before_predictor_2.unsqueeze(2)).squeeze(2)
else:
predictions_from_view_1 = self.online_network(batch_view_1)
predictions_from_view_2 = self.online_network(batch_view_2)
# compute key features
with torch.no_grad():
self.target_network.projetion.set_adj_grad(False)
targets_to_view_2 = self.target_network(batch_view_1)
targets_to_view_1 = self.target_network(batch_view_2)
self.target_network.projetion.set_adj_grad(True)
if self.dyn_bn:
targets_to_view_2 = self.bn_before_target(targets_to_view_2)
targets_to_view_1 = self.bn_before_target(targets_to_view_1)
if self.predictor_reg in ["minimal_space", "solve", "corr", "directcopy"] or self.corr_collect or self.use_order_of_variance:
cross_corr1 = torch.bmm(targets_to_view_1.unsqueeze(2), before_detach_1.unsqueeze(1)).mean(dim=0)
cross_corr2 = torch.bmm(targets_to_view_2.unsqueeze(2), before_detach_2.unsqueeze(1)).mean(dim=0)
cross_corr = (cross_corr1 + cross_corr2) / 2
mean_f_ema = (targets_to_view_1.mean(dim=0) + targets_to_view_2.mean(dim=0)) / 2
if torch.any(torch.isnan(cross_corr)).item():
import pdb
pdb.set_trace()
self.cum_mean2.add(mean_f_ema)
self.cum_cross_corr.add(cross_corr)
if self.use_order_of_variance:
if not self.predictor_signaling_2:
log.info(f"Use order of variance!")
# Skip the predictor completely.
M = self.cum_corr.get()
M2 = self.cum_cross_corr.get()
# just check their diagonal.
Mdiag = M.diag()
M2diag = M2.diag()
# ratio
var_ratio = M2diag / (Mdiag + 1e-5)
# Ideally we want to have low variance in M2 but high variance in M
_, indices = var_ratio.sort()
# Then setup the goal
d = indices.size(0)
d_partial = d // 3
good_indices = indices[:d_partial]
bad_indices = indices[d_partial:]
# Compute variance.
before_predictor = torch.cat([before_predictor_1, before_predictor_2], dim=0)
before_predictor_normalized = before_predictor / before_predictor.norm(dim=1, keepdim=True)
variances = before_predictor_normalized.var(dim=0)
# Minimize the bad variance (suppress the features), while maximize the good variance (boost the feature)
loss = variances[bad_indices].mean() - variances[good_indices].mean()
else:
loss = self.regression_loss(predictions_from_view_1, targets_to_view_1, l2_normalized=self.use_l2_normalization)
loss += self.regression_loss(predictions_from_view_2, targets_to_view_2, l2_normalized=self.use_l2_normalization)
return loss.mean()