def __init__()

in models/decoders/mvp.py [0:0]


    def __init__(self,
            vt,
            vertmean,
            vertstd,
            idxim,
            tidxim,
            barim,
            volradius,
            dectype="slab2d",
            nprims=512,
            primsize=(32, 32, 32),
            chstart=256,
            penultch=None,
            condsize=0,
            motiontype="deconv",
            warptype=None,
            warpprimsize=None,
            sharedrgba=False,
            norm=None,
            mod=False,
            elr=True,
            scalemult=2.,
            nogeo=False,
            notplateact=False,
            postrainstart=-1,
            alphatrainstart=-1,
            renderoptions={},
            **kwargs):
        """
        Parameters
        ----------
        vt : numpy.array [V, 2]
            mesh vertex texture coordinates
        vertmean : numpy.array [V, 3]
            mesh vertex position average (average over time)
        vertstd : float
            mesh vertex position standard deviation (over time)
        idxim : torch.Tensor
            texture map of triangle indices
        tidxim : torch.Tensor
            texture map of texture triangle indices
        barim : torch.Tensor
            texture map of barycentric coordinates
        volradius : float
            radius of bounding volume of scene
        dectype : string
            type of content decoder, options are "slab2d", "slab2d3d", "slab2d3dv2"
        nprims : int
            number of primitives
        primsize : Tuple[int, int, int]
            size of primitive dimensions
        postrainstart : int
            training iterations to start learning position, rotation, and
            scaling (i.e., primitives stay frozen until this iteration number)
        condsize : int
            unused
        motiontype : string
            motion model, options are "linear" and "deconv"
        warptype : string
            warp model, options are "same" to use same architecture as content
            or None
        sharedrgba : bool
            True to use 1 branch to output rgba, False to use 1 branch for rgb
            and 1 branch for alpha
        """
        super(Decoder, self).__init__()

        self.volradius = volradius
        self.postrainstart = postrainstart
        self.alphatrainstart = alphatrainstart

        self.nprims = nprims
        self.primsize = primsize

        self.motiontype = motiontype
        self.nogeo = nogeo
        self.notplateact = notplateact
        self.scalemult = scalemult

        self.enc = LinearELR(256 + condsize, 256)

        # vertex output
        if not self.nogeo:
            self.geobranch = LinearELR(256, vertmean.numel(), norm=None)

        # primitive motion delta decoder
        self.motiondec = get_motion(motiontype, nprims=nprims, inch=256, outch=9,
                norm=norm, mod=mod, elr=elr, **kwargs)

        # slab decoder (RGBA)
        if sharedrgba:
            self.rgbadec = get_dec(dectype, nprims=nprims, primsize=primsize,
                    inch=256+3, outch=4, norm=norm, mod=mod, elr=elr,
                    penultch=penultch, **kwargs)

            if renderoptions.get("half", False):
                self.rgbadec = self.rgbadec.half()

            if renderoptions.get("chlastconv", False):
                self.rgbadec = self.rgbadec.to(memory_format=torch.channels_last)
        else:
            self.rgbdec = get_dec(dectype, nprims=nprims, primsize=primsize,
                    inch=256+3, outch=3, chstart=chstart, norm=norm, mod=mod,
                    elr=elr, penultch=penultch, **kwargs)
            self.alphadec = get_dec(dectype, nprims=nprims, primsize=primsize,
                    inch=256, outch=1, chstart=chstart, norm=norm, mod=mod,
                    elr=elr, penultch=penultch, **kwargs)
            self.rgbadec = None

            if renderoptions.get("half", False):
                self.rgbdec = self.rgbdec.half()
                self.alphadec = self.alphadec.half()

            if renderoptions.get("chlastconv", False):
                self.rgbdec = self.rgbdec.to(memory_format=torch.channels_last)
                self.alphadec = self.alphadec.to(memory_format=torch.channels_last)

        # warp field decoder
        if warptype is not None:
            self.warpdec = get_dec(warptype, nprims=nprims, primsize=warpprimsize,
                    inch=256, outch=3, chstart=chstart, norm=norm, mod=mod, elr=elr, **kwargs)
        else:
            self.warpdec = None

        # vertex/triangle/mesh topology data
        if vt is not None:
            vt = torch.tensor(vt) if not isinstance(vt, torch.Tensor) else vt
            self.register_buffer("vt", vt, persistent=False)

        if vertmean is not None:
            self.register_buffer("vertmean", vertmean, persistent=False)
        self.vertstd = vertstd

        idxim = torch.tensor(idxim) if not isinstance(idxim, torch.Tensor) else idxim
        tidxim = torch.tensor(tidxim) if not isinstance(tidxim, torch.Tensor) else tidxim
        barim = torch.tensor(barim) if not isinstance(barim, torch.Tensor) else barim
        self.register_buffer("idxim", idxim.long(), persistent=False)
        self.register_buffer("tidxim", tidxim.long(), persistent=False)
        self.register_buffer("barim", barim, persistent=False)