in fairnr/modules/renderer.py [0:0]
def forward_chunk(
self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states,
gt_depths=None, output_types=['sigma', 'texture'], global_weights=None,
):
if self.trace_normal:
output_types += ['normal']
sampled_depth = samples['sampled_point_depth']
sampled_idx = samples['sampled_point_voxel_idx'].long()
original_depth = samples.get('original_point_depth', None)
tolerance = self.raymarching_tolerance
chunk_size = self.chunk_size if self.training else self.valid_chunk_size
early_stop = None
if tolerance > 0:
tolerance = -math.log(tolerance)
hits = sampled_idx.ne(-1).long()
outputs = defaultdict(lambda: [])
size_so_far, start_step = 0, 0
accumulated_free_energy = 0
accumulated_evaluations = 0
for i in range(hits.size(1) + 1):
if ((i == hits.size(1)) or (size_so_far + hits[:, i].sum() > chunk_size)) and (i > start_step):
_outputs, _evals = self.forward_once(
input_fn, field_fn,
ray_start, ray_dir,
{name: s[:, start_step: i]
for name, s in samples.items()},
encoder_states,
early_stop=early_stop,
output_types=output_types)
if _outputs is not None:
accumulated_evaluations += _evals
if 'free_energy' in _outputs:
accumulated_free_energy += _outputs['free_energy'].sum(1)
if tolerance > 0:
early_stop = accumulated_free_energy > tolerance
hits[early_stop] *= 0
for key in _outputs:
outputs[key] += [_outputs[key]]
else:
for key in outputs:
outputs[key] += [outputs[key][-1].new_zeros(
outputs[key][-1].size(0),
sampled_depth[:, start_step: i].size(1),
*outputs[key][-1].size()[2:]
)]
start_step, size_so_far = i, 0
if (i < hits.size(1)):
size_so_far += hits[:, i].sum()
outputs = {key: torch.cat(outputs[key], 1) for key in outputs}
results = {}
if 'free_energy' in outputs:
free_energy = outputs['free_energy']
shifted_free_energy = torch.cat([free_energy.new_zeros(sampled_depth.size(0), 1), free_energy[:, :-1]], dim=-1) # shift one step
a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here
b = torch.exp(-torch.cumsum(shifted_free_energy.float(), dim=-1)) # probability of everything is empty up to now
probs = (a * b).type_as(free_energy) # probability of the ray hits something here
else:
probs = outputs['sample_mask'].type_as(sampled_depth) / sampled_depth.size(-1) # assuming a uniform distribution
if global_weights is not None:
probs = probs * global_weights
depth = (sampled_depth * probs).sum(-1)
missed = 1 - probs.sum(-1)
results.update({
'probs': probs, 'depths': depth,
'max_depths': sampled_depth.masked_fill(hits.eq(0), -1).max(1).values,
'min_depths': sampled_depth.min(1).values,
'missed': missed, 'ae': accumulated_evaluations
})
if original_depth is not None:
results['z'] = (original_depth * probs).sum(-1)
if 'texture' in outputs:
results['colors'] = (outputs['texture'] * probs.unsqueeze(-1)).sum(-2)
if 'normal' in outputs:
results['normal'] = (outputs['normal'] * probs.unsqueeze(-1)).sum(-2)
if not self.trace_normal:
results['eikonal-term'] = (outputs['normal'].norm(p=2, dim=-1) - 1) ** 2
else:
results['eikonal-term'] = torch.log((outputs['normal'] ** 2).sum(-1) + 1e-6)
results['eikonal-term'] = results['eikonal-term'][sampled_idx.ne(-1)]
if 'feat_n2' in outputs:
results['feat_n2'] = (outputs['feat_n2'] * probs).sum(-1)
results['regz-term'] = outputs['feat_n2'][sampled_idx.ne(-1)]
return results