def load_dataset()

in fairnr/tasks/neural_rendering.py [0:0]


    def load_dataset(self, split, **kwargs):
        """
        Load a given dataset split (train, valid, test)
        """
        self.datasets[split] = ShapeViewDataset(
            self.train_data if split == 'train' else \
                self.val_data if split == 'valid' else self.test_data,
            views=self.train_views if split == 'train' else \
                self.valid_views if split == 'valid' else self.test_views,
            num_view=self.args.view_per_batch if split == 'train' else \
                self.args.valid_view_per_batch if split == 'valid' else 1,
            resolution=self.args.view_resolution if split == 'train' else \
                getattr(self.args, "valid_view_resolution", self.args.view_resolution) if split == 'valid' else \
                    getattr(self.args, "render_resolution", self.args.view_resolution),
            subsample_valid=self.args.subsample_valid if split == 'valid' else -1,
            train=(split=='train'),
            load_depth=self.args.load_depth and (split!='test'),
            load_mask=self.args.load_mask and (split!='test'),
            repeat=self.repeat_dataset(split),
            preload=(not getattr(self.args, "no_preload", False)) and (split!='test'),
            binarize=(not getattr(self.args, "no_load_binary", False)) and (split!='test'),
            bg_color=getattr(self.args, "transparent_background", "1,1,1"),
            min_color=getattr(self.args, "min_color", -1),
            ids=self.object_ids
        )

        if split == 'train':
            max_step = getattr(self.args, "virtual_epoch_steps", None)
            if max_step is not None:
                total_num_models = max_step * self.args.distributed_world_size * self.args.max_sentences
            else:
                total_num_models = 10000000

            if getattr(self.args, "pruning_rerun_train_set", False):
                self._unique_trainset = ShapeViewStreamDataset(copy.deepcopy(self.datasets[split]))  # backup
                self._unique_trainitr = self.get_batch_iterator(
                    self._unique_trainset, max_sentences=self.args.max_sentences_valid, seed=self.args.seed,
                    num_shards=self.args.distributed_world_size, shard_id=self.args.distributed_rank, 
                    num_workers=self.args.num_workers)
            self.datasets[split] = InfiniteDataset(self.datasets[split], total_num_models)

        if split == 'valid':
            self.datasets[split] = ShapeViewStreamDataset(self.datasets[split])