in models/decoders/voxel1.py [0:0]
def __init__(self, templatetype="conv", templateres=128,
viewconditioned=False, globalwarp=True, warptype="affinemix",
displacementwarp=False):
super(Decoder, self).__init__()
self.templatetype = templatetype
self.templateres = templateres
self.viewconditioned = viewconditioned
self.globalwarp = globalwarp
self.warptype = warptype
self.displacementwarp = displacementwarp
if self.viewconditioned:
self.template = gettemplate(self.templatetype, encodingsize=256+3,
outchannels=3, templateres=self.templateres)
self.templatealpha = gettemplate(self.templatetype, encodingsize=256,
outchannels=1, templateres=self.templateres)
else:
self.template = gettemplate(self.templatetype, templateres=self.templateres)
self.warp = getwarp(self.warptype, displacementwarp=self.displacementwarp)
if self.globalwarp:
self.quat = models.utils.Quaternion()
self.gwarps = nn.Sequential(
nn.Linear(256, 128), nn.LeakyReLU(0.2),
nn.Linear(128, 3))
self.gwarpr = nn.Sequential(
nn.Linear(256, 128), nn.LeakyReLU(0.2),
nn.Linear(128, 4))
self.gwarpt = nn.Sequential(
nn.Linear(256, 128), nn.LeakyReLU(0.2),
nn.Linear(128, 3))
initseq = models.utils.initseq
for m in [self.gwarps, self.gwarpr, self.gwarpt]:
initseq(m)