in tensorflow_graphics/geometry/deformation_energy/as_conformal_as_possible.py [0:0]
def energy(vertices_rest_pose: type_alias.TensorLike,
vertices_deformed_pose: type_alias.TensorLike,
quaternions: type_alias.TensorLike,
edges: type_alias.TensorLike,
vertex_weight: Optional[type_alias.TensorLike] = None,
edge_weight: Optional[type_alias.TensorLike] = None,
conformal_energy: bool = True,
aggregate_loss: bool = True,
name: str = "as_conformal_as_possible_energy"):
"""Estimates an As Conformal As Possible (ACAP) fitting energy.
For a given mesh in rest pose, this function evaluates a variant of the ACAP
[1] fitting energy for a batch of deformed meshes. The vertex weights and edge
weights are defined on the rest pose.
The method implemented here is similar to [2], but with an added free variable
capturing a scale factor per vertex.
[1]: Yusuke Yoshiyasu, Wan-Chun Ma, Eiichi Yoshida, and Fumio Kanehiro.
"As-Conformal-As-Possible Surface Registration." Computer Graphics Forum. Vol.
33. No. 5. 2014.</br>
[2]: Olga Sorkine, and Marc Alexa.
"As-rigid-as-possible surface modeling". Symposium on Geometry Processing.
Vol. 4. 2007.
Note:
In the description of the arguments, V corresponds to
the number of vertices in the mesh, and E to the number of edges in this
mesh.
Note:
In the following, A1 to An are optional batch dimensions.
Args:
vertices_rest_pose: A tensor of shape `[V, 3]` containing the position of
all the vertices of the mesh in rest pose.
vertices_deformed_pose: A tensor of shape `[A1, ..., An, V, 3]` containing
the position of all the vertices of the mesh in deformed pose.
quaternions: A tensor of shape `[A1, ..., An, V, 4]` defining a rigid
transformation to apply to each vertex of the rest pose. See Section 2
from [1] for further details.
edges: A tensor of shape `[E, 2]` defining indices of vertices that are
connected by an edge.
vertex_weight: An optional tensor of shape `[V]` defining the weight
associated with each vertex. Defaults to a tensor of ones.
edge_weight: A tensor of shape `[E]` defining the weight of edges. Common
choices for these weights include uniform weighting, and cotangent
weights. Defaults to a tensor of ones.
conformal_energy: A `bool` indicating whether each vertex is associated with
a scale factor or not. If this parameter is True, scaling information must
be encoded in the norm of `quaternions`. If this parameter is False, this
function implements the energy described in [2].
aggregate_loss: A `bool` defining whether the returned loss should be an
aggregate measure. When True, the mean squared error is returned. When
False, returns two losses for every edge of the mesh.
name: A name for this op. Defaults to "as_conformal_as_possible_energy".
Returns:
When aggregate_loss is `True`, returns a tensor of shape `[A1, ..., An]`
containing the ACAP energies. When aggregate_loss is `False`, returns a
tensor of shape `[A1, ..., An, 2*E]` containing each term of the summation
described in the equation 7 of [2].
Raises:
ValueError: if the shape of `vertices_rest_pose`, `vertices_deformed_pose`,
`quaternions`, `edges`, `vertex_weight`, or `edge_weight` is not supported.
"""
with tf.name_scope(name):
vertices_rest_pose = tf.convert_to_tensor(value=vertices_rest_pose)
vertices_deformed_pose = tf.convert_to_tensor(value=vertices_deformed_pose)
quaternions = tf.convert_to_tensor(value=quaternions)
edges = tf.convert_to_tensor(value=edges)
if vertex_weight is not None:
vertex_weight = tf.convert_to_tensor(value=vertex_weight)
if edge_weight is not None:
edge_weight = tf.convert_to_tensor(value=edge_weight)
shape.check_static(
tensor=vertices_rest_pose,
tensor_name="vertices_rest_pose",
has_rank=2,
has_dim_equals=(-1, 3))
shape.check_static(
tensor=vertices_deformed_pose,
tensor_name="vertices_deformed_pose",
has_rank_greater_than=1,
has_dim_equals=(-1, 3))
shape.check_static(
tensor=quaternions,
tensor_name="quaternions",
has_rank_greater_than=1,
has_dim_equals=(-1, 4))
shape.compare_batch_dimensions(
tensors=(vertices_deformed_pose, quaternions),
last_axes=(-3, -3),
broadcast_compatible=False)
shape.check_static(
tensor=edges, tensor_name="edges", has_rank=2, has_dim_equals=(-1, 2))
tensors_with_vertices = [
vertices_rest_pose, vertices_deformed_pose, quaternions
]
names_with_vertices = [
"vertices_rest_pose", "vertices_deformed_pose", "quaternions"
]
axes_with_vertices = [-2, -2, -2]
if vertex_weight is not None:
shape.check_static(
tensor=vertex_weight, tensor_name="vertex_weight", has_rank=1)
tensors_with_vertices.append(vertex_weight)
names_with_vertices.append("vertex_weight")
axes_with_vertices.append(0)
shape.compare_dimensions(
tensors=tensors_with_vertices,
axes=axes_with_vertices,
tensor_names=names_with_vertices)
if edge_weight is not None:
shape.check_static(
tensor=edge_weight, tensor_name="edge_weight", has_rank=1)
shape.compare_dimensions(
tensors=(edges, edge_weight),
axes=(0, 0),
tensor_names=("edges", "edge_weight"))
if not conformal_energy:
quaternions = quaternion.normalize(quaternions)
# Extracts the indices of vertices.
indices_i, indices_j = tf.unstack(edges, axis=-1)
# Extracts the vertices we need per term.
vertices_i_rest = tf.gather(vertices_rest_pose, indices_i, axis=-2)
vertices_j_rest = tf.gather(vertices_rest_pose, indices_j, axis=-2)
vertices_i_deformed = tf.gather(vertices_deformed_pose, indices_i, axis=-2)
vertices_j_deformed = tf.gather(vertices_deformed_pose, indices_j, axis=-2)
# Extracts the weights we need per term.
weights_shape = vertices_i_rest.shape.as_list()[-2]
if vertex_weight is not None:
weight_i = tf.gather(vertex_weight, indices_i)
weight_j = tf.gather(vertex_weight, indices_j)
else:
weight_i = weight_j = tf.ones(
weights_shape, dtype=vertices_rest_pose.dtype)
weight_i = tf.expand_dims(weight_i, axis=-1)
weight_j = tf.expand_dims(weight_j, axis=-1)
if edge_weight is not None:
weight_ij = edge_weight
else:
weight_ij = tf.ones(weights_shape, dtype=vertices_rest_pose.dtype)
weight_ij = tf.expand_dims(weight_ij, axis=-1)
# Extracts the rotation we need per term.
quaternion_i = tf.gather(quaternions, indices_i, axis=-2)
quaternion_j = tf.gather(quaternions, indices_j, axis=-2)
# Computes the energy.
deformed_ij = vertices_i_deformed - vertices_j_deformed
rotated_rest_ij = quaternion.rotate((vertices_i_rest - vertices_j_rest),
quaternion_i)
energy_ij = weight_i * weight_ij * (deformed_ij - rotated_rest_ij)
deformed_ji = vertices_j_deformed - vertices_i_deformed
rotated_rest_ji = quaternion.rotate((vertices_j_rest - vertices_i_rest),
quaternion_j)
energy_ji = weight_j * weight_ij * (deformed_ji - rotated_rest_ji)
energy_ij_squared = vector.dot(energy_ij, energy_ij, keepdims=False)
energy_ji_squared = vector.dot(energy_ji, energy_ji, keepdims=False)
if aggregate_loss:
average_energy_ij = tf.reduce_mean(
input_tensor=energy_ij_squared, axis=-1)
average_energy_ji = tf.reduce_mean(
input_tensor=energy_ji_squared, axis=-1)
return (average_energy_ij + average_energy_ji) / 2.0
return tf.concat((energy_ij_squared, energy_ji_squared), axis=-1)