in models/decoders/mvp.py [0:0]
def __init__(self,
vt,
vertmean,
vertstd,
idxim,
tidxim,
barim,
volradius,
dectype="slab2d",
nprims=512,
primsize=(32, 32, 32),
chstart=256,
penultch=None,
condsize=0,
motiontype="deconv",
warptype=None,
warpprimsize=None,
sharedrgba=False,
norm=None,
mod=False,
elr=True,
scalemult=2.,
nogeo=False,
notplateact=False,
postrainstart=-1,
alphatrainstart=-1,
renderoptions={},
**kwargs):
"""
Parameters
----------
vt : numpy.array [V, 2]
mesh vertex texture coordinates
vertmean : numpy.array [V, 3]
mesh vertex position average (average over time)
vertstd : float
mesh vertex position standard deviation (over time)
idxim : torch.Tensor
texture map of triangle indices
tidxim : torch.Tensor
texture map of texture triangle indices
barim : torch.Tensor
texture map of barycentric coordinates
volradius : float
radius of bounding volume of scene
dectype : string
type of content decoder, options are "slab2d", "slab2d3d", "slab2d3dv2"
nprims : int
number of primitives
primsize : Tuple[int, int, int]
size of primitive dimensions
postrainstart : int
training iterations to start learning position, rotation, and
scaling (i.e., primitives stay frozen until this iteration number)
condsize : int
unused
motiontype : string
motion model, options are "linear" and "deconv"
warptype : string
warp model, options are "same" to use same architecture as content
or None
sharedrgba : bool
True to use 1 branch to output rgba, False to use 1 branch for rgb
and 1 branch for alpha
"""
super(Decoder, self).__init__()
self.volradius = volradius
self.postrainstart = postrainstart
self.alphatrainstart = alphatrainstart
self.nprims = nprims
self.primsize = primsize
self.motiontype = motiontype
self.nogeo = nogeo
self.notplateact = notplateact
self.scalemult = scalemult
self.enc = LinearELR(256 + condsize, 256)
# vertex output
if not self.nogeo:
self.geobranch = LinearELR(256, vertmean.numel(), norm=None)
# primitive motion delta decoder
self.motiondec = get_motion(motiontype, nprims=nprims, inch=256, outch=9,
norm=norm, mod=mod, elr=elr, **kwargs)
# slab decoder (RGBA)
if sharedrgba:
self.rgbadec = get_dec(dectype, nprims=nprims, primsize=primsize,
inch=256+3, outch=4, norm=norm, mod=mod, elr=elr,
penultch=penultch, **kwargs)
if renderoptions.get("half", False):
self.rgbadec = self.rgbadec.half()
if renderoptions.get("chlastconv", False):
self.rgbadec = self.rgbadec.to(memory_format=torch.channels_last)
else:
self.rgbdec = get_dec(dectype, nprims=nprims, primsize=primsize,
inch=256+3, outch=3, chstart=chstart, norm=norm, mod=mod,
elr=elr, penultch=penultch, **kwargs)
self.alphadec = get_dec(dectype, nprims=nprims, primsize=primsize,
inch=256, outch=1, chstart=chstart, norm=norm, mod=mod,
elr=elr, penultch=penultch, **kwargs)
self.rgbadec = None
if renderoptions.get("half", False):
self.rgbdec = self.rgbdec.half()
self.alphadec = self.alphadec.half()
if renderoptions.get("chlastconv", False):
self.rgbdec = self.rgbdec.to(memory_format=torch.channels_last)
self.alphadec = self.alphadec.to(memory_format=torch.channels_last)
# warp field decoder
if warptype is not None:
self.warpdec = get_dec(warptype, nprims=nprims, primsize=warpprimsize,
inch=256, outch=3, chstart=chstart, norm=norm, mod=mod, elr=elr, **kwargs)
else:
self.warpdec = None
# vertex/triangle/mesh topology data
if vt is not None:
vt = torch.tensor(vt) if not isinstance(vt, torch.Tensor) else vt
self.register_buffer("vt", vt, persistent=False)
if vertmean is not None:
self.register_buffer("vertmean", vertmean, persistent=False)
self.vertstd = vertstd
idxim = torch.tensor(idxim) if not isinstance(idxim, torch.Tensor) else idxim
tidxim = torch.tensor(tidxim) if not isinstance(tidxim, torch.Tensor) else tidxim
barim = torch.tensor(barim) if not isinstance(barim, torch.Tensor) else barim
self.register_buffer("idxim", idxim.long(), persistent=False)
self.register_buffer("tidxim", tidxim.long(), persistent=False)
self.register_buffer("barim", barim, persistent=False)