def train_step()

in fairnr/tasks/neural_rendering.py [0:0]


    def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
        if (((self.pruning_every_steps is not None) and \
            (update_num % self.pruning_every_steps == 0) and \
            (update_num > 0)) or \
            ((self.steps_to_prune_voxels is not None) and \
             update_num in self.steps_to_prune_voxels) \
             ) and \
            (update_num > self._num_updates['pv']) and \
            hasattr(model, 'prune_voxels'):
            model.eval()
            if getattr(self.args, "pruning_rerun_train_set", False):
                with torch.no_grad():
                    model.clean_caches(reset=True)
                    progress = progress_bar.progress_bar(
                        self._unique_trainitr.next_epoch_itr(shuffle=False),
                        prefix=f"pruning based statiscs over training set",
                        tensorboard_logdir=None, 
                        default_log_format=self.args.log_format if self.args.log_format is not None else "tqdm")
                    for step, inner_sample in enumerate(progress):
                        outs = model(**self._trainer._prepare_sample(self.filter_dummy(inner_sample)))
                        progress.log(stats=outs['other_logs'], tag='track', step=step)

            model.prune_voxels(self.pruning_th, train_stats=getattr(self.args, "pruning_with_train_stats", False))
            self.update_step(update_num, 'pv')

        if self.steps_to_half_voxels is not None and \
            (update_num in self.steps_to_half_voxels) and \
            (update_num > self._num_updates['sv']):
            
            model.split_voxels()
            self.update_step(update_num, 'sv')
            raise ResetTrainerException

        if self.rendering_every_steps is not None and \
            (update_num % self.rendering_every_steps == 0) and \
            (update_num > 0) and \
            self.renderer is not None and \
            (update_num > self._num_updates['re']):

            sample_clone = {key: sample[key].clone() if sample[key] is not None else None for key in sample }
            outputs = self.inference_step(self.renderer, [model], [sample_clone, 0])[1]
            if getattr(self.args, "distributed_rank", -1) == 0:  # save only for master
                self.renderer.save_images(outputs, update_num)
            self.steps_to_half_voxels = [a for a in self.steps_to_half_voxels if a != update_num]

        if self.steps_to_reduce_step is not None and \
            update_num in self.steps_to_reduce_step and \
            (update_num > self._num_updates['rs']):

            model.reduce_stepsize()
            self.update_step(update_num, 'rs')
        
        self.update_step(update_num, 'step')
        return super().train_step(sample, model, criterion, optimizer, update_num, ignore_grad)