in models/decoders/nv.py [0:0]
def __init__(self,
volradius,
dectype="conv",
primsize=128,
chstart=256,
penultch=None,
condsize=0,
warptype="conv",
warpprimsize=32,
sharedrgba=False,
norm=None,
mod=False,
elr=True,
notplateact=False,
postrainstart=-1,
alphatrainstart=-1,
renderoptions={},
**kwargs):
"""
Parameters
----------
volradius : float
radius of bounding volume of scene
dectype : string
type of content decoder, options are "slab2d", "slab2d3d", "slab2d3dv2"
primsize : Tuple[int, int, int]
size of primitive dimensions
postrainstart : int
training iterations to start learning position, rotation, and
scaling (i.e., primitives stay frozen until this iteration number)
condsize : int
unused
motiontype : string
motion model, options are "linear" and "deconv"
warptype : string
warp model, options are "same" to use same architecture as content
or None
sharedrgba : bool
True to use 1 branch to output rgba, False to use 1 branch for rgb
and 1 branch for alpha
"""
super(Decoder, self).__init__()
self.volradius = volradius
self.postrainstart = postrainstart
self.alphatrainstart = alphatrainstart
self.primsize = primsize
self.warpprimsize = warpprimsize
self.notplateact = notplateact
self.enc = LinearELR(256 + condsize, 256)
# slab decoder (RGBA)
if sharedrgba:
self.rgbadec = get_dec(dectype, primsize=primsize,
inch=256+3, outch=4, norm=norm, mod=mod, elr=elr,
penultch=penultch, **kwargs)
if renderoptions.get("half", False):
self.rgbadec = self.rgbadec.half()
if renderoptions.get("chlastconv", False):
self.rgbadec = self.rgbadec.to(memory_format=torch.channels_last)
else:
self.rgbdec = get_dec(dectype, primsize=primsize,
inch=256+3, outch=3, chstart=chstart, norm=norm, mod=mod,
elr=elr, penultch=penultch, **kwargs)
self.alphadec = get_dec(dectype, primsize=primsize,
inch=256, outch=1, chstart=chstart, norm=norm, mod=mod,
elr=elr, penultch=penultch, **kwargs)
self.rgbadec = None
if renderoptions.get("half", False):
self.rgbdec = self.rgbdec.half()
self.alphadec = self.alphadec.half()
if renderoptions.get("chlastconv", False):
self.rgbdec = self.rgbdec.to(memory_format=torch.channels_last)
self.alphadec = self.alphadec.to(memory_format=torch.channels_last)
# warp field decoder
if warptype is not None:
self.warpdec = get_dec(warptype, primsize=warpprimsize,
inch=256, outch=3, chstart=chstart, norm=norm, mod=mod, elr=elr, **kwargs)
else:
self.warpdec = None