in vissl/losses/swav_loss.py [0:0]
def forward(self, scores: torch.Tensor, head_id: int):
assert scores.shape[0] % self.num_crops == 0
bs = scores.shape[0] // self.num_crops
total_loss = 0
n_term_loss = 0
# 2 big crops are normally used for the assignment
for i, crop_id in enumerate(self.crops_for_assign):
# Compute the target assignments, taking crop_id as the features
# used to compute the codes to which other crops will be mapped
with torch.no_grad():
scores_this_crop = scores[bs * crop_id : bs * (crop_id + 1)]
# Add representations of the queue (this option is useful when
# the batch size is small, to increase the number of samples
# in sinkhornknopp to make equal repartition possible)
if self.use_queue:
queue = getattr(self, "local_queue" + str(head_id))[i].clone()
scores_this_crop = torch.cat((scores_this_crop, queue))
# Divide by epsilon (which can be seen as a temperature which
# helps to sharpen the distribution of the assignments)
if self.use_double_prec:
assignments = torch.exp(
scores_this_crop.double() / np.float64(self.epsilon)
).t()
assignments = assignments.double()
else:
assignments = scores_this_crop / self.epsilon
# use the log-sum-exp trick for numerical stability.
M = torch.max(assignments)
all_reduce_max(M)
assignments -= M
assignments = torch.exp(assignments).t()
# Apply sinkhornknopp algorithm to divide equally the
# assignment to each of the prototypes
assignments = distributed_sinkhornknopp(
Q=assignments,
hard_assignment=self.num_iteration
< self.temp_hard_assignment_iters,
world_size=self.world_size,
num_iter=self.nmb_sinkhornknopp_iters,
use_gpu=self.use_gpu,
use_double_prec=self.use_double_prec,
)
assignments = assignments[:bs]
# For each crop other than the one used as target assignment
# compute the cross entropy between the target assigment and
# the soft-max of the dot product of each crop to the prototypes
loss = 0
idx_crop_pred = np.delete(np.arange(self.num_crops), crop_id)
for p in idx_crop_pred:
if self.use_double_prec:
loss -= torch.mean(
torch.sum(
assignments
* self.log_softmax(
scores[bs * p : bs * (p + 1)].double()
/ np.float64(self.temperature)
),
dim=1,
dtype=assignments.dtype,
)
)
else:
loss -= torch.mean(
torch.sum(
assignments
* self.log_softmax(
scores[bs * p : bs * (p + 1)] / self.temperature
),
dim=1,
dtype=assignments.dtype,
)
)
# Average of the contribution of each crop (we don't want and
# increase in the number of crop to impact the loss magnitude
# and force us to update the LR)
loss /= len(idx_crop_pred)
# Average the contribution of each swapped assignment (the
# division by 'n_term_loss' is done at the end of the loop)
# for the same reason as above
total_loss += loss
n_term_loss += 1
# Stop training if NaN appears and log the output to help debugging
# TODO (prigoyal): extract the logic to be common for all losses
# debug_state() method that all losses can override
if torch.isnan(loss):
logging.info(
f"Infinite Loss or NaN. Loss value: {loss}, rank: {self.dist_rank}"
)
scores_output_file = os.path.join(
self.output_dir,
"rank" + str(self.dist_rank) + "_scores" + str(i) + ".pth",
)
assignments_out_file = os.path.join(
self.output_dir,
"rank" + str(self.dist_rank) + "_assignments" + str(i) + ".pth",
)
with g_pathmgr.open(scores_output_file, "wb") as fwrite:
torch.save(scores, fwrite)
with g_pathmgr.open(assignments_out_file, "wb") as fwrite:
torch.save(assignments, fwrite)
logging.info(f"Saved the scores matrix to: {scores_output_file}")
logging.info(f"Saved the assignment matrix to: {assignments_out_file}")
total_loss /= n_term_loss
return total_loss