in tensorflow_graphics/projects/cvxnet/lib/datasets.py [0:0]
def shapenet(split, args):
"""ShapeNet Dataset.
Args:
split: string, the split of the dataset, either "train" or "test".
args: tf.app.flags.FLAGS, configurations.
Returns:
dataset: tf.data.Dataset, the shapenet dataset.
"""
total_points = 100000
data_dir = args.data_dir
sample_bbx = args.sample_bbx
if split != "train":
sample_bbx = total_points
sample_surf = args.sample_surf
if split != "train":
sample_surf = 0
image_h = args.image_h
image_w = args.image_w
image_d = args.image_d
n_views = args.n_views
depth_h = args.depth_h
depth_w = args.depth_w
depth_d = args.depth_d
batch_size = args.batch_size if split == "train" else 1
dims = args.dims
def _parser(example):
fs = tf.parse_single_example(
example,
features={
"rgb":
tf.FixedLenFeature([n_views * image_h * image_w * image_d],
tf.float32),
"depth":
tf.FixedLenFeature([depth_d * depth_h * depth_w], tf.float32),
"bbox_samples":
tf.FixedLenFeature([total_points * (dims + 1)], tf.float32),
"surf_samples":
tf.FixedLenFeature([total_points * (dims + 1)], tf.float32),
"name":
tf.FixedLenFeature([], tf.string),
})
fs["rgb"] = tf.reshape(fs["rgb"], [n_views, image_h, image_w, image_d])
fs["depth"] = tf.reshape(fs["depth"], [depth_d, depth_h, depth_w, 1])
fs["bbox_samples"] = tf.reshape(fs["bbox_samples"],
[total_points, dims + 1])
fs["surf_samples"] = tf.reshape(fs["surf_samples"],
[total_points, dims + 1])
return fs
def _sampler(example):
image = tf.gather(
example["rgb"],
tf.random.uniform((),
minval=0,
maxval=n_views if split == "train" else 1,
dtype=tf.int32),
axis=0)
image = tf.image.resize_bilinear(tf.expand_dims(image, axis=0), [224, 224])
depth = example["depth"] / 1000.
sample_points = []
sample_labels = []
if sample_bbx > 0:
if split == "train":
indices_bbx = tf.random.uniform([sample_bbx],
minval=0,
maxval=total_points,
dtype=tf.int32)
bbx_samples = tf.gather(example["bbox_samples"], indices_bbx, axis=0)
else:
bbx_samples = example["bbox_samples"]
bbx_points, bbx_labels = tf.split(bbx_samples, [3, 1], axis=-1)
sample_points.append(bbx_points)
sample_labels.append(bbx_labels)
if sample_surf > 0:
indices_surf = tf.random.uniform([sample_surf],
minval=0,
maxval=total_points,
dtype=tf.int32)
surf_samples = tf.gather(example["surf_samples"], indices_surf, axis=0)
surf_points, surf_labels = tf.split(surf_samples, [3, 1], axis=-1)
sample_points.append(surf_points)
sample_labels.append(surf_labels)
points = tf.concat(sample_points, axis=0)
point_labels = tf.cast(tf.concat(sample_labels, axis=0) <= 0., tf.float32)
image = tf.reshape(image, [224, 224, image_d])
depth = tf.reshape(depth, [depth_d, depth_h, depth_w])
depth = tf.transpose(depth, [1, 2, 0])
points = tf.reshape(points, [sample_bbx + sample_surf, 3])
point_labels = tf.reshape(point_labels, [sample_bbx + sample_surf, 1])
return {
"image": image,
"depth": depth,
"point": points,
"point_label": point_labels,
"name": example["name"],
}
data_pattern = path.join(data_dir, "{}-{}-*".format(args.obj_class, split))
data_files = tf.gfile.Glob(data_pattern)
if not data_files:
raise ValueError("{} did not match any files".format(data_pattern))
file_count = len(data_files)
filenames = tf.data.Dataset.list_files(data_pattern, shuffle=True)
data = filenames.interleave(
lambda x: tf.data.TFRecordDataset([x]),
cycle_length=file_count,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
data = data.map(_parser, num_parallel_calls=tf.data.experimental.AUTOTUNE)
data = data.map(_sampler, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if split == "train":
data = data.shuffle(batch_size * 5).repeat(-1)
return data.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)