in fairnr/data/shape_dataset.py [0:0]
def __init__(self,
paths,
views,
num_view,
subsample_valid=-1,
resolution=None,
load_depth=False,
load_mask=False,
train=True,
preload=True,
repeat=1,
binarize=True,
bg_color="1,1,1",
min_color=-1,
ids=None):
super().__init__(paths, False, repeat, subsample_valid, ids)
self.train = train
self.load_depth = load_depth
self.load_mask = load_mask
self.views = views
self.num_view = num_view
if isinstance(resolution, str):
self.resolution = [int(r) for r in resolution.split('x')]
else:
self.resolution = [resolution, resolution]
self.world2camera = True
self.cache_view = None
bg_color = [float(b) for b in bg_color.split(',')] \
if isinstance(bg_color, str) else [bg_color]
if min_color == -1:
bg_color = [b * 2 - 1 for b in bg_color]
if len(bg_color) == 1:
bg_color = bg_color + bg_color + bg_color
self.bg_color = bg_color
self.min_color = min_color
self.apply_mask_color = (self.bg_color[0] >= -1) & (self.bg_color[0] <= 1) # if need to apply
# -- load per-view data
_data_per_view = {}
_data_per_view['rgb'] = self.find_rgb()
_data_per_view['ext'] = self.find_extrinsics()
if self.find_intrinsics_per_view() is not None:
_data_per_view['ixt_v'] = self.find_intrinsics_per_view()
if self.load_depth:
_data_per_view['dep'] = self.find_depth()
if self.load_mask:
_data_per_view['mask'] = self.find_mask()
_data_per_view['view'] = self.summary_view_data(_data_per_view)
# group the data.
_index = 0
for r in range(repeat):
# HACK: making several copies to enable multi-GPU usage.
if r == 0 and preload:
self.cache = []
logger.info('pre-load the dataset into memory.')
for id in range(self.total_num_shape):
element = {}
total_num_view = len(_data_per_view['rgb'][id])
perm_ids = np.random.permutation(total_num_view) if train else np.arange(total_num_view)
for key in _data_per_view:
element[key] = [_data_per_view[key][id][i] for i in perm_ids]
self.data[_index].update(element)
if r == 0 and preload:
phase_name = f"{'train' if self.train else 'valid'}" + \
f".{self.resolution[0]}x{self.resolution[1]}" + \
f"{'.d' if load_depth else ''}" + \
f"{'.m' if load_mask else ''}" + \
f"{'b' if not self.apply_mask_color else ''}" + \
"_full"
logger.info("preload {}-{}".format(id, phase_name))
if binarize:
cache = self._load_binary(id, np.arange(total_num_view), phase_name)
else:
cache = self._load_batch(self.data, id, np.arange(total_num_view))
self.cache += [cache]
_index += 1
# group the data together
self.data_index = []
for i, d in enumerate(self.data):
if self.train:
index_list = list(range(len(d['rgb'])))
self.data_index.append(
data_utils.InfIndex(index_list, shuffle=True)
)
else:
copy_id = i // self.total_num_shape
index_list = []
for j in range(copy_id * num_view, copy_id * num_view + num_view):
index_list.append(j % len(d['rgb']))
self.data_index.append(
data_utils.InfIndex(index_list, shuffle=False)
)