def __init__()

in models/vertex_unet.py [0:0]


    def __init__(self, classes: int = 128, heads: int = 64, n_vertices: int = 6172, mean: th.Tensor = None,
                 stddev: th.Tensor = None, model_name: str = 'vertex_unet'):
        """
        VertexUnet consumes a neutral template mesh and an expression encoding and produces an animated face mesh
        :param classes: number of classes for the categorical latent embedding
        :param heads: number of heads for the categorical latent embedding
        :param n_vertices: number of vertices in the face mesh
        :param mean: mean position of each vertex
        :param stddev: standard deviation of each vertex position
        :param model_name: name of the model, used to load and save the model
        """
        super().__init__(model_name)

        self.classes = classes
        self.heads = heads
        self.n_vertices = n_vertices

        shape = (1, 1, n_vertices, 3)
        self.register_buffer("mean", th.zeros(shape) if mean is None else mean.view(shape))
        self.register_buffer("stddev", th.ones(shape) if stddev is None else stddev.view(shape))

        # encoder layers
        self.encoder = th.nn.ModuleList([
            th.nn.Linear(n_vertices*3, 512),
            th.nn.Linear(512, 256),
            th.nn.Linear(256, 128)
        ])

        # multimodal fusion
        self.fusion = th.nn.Linear(heads * classes + 128, 128)

        # decoder layers
        self.temporal = th.nn.LSTM(input_size=128, hidden_size=128, num_layers=2, batch_first=True)
        self.decoder = th.nn.ModuleList([
            th.nn.Linear(128, 256),
            th.nn.Linear(256, 512),
            th.nn.Linear(512, n_vertices*3)
        ])
        self.vertex_bias = th.nn.Parameter(th.zeros(n_vertices * 3))