in fairnr/tasks/neural_rendering.py [0:0]
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
if (((self.pruning_every_steps is not None) and \
(update_num % self.pruning_every_steps == 0) and \
(update_num > 0)) or \
((self.steps_to_prune_voxels is not None) and \
update_num in self.steps_to_prune_voxels) \
) and \
(update_num > self._num_updates['pv']) and \
hasattr(model, 'prune_voxels'):
model.eval()
if getattr(self.args, "pruning_rerun_train_set", False):
with torch.no_grad():
model.clean_caches(reset=True)
progress = progress_bar.progress_bar(
self._unique_trainitr.next_epoch_itr(shuffle=False),
prefix=f"pruning based statiscs over training set",
tensorboard_logdir=None,
default_log_format=self.args.log_format if self.args.log_format is not None else "tqdm")
for step, inner_sample in enumerate(progress):
outs = model(**self._trainer._prepare_sample(self.filter_dummy(inner_sample)))
progress.log(stats=outs['other_logs'], tag='track', step=step)
model.prune_voxels(self.pruning_th, train_stats=getattr(self.args, "pruning_with_train_stats", False))
self.update_step(update_num, 'pv')
if self.steps_to_half_voxels is not None and \
(update_num in self.steps_to_half_voxels) and \
(update_num > self._num_updates['sv']):
model.split_voxels()
self.update_step(update_num, 'sv')
raise ResetTrainerException
if self.rendering_every_steps is not None and \
(update_num % self.rendering_every_steps == 0) and \
(update_num > 0) and \
self.renderer is not None and \
(update_num > self._num_updates['re']):
sample_clone = {key: sample[key].clone() if sample[key] is not None else None for key in sample }
outputs = self.inference_step(self.renderer, [model], [sample_clone, 0])[1]
if getattr(self.args, "distributed_rank", -1) == 0: # save only for master
self.renderer.save_images(outputs, update_num)
self.steps_to_half_voxels = [a for a in self.steps_to_half_voxels if a != update_num]
if self.steps_to_reduce_step is not None and \
update_num in self.steps_to_reduce_step and \
(update_num > self._num_updates['rs']):
model.reduce_stepsize()
self.update_step(update_num, 'rs')
self.update_step(update_num, 'step')
return super().train_step(sample, model, criterion, optimizer, update_num, ignore_grad)