in fairnr/tasks/neural_rendering.py [0:0]
def load_dataset(self, split, **kwargs):
"""
Load a given dataset split (train, valid, test)
"""
self.datasets[split] = ShapeViewDataset(
self.train_data if split == 'train' else \
self.val_data if split == 'valid' else self.test_data,
views=self.train_views if split == 'train' else \
self.valid_views if split == 'valid' else self.test_views,
num_view=self.args.view_per_batch if split == 'train' else \
self.args.valid_view_per_batch if split == 'valid' else 1,
resolution=self.args.view_resolution if split == 'train' else \
getattr(self.args, "valid_view_resolution", self.args.view_resolution) if split == 'valid' else \
getattr(self.args, "render_resolution", self.args.view_resolution),
subsample_valid=self.args.subsample_valid if split == 'valid' else -1,
train=(split=='train'),
load_depth=self.args.load_depth and (split!='test'),
load_mask=self.args.load_mask and (split!='test'),
repeat=self.repeat_dataset(split),
preload=(not getattr(self.args, "no_preload", False)) and (split!='test'),
binarize=(not getattr(self.args, "no_load_binary", False)) and (split!='test'),
bg_color=getattr(self.args, "transparent_background", "1,1,1"),
min_color=getattr(self.args, "min_color", -1),
ids=self.object_ids
)
if split == 'train':
max_step = getattr(self.args, "virtual_epoch_steps", None)
if max_step is not None:
total_num_models = max_step * self.args.distributed_world_size * self.args.max_sentences
else:
total_num_models = 10000000
if getattr(self.args, "pruning_rerun_train_set", False):
self._unique_trainset = ShapeViewStreamDataset(copy.deepcopy(self.datasets[split])) # backup
self._unique_trainitr = self.get_batch_iterator(
self._unique_trainset, max_sentences=self.args.max_sentences_valid, seed=self.args.seed,
num_shards=self.args.distributed_world_size, shard_id=self.args.distributed_rank,
num_workers=self.args.num_workers)
self.datasets[split] = InfiniteDataset(self.datasets[split], total_num_models)
if split == 'valid':
self.datasets[split] = ShapeViewStreamDataset(self.datasets[split])