in phosa/global_opt.py [0:0]
def compute_ordinal_depth_loss(self, masks, silhouettes, depths):
loss = torch.tensor(0.0).float().cuda()
num_pairs = 0
for i in range(len(silhouettes)):
for j in range(len(silhouettes)):
has_pred = silhouettes[i] & silhouettes[j]
if has_pred.sum() == 0:
continue
else:
num_pairs += 1
front_i_gt = masks[i] & (~masks[j])
front_j_pred = depths[j] < depths[i]
m = front_i_gt & front_j_pred & has_pred
if m.sum() == 0:
continue
dists = torch.clamp(depths[i] - depths[j], min=0.0, max=2.0)
loss += torch.sum(torch.log(1 + torch.exp(dists))[m])
loss /= num_pairs
return {"loss_depth": loss}