def __init__()

in fairnr/tasks/neural_rendering.py [0:0]


    def __init__(self, args):
        super().__init__(args)
        
        self._trainer, self._dummy_batch = None, None

        # check dataset
        self.train_data = self.val_data = self.test_data = args.data
        self.object_ids = None if args.object_id_path is None else \
            {line.strip(): i for i, line in enumerate(open(args.object_id_path))}
        self.output_valid = getattr(args, "output_valid", None)
        
        if os.path.isdir(args.data):
            if os.path.exists(args.data + '/train.txt'):
                self.train_data = args.data + '/train.txt'
            if os.path.exists(args.data + '/val.txt'):
                self.val_data = args.data + '/val.txt'
            if os.path.exists(args.data + '/test.txt'):
                self.test_data = args.data + '/test.txt'
            if self.object_ids is None and os.path.exists(args.data + '/object_ids.txt'):
                self.object_ids = {line.strip(): i for i, line in enumerate(open(args.data + '/object_ids.txt'))}
        if self.object_ids is not None:
            self.ids_object = {self.object_ids[o]: o for o in self.object_ids}
        else:
            self.ids_object = {0: 'model'}

        if len(self.args.tensorboard_logdir) > 0 and getattr(args, "distributed_rank", -1) == 0:
            from tensorboardX import SummaryWriter
            self.writer = SummaryWriter(self.args.tensorboard_logdir + '/images')
        else:
            self.writer = None

        self._num_updates = {'pv': -1, 'sv': -1, 'rs': -1, 're': -1}
        self.pruning_every_steps = getattr(self.args, "pruning_every_steps", None)
        self.pruning_th = getattr(self.args, "pruning_th", 0.5)
        self.rendering_every_steps = getattr(self.args, "rendering_every_steps", None)
        self.steps_to_half_voxels = getattr(self.args, "half_voxel_size_at", None)
        self.steps_to_reduce_step = getattr(self.args, "reduce_step_size_at", None)
        self.steps_to_prune_voxels = getattr(self.args, "prune_voxel_at", None)

        if self.steps_to_half_voxels is not None:
            self.steps_to_half_voxels = [int(s) for s in self.steps_to_half_voxels.split(',')]
        if self.steps_to_reduce_step is not None:
            self.steps_to_reduce_step = [int(s) for s in self.steps_to_reduce_step.split(',')]
        if self.steps_to_prune_voxels is not None:
            self.steps_to_prune_voxels = [int(s) for s in self.steps_to_prune_voxels.split(',')]

        if self.rendering_every_steps is not None:
            gen_args = {
                'path': args.save_dir,
                'render_beam': 1, 'render_resolution': '512x512',
                'render_num_frames': 120, 'render_angular_speed': 3,
                'render_output_types': ["rgb"], 'render_raymarching_steps': 10,
                'render_at_vector': "(0,0,0)", 'render_up_vector': "(0,0,-1)",
                'render_path_args': "{'radius': 1.5, 'h': 0.5}",
                'render_path_style': 'circle', "render_output": None
            }
            gen_args.update(json.loads(getattr(args, 'rendering_args', '{}') or '{}'))
            self.renderer = self.build_generator(Namespace(**gen_args))    
        else:
            self.renderer = None

        self.train_views = parse_views(args.train_views)
        self.valid_views = parse_views(args.valid_views)
        self.test_views  = parse_views(args.test_views)