def compute_ordinal_depth_loss()

in phosa/global_opt.py [0:0]


    def compute_ordinal_depth_loss(self, masks, silhouettes, depths):
        loss = torch.tensor(0.0).float().cuda()
        num_pairs = 0
        for i in range(len(silhouettes)):
            for j in range(len(silhouettes)):
                has_pred = silhouettes[i] & silhouettes[j]
                if has_pred.sum() == 0:
                    continue
                else:
                    num_pairs += 1
                front_i_gt = masks[i] & (~masks[j])
                front_j_pred = depths[j] < depths[i]
                m = front_i_gt & front_j_pred & has_pred
                if m.sum() == 0:
                    continue
                dists = torch.clamp(depths[i] - depths[j], min=0.0, max=2.0)
                loss += torch.sum(torch.log(1 + torch.exp(dists))[m])
        loss /= num_pairs
        return {"loss_depth": loss}