in fairnr/tasks/neural_rendering.py [0:0]
def __init__(self, args):
super().__init__(args)
self._trainer, self._dummy_batch = None, None
# check dataset
self.train_data = self.val_data = self.test_data = args.data
self.object_ids = None if args.object_id_path is None else \
{line.strip(): i for i, line in enumerate(open(args.object_id_path))}
self.output_valid = getattr(args, "output_valid", None)
if os.path.isdir(args.data):
if os.path.exists(args.data + '/train.txt'):
self.train_data = args.data + '/train.txt'
if os.path.exists(args.data + '/val.txt'):
self.val_data = args.data + '/val.txt'
if os.path.exists(args.data + '/test.txt'):
self.test_data = args.data + '/test.txt'
if self.object_ids is None and os.path.exists(args.data + '/object_ids.txt'):
self.object_ids = {line.strip(): i for i, line in enumerate(open(args.data + '/object_ids.txt'))}
if self.object_ids is not None:
self.ids_object = {self.object_ids[o]: o for o in self.object_ids}
else:
self.ids_object = {0: 'model'}
if len(self.args.tensorboard_logdir) > 0 and getattr(args, "distributed_rank", -1) == 0:
from tensorboardX import SummaryWriter
self.writer = SummaryWriter(self.args.tensorboard_logdir + '/images')
else:
self.writer = None
self._num_updates = {'pv': -1, 'sv': -1, 'rs': -1, 're': -1}
self.pruning_every_steps = getattr(self.args, "pruning_every_steps", None)
self.pruning_th = getattr(self.args, "pruning_th", 0.5)
self.rendering_every_steps = getattr(self.args, "rendering_every_steps", None)
self.steps_to_half_voxels = getattr(self.args, "half_voxel_size_at", None)
self.steps_to_reduce_step = getattr(self.args, "reduce_step_size_at", None)
self.steps_to_prune_voxels = getattr(self.args, "prune_voxel_at", None)
if self.steps_to_half_voxels is not None:
self.steps_to_half_voxels = [int(s) for s in self.steps_to_half_voxels.split(',')]
if self.steps_to_reduce_step is not None:
self.steps_to_reduce_step = [int(s) for s in self.steps_to_reduce_step.split(',')]
if self.steps_to_prune_voxels is not None:
self.steps_to_prune_voxels = [int(s) for s in self.steps_to_prune_voxels.split(',')]
if self.rendering_every_steps is not None:
gen_args = {
'path': args.save_dir,
'render_beam': 1, 'render_resolution': '512x512',
'render_num_frames': 120, 'render_angular_speed': 3,
'render_output_types': ["rgb"], 'render_raymarching_steps': 10,
'render_at_vector': "(0,0,0)", 'render_up_vector': "(0,0,-1)",
'render_path_args': "{'radius': 1.5, 'h': 0.5}",
'render_path_style': 'circle', "render_output": None
}
gen_args.update(json.loads(getattr(args, 'rendering_args', '{}') or '{}'))
self.renderer = self.build_generator(Namespace(**gen_args))
else:
self.renderer = None
self.train_views = parse_views(args.train_views)
self.valid_views = parse_views(args.valid_views)
self.test_views = parse_views(args.test_views)