def decode_bytes_multiple_scannet()

in tensorflow_graphics/projects/points_to_3Dobjects/data_preparation/extract_protos.py [0:0]


def decode_bytes_multiple_scannet(serialized):
  """Extracts the contents from a VoxelSample proto to tensors."""
  status = True
  if status:
    message_type = 'giotto_occluded_primitives.MultipleObjects'
    name_type_shape = {
        'name': [tf.string, []],
        'scene_filename': [tf.string, []],
        'image_data': [tf.string, []],
        'image_size': [tf.int32, []],
        'center2d': [tf.float32, [-1, 2]],
        'center3d': [tf.float32, [-1, 3]],
        'box_dims2d': [tf.float32, [-1, 2]],
        'box_dims3d': [tf.float32, [-1, 3]],
        'rotations_3d': [tf.float32, [-1, 3, 3]],
        'rt': [tf.float32, [3, 4]],
        'k': [tf.float32, []],
        'classes': [tf.int32, []],
        'mesh_names': [tf.string, []],
        'shapes': [tf.int32, []]
    }

    field_names = [n for n in name_type_shape.keys()]
    output_types = [name_type_shape[k][0] for k in name_type_shape.keys()]
    _, tensors = tf.io.decode_proto(serialized, message_type,
                                    field_names, output_types)

    # Unpack tensors into dict and reshape
    tensors = [tf.squeeze(t) for t in tensors]
    tensor_dict = dict(zip(name_type_shape.keys(), tensors))
    for name, [_, shape] in name_type_shape.items():
      if shape:
        tensor_dict[name] = tf.reshape(tensor_dict[name], shape)

    # Decode image and set shape
    image = tf.io.decode_image(tensor_dict['image_data'], channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    width = tf.shape(image)[1]
    image_padded = tf.image.pad_to_bounding_box(image, 0, 0, width, width)
    image = image_padded
    image.set_shape([None, None, 3])
    original_image_spatial_shape = tf.stack(
        [tf.shape(image)[1], tf.shape(image)[0]], axis=0)
    tensor_dict['image_size'] = tf.stack([width, width, 3], axis=0)

    num_boxes = tf.shape(tensor_dict['center2d'])[0]

    # Compute ground truth boxes. (center2d,size2d) --> (ymin,xmin,ymax,xmax)
    box_min = tensor_dict['center2d'] - tensor_dict['box_dims2d'] / 2.0
    box_max = tensor_dict['center2d'] + tensor_dict['box_dims2d'] / 2.0
    size = tf.cast(original_image_spatial_shape, tf.float32)
    size = tf.expand_dims(size, axis=0)
    size = tf.tile(size, [num_boxes, 1])
    box_min = tf.divide(box_min, size)
    box_max = tf.divide(box_max, size)
    groundtruth_boxes = tf.reshape(
        tf.concat([box_min[:, 1:2], box_min[:, 0:1],
                   box_max[:, 1:2], box_max[:, 0:1]], axis=1), [num_boxes, 4])

    # Compute the dot on the ground plane used as prior for 3D pose (batched)
    dot = tf.transpose(tf_utils.compute_dot(tensor_dict['image_size'],
                                            tensor_dict['k'],
                                            tensor_dict['rt'],
                                            axis=1,
                                            image_intersection=(0.5, 0.6)))

    dot_x = tf.transpose(tf_utils.compute_dot(tensor_dict['image_size'],
                                              tensor_dict['k'],
                                              tensor_dict['rt'],
                                              axis=1,
                                              image_intersection=(0.6, 0.6)))
    dot_x -= dot
    dot_x = tf.math.l2_normalize(dot_x)
    angle_y = tf.math.atan2(dot_x[0, 2], dot_x[0, 0])

    table = tf.lookup.StaticHashTable(
        initializer=tf.lookup.KeyValueTensorInitializer(
            keys=tf.constant([2818832, 2747177, 2871439, 2933112,
                              3001627, 3211117, 4256520, 4379243]),
            values=tf.constant([0, 1, 2, 3, 4, 5, 6, 7])),
        default_value=tf.constant(-1),
        name='classes')
    classes = table.lookup(tensor_dict['classes'])

  rotations_3d = tensor_dict['rotations_3d']  # (K, 3, 3)
  translations_3d = tensor_dict['center3d']
  sizes_3d = tensor_dict['box_dims3d']

  translation_correction = tf.concat(
      [tf.concat([tf.eye(3, dtype=tf.float32), tf.transpose(dot)], axis=-1),
       [[0.0, 0.0, 0.0, 1.0]]], axis=0)

  rt = tensor_dict['rt'] @ translation_correction

  # Transform to camera coordinate system
  rotation_left = rotation_matrix_3d.from_euler([0.0, angle_y, 0.0])
  rotation_right = rotation_matrix_3d.from_euler([0.0, -angle_y, 0.0])

  rotations_3d = rotation_left @ rotations_3d

  translations_3d = tf.transpose(
      rotation_left @ tf.transpose(translations_3d - dot))
  # translations_3d = translations_3d - dot

  rotation_right = tf.concat([tf.concat(
      [rotation_right, [[0.0], [0.0], [0.0]]], axis=-1),
                              [[0.0, 0.0, 0.0, 1.0]]], axis=0)

  rt = rt @ rotation_right

  # Adapt rotation for table, bottle, bowl
  # fix_rotation_matrix(rotations_3d[0], classes[0])
  num_boxes = tf.shape(tensor_dict['center2d'])[0]

  rotations_3d = tf.reshape(rotations_3d, [-1, 9])

  status = True
  if status:
    output_dict = {}
    output_dict['name'] = tensor_dict['name']  # e.g. 'train-0000'
    output_dict['scene_filename'] = tensor_dict['scene_filename']
    output_dict['mesh_names'] = tf.reshape(tensor_dict['mesh_names'], [-1])
    output_dict['classes'] = tf.reshape(tensor_dict['classes'], [-1])
    output_dict['image'] = image
    output_dict['image_data'] = tensor_dict['image_data']
    output_dict['original_image_spatial_shape'] = original_image_spatial_shape
    output_dict['num_boxes'] = num_boxes
    output_dict['center2d'] = tensor_dict['center2d']
    output_dict['groundtruth_boxes'] = groundtruth_boxes
    output_dict['dot'] = dot - dot
    output_dict['sizes_3d'] = sizes_3d
    output_dict['translations_3d'] = translations_3d
    output_dict['rotations_3d'] = rotations_3d
    output_dict['rt'] = rt
    output_dict['k'] = tensor_dict['k']
    output_dict['groundtruth_valid_classes'] = tf.reshape(classes, [-1])
    output_dict['shapes'] = tf.reshape(tensor_dict['shapes'], [-1])

  return output_dict