in models/decoders/voxel1.py [0:0]
def __init__(self, **kwargs):
super(AffineMixWarp, self).__init__()
self.quat = models.utils.Quaternion()
self.warps = nn.Sequential(
nn.Linear(256, 128), nn.LeakyReLU(0.2),
nn.Linear(128, 3*16))
self.warpr = nn.Sequential(
nn.Linear(256, 128), nn.LeakyReLU(0.2),
nn.Linear(128, 4*16))
self.warpt = nn.Sequential(
nn.Linear(256, 128), nn.LeakyReLU(0.2),
nn.Linear(128, 3*16))
self.weightbranch = nn.Sequential(
nn.Linear(256, 64), nn.LeakyReLU(0.2),
nn.Linear(64, 16*32*32*32))
for m in [self.warps, self.warpr, self.warpt, self.weightbranch]:
models.utils.initseq(m)
zgrid, ygrid, xgrid = np.meshgrid(
np.linspace(-1.0, 1.0, 32),
np.linspace(-1.0, 1.0, 32),
np.linspace(-1.0, 1.0, 32), indexing='ij')
self.register_buffer("grid", torch.tensor(np.stack((xgrid, ygrid, zgrid), axis=-1)[None].astype(np.float32)))