in fairnr/criterions/rendering_loss.py [0:0]
def compute_loss(self, model, net_output, sample, reduce=True):
losses, other_logs = {}, {}
# prepare data before computing loss
sampled_uv = sample['sampled_uv'] # S, V, 2, N, P, P (patch-size)
S, V, _, N, P1, P2 = sampled_uv.size()
H, W, h, w = sample['size'][0, 0].long().cpu().tolist()
L = N * P1 * P2
flatten_uv = sampled_uv.view(S, V, 2, L)
flatten_index = (flatten_uv[:,:,0] // h + flatten_uv[:,:,1] // w * W).long()
assert 'colors' in sample and sample['colors'] is not None, "ground-truth colors not provided"
target_colors = sample['colors']
masks = (sample['alpha'] > 0) if self.args.no_background_loss else None
if L < target_colors.size(2):
target_colors = target_colors.gather(2, flatten_index.unsqueeze(-1).repeat(1,1,1,3))
masks = masks.gather(2, flatten_uv) if masks is not None else None
if 'other_logs' in net_output:
other_logs.update(net_output['other_logs'])
# computing loss
if self.args.color_weight > 0:
color_loss = utils.rgb_loss(
net_output['colors'], target_colors,
masks, self.args.L1)
losses['color_loss'] = (color_loss, self.args.color_weight)
if self.args.alpha_weight > 0:
_alpha = net_output['missed'].reshape(-1)
alpha_loss = torch.log1p(
1. / 0.11 * _alpha.float() * (1 - _alpha.float())
).mean().type_as(_alpha)
losses['alpha_loss'] = (alpha_loss, self.args.alpha_weight)
if self.args.depth_weight > 0:
if sample['depths'] is not None:
target_depths = target_depths.gather(2, flatten_index)
depth_mask = masks & (target_depths > 0)
depth_loss = utils.depth_loss(net_output['depths'], target_depths, depth_mask)
else:
# no depth map is provided, depth loss only applied on background based on masks
max_depth_target = self.args.max_depth * torch.ones_like(net_output['depths'])
if sample['mask'] is not None:
depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, (1 - sample['mask']).bool())
else:
depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, ~masks)
depth_weight = self.args.depth_weight
if self.args.depth_weight_decay is not None:
final_factor, final_steps = eval(self.args.depth_weight_decay)
depth_weight *= max(0, 1 - (1 - final_factor) * self.task._num_updates / final_steps)
other_logs['depth_weight'] = depth_weight
losses['depth_loss'] = (depth_loss, depth_weight)
if self.args.vgg_weight > 0:
assert P1 * P2 > 1, "we have to use a patch-based sampling for VGG loss"
target_colors = target_colors.reshape(-1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5
output_colors = net_output['colors'].reshape(-1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5
vgg_loss = self.vgg(output_colors, target_colors)
losses['vgg_loss'] = (vgg_loss, self.args.vgg_weight)
if self.args.eikonal_weight > 0:
losses['eik_loss'] = (net_output['eikonal-term'].mean(), self.args.eikonal_weight)
# if self.args.regz_weight > 0:
losses['reg_loss'] = (net_output['regz-term'].mean(), self.args.regz_weight)
loss = sum(losses[key][0] * losses[key][1] for key in losses)
# add a dummy loss
loss = loss + model.dummy_loss + self.dummy_loss * 0.
logging_outputs = {key: item(losses[key][0]) for key in losses}
logging_outputs.update(other_logs)
return loss, logging_outputs