def amass()

in tensorflow_graphics/projects/nasa/lib/datasets.py [0:0]


def amass(split, hparams):
  """Construct an AMASS data loader."""

  def _input_fn(params):  # pylint: disable=unused-argument
    # Dataset constants.
    n_bbox = 100000
    n_surf = 100000
    n_points = n_bbox + n_surf
    n_vert = 6890
    n_frames = 1
    # Parse parameters for global configurations.
    n_dims = hparams.n_dims
    data_dir = hparams.data_dir
    sample_bbox = hparams.sample_bbox
    sample_surf = hparams.sample_surf
    batch_size = hparams.batch_size
    subject = hparams.subject
    motion = hparams.motion
    n_parts = hparams.n_parts

    def _parse_tfrecord(serialized_example):
      fs = tf.parse_single_example(
          serialized_example,
          features={
              'point':
                  tf.FixedLenFeature([n_frames * n_points * n_dims],
                                     tf.float32),
              'label':
                  tf.FixedLenFeature([n_frames * n_points * 1], tf.float32),
              'vert':
                  tf.FixedLenFeature([n_frames * n_vert * n_dims], tf.float32),
              'weight':
                  tf.FixedLenFeature([n_frames * n_vert * n_parts], tf.float32),
              'transform':
                  tf.FixedLenFeature(
                      [n_frames * n_parts * (n_dims + 1) * (n_dims + 1)],
                      tf.float32),
              'joint':
                  tf.FixedLenFeature([n_frames * n_parts * n_dims], tf.float32),
              'name':
                  tf.FixedLenFeature([], tf.string),
          })
      fs['point'] = tf.reshape(fs['point'], [n_frames, n_points, n_dims])
      fs['label'] = tf.reshape(fs['label'], [n_frames, n_points, 1])
      fs['vert'] = tf.reshape(fs['vert'], [n_frames, n_vert, n_dims])
      fs['weight'] = tf.reshape(fs['weight'], [n_frames, n_vert, n_parts])
      fs['transform'] = tf.reshape(fs['transform'],
                                   [n_frames, n_parts, n_dims + 1, n_dims + 1])
      fs['joint'] = tf.reshape(fs['joint'], [n_frames, n_parts, n_dims])
      return fs

    def _sample_frame_points(fs):
      feature = {}
      for k, v in fs.items():
        feature[k] = v
      points = feature['point'][0]
      labels = feature['label'][0]
      sample_points = []
      sample_labels = []
      if sample_bbox > 0:
        indices_bbox = tf.random.uniform([sample_bbox],
                                         minval=0,
                                         maxval=n_bbox,
                                         dtype=tf.int32)
        bbox_samples = tf.gather(points[:n_bbox], indices_bbox, axis=0)
        bbox_labels = tf.gather(labels[:n_bbox], indices_bbox, axis=0)
        sample_points.append(bbox_samples)
        sample_labels.append(bbox_labels)
      if sample_surf > 0:
        indices_surf = tf.random.uniform([sample_surf],
                                         minval=0,
                                         maxval=n_surf,
                                         dtype=tf.int32)
        surf_samples = tf.gather(
            points[n_bbox:n_bbox + n_surf], indices_surf, axis=0)
        surf_labels = tf.gather(
            labels[n_bbox:n_bbox + n_surf], indices_surf, axis=0)
        sample_points.append(surf_samples)
        sample_labels.append(surf_labels)
      points = tf.concat(sample_points, axis=0)
      point_labels = tf.concat(sample_labels, axis=0)
      feature['point'] = tf.expand_dims(points, axis=0)
      feature['label'] = tf.expand_dims(point_labels, axis=0)
      return feature

    def _sample_eval_points(fs):
      feature = {}

      feature['transform'] = fs['transform']

      feature['points'] = fs['point'][:, :n_bbox]
      feature['labels'] = fs['label'][:, :n_bbox]

      feature['name'] = fs['name']
      feature['vert'] = fs['vert']
      feature['weight'] = fs['weight']
      feature['joint'] = fs['joint']

      return feature

    data_split = 'train'
    all_motions = list(x for x in range(10))
    if split == 'train':
      file_pattern = [
          path.join(data_dir,
                    '{0}-{1:02d}-{2:02d}-*'.format(data_split, subject, x))
          for x in all_motions if x != motion
      ]
    else:
      file_pattern = [
          path.join(data_dir,
                    '{0}-{1:02d}-{2:02d}-*'.format(data_split, subject, motion))
      ]
    data_files = tf.gfile.Glob(file_pattern)
    if not data_files:
      raise IOError('{} did not match any files'.format(file_pattern))
    filenames = tf.data.Dataset.list_files(file_pattern, shuffle=True)
    data = filenames.interleave(
        lambda x: tf.data.TFRecordDataset([x]),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    data = data.map(
        _parse_tfrecord,
        num_parallel_calls=tf.data.experimental.AUTOTUNE).cache()
    if split == 'train':
      data = data.map(
          _sample_frame_points,
          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    else:
      data = data.map(
          _sample_eval_points, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if split == 'train':
      data = data.shuffle(int(batch_size * 2.5)).repeat(-1)
    else:
      batch_size = 1

    return data.batch(
        batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

  return _input_fn