def forward_chunk()

in fairnr/modules/renderer.py [0:0]


    def forward_chunk(
        self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states,
        gt_depths=None, output_types=['sigma', 'texture'], global_weights=None,
        ):
        if self.trace_normal:
            output_types += ['normal']

        sampled_depth = samples['sampled_point_depth']
        sampled_idx = samples['sampled_point_voxel_idx'].long()
        original_depth = samples.get('original_point_depth', None)

        tolerance = self.raymarching_tolerance
        chunk_size = self.chunk_size if self.training else self.valid_chunk_size
        early_stop = None
        if tolerance > 0:
            tolerance = -math.log(tolerance)
            
        hits = sampled_idx.ne(-1).long()
        outputs = defaultdict(lambda: [])
        size_so_far, start_step = 0, 0
        accumulated_free_energy = 0
        accumulated_evaluations = 0
        for i in range(hits.size(1) + 1):
            if ((i == hits.size(1)) or (size_so_far + hits[:, i].sum() > chunk_size)) and (i > start_step):
                _outputs, _evals = self.forward_once(
                        input_fn, field_fn, 
                        ray_start, ray_dir, 
                        {name: s[:, start_step: i] 
                            for name, s in samples.items()},
                        encoder_states, 
                        early_stop=early_stop,
                        output_types=output_types)
                if _outputs is not None:
                    accumulated_evaluations += _evals

                    if 'free_energy' in _outputs:
                        accumulated_free_energy += _outputs['free_energy'].sum(1)
                        if tolerance > 0:
                            early_stop = accumulated_free_energy > tolerance
                            hits[early_stop] *= 0
                    
                    for key in _outputs:
                        outputs[key] += [_outputs[key]]
                else:
                    for key in outputs:
                        outputs[key] += [outputs[key][-1].new_zeros(
                            outputs[key][-1].size(0),
                            sampled_depth[:, start_step: i].size(1),
                            *outputs[key][-1].size()[2:] 
                        )]
                start_step, size_so_far = i, 0
            
            if (i < hits.size(1)):
                size_so_far += hits[:, i].sum()

        outputs = {key: torch.cat(outputs[key], 1) for key in outputs}
        results = {}
        
        if 'free_energy' in outputs:
            free_energy = outputs['free_energy']
            shifted_free_energy = torch.cat([free_energy.new_zeros(sampled_depth.size(0), 1), free_energy[:, :-1]], dim=-1)  # shift one step
            a = 1 - torch.exp(-free_energy.float())                             # probability of it is not empty here
            b = torch.exp(-torch.cumsum(shifted_free_energy.float(), dim=-1))   # probability of everything is empty up to now
            probs = (a * b).type_as(free_energy)                                # probability of the ray hits something here
        else:
            probs = outputs['sample_mask'].type_as(sampled_depth) / sampled_depth.size(-1)  # assuming a uniform distribution

        if global_weights is not None:
            probs = probs * global_weights

        depth = (sampled_depth * probs).sum(-1)
        missed = 1 - probs.sum(-1)
        
        results.update({
            'probs': probs, 'depths': depth, 
            'max_depths': sampled_depth.masked_fill(hits.eq(0), -1).max(1).values,
            'min_depths': sampled_depth.min(1).values,
            'missed': missed, 'ae': accumulated_evaluations
        })
        if original_depth is not None:
            results['z'] = (original_depth * probs).sum(-1)

        if 'texture' in outputs:
            results['colors'] = (outputs['texture'] * probs.unsqueeze(-1)).sum(-2)
        
        if 'normal' in outputs:
            results['normal'] = (outputs['normal'] * probs.unsqueeze(-1)).sum(-2)
            if not self.trace_normal:
                results['eikonal-term'] = (outputs['normal'].norm(p=2, dim=-1) - 1) ** 2
            else:
                results['eikonal-term'] = torch.log((outputs['normal'] ** 2).sum(-1) + 1e-6)
            results['eikonal-term'] = results['eikonal-term'][sampled_idx.ne(-1)]

        if 'feat_n2' in outputs:
            results['feat_n2'] = (outputs['feat_n2'] * probs).sum(-1)
            results['regz-term'] = outputs['feat_n2'][sampled_idx.ne(-1)]
            
        return results