in models/vertex_unet.py [0:0]
def __init__(self, classes: int = 128, heads: int = 64, n_vertices: int = 6172, mean: th.Tensor = None,
stddev: th.Tensor = None, model_name: str = 'vertex_unet'):
"""
VertexUnet consumes a neutral template mesh and an expression encoding and produces an animated face mesh
:param classes: number of classes for the categorical latent embedding
:param heads: number of heads for the categorical latent embedding
:param n_vertices: number of vertices in the face mesh
:param mean: mean position of each vertex
:param stddev: standard deviation of each vertex position
:param model_name: name of the model, used to load and save the model
"""
super().__init__(model_name)
self.classes = classes
self.heads = heads
self.n_vertices = n_vertices
shape = (1, 1, n_vertices, 3)
self.register_buffer("mean", th.zeros(shape) if mean is None else mean.view(shape))
self.register_buffer("stddev", th.ones(shape) if stddev is None else stddev.view(shape))
# encoder layers
self.encoder = th.nn.ModuleList([
th.nn.Linear(n_vertices*3, 512),
th.nn.Linear(512, 256),
th.nn.Linear(256, 128)
])
# multimodal fusion
self.fusion = th.nn.Linear(heads * classes + 128, 128)
# decoder layers
self.temporal = th.nn.LSTM(input_size=128, hidden_size=128, num_layers=2, batch_first=True)
self.decoder = th.nn.ModuleList([
th.nn.Linear(128, 256),
th.nn.Linear(256, 512),
th.nn.Linear(512, n_vertices*3)
])
self.vertex_bias = th.nn.Parameter(th.zeros(n_vertices * 3))