in models/neurvol1.py [0:0]
def forward(self, iternum, losslist, camrot, campos, focal, princpt, pixelcoords, validinput,
fixedcamimage=None, encoding=None, keypoints=None, camindex=None,
image=None, imagevalid=None, viewtemplate=False,
outputlist=[]):
result = {"losses": {}}
# encode input or get encoding
if encoding is None:
encout = self.encoder(fixedcamimage, losslist)
encoding = encout["encoding"]
result["losses"].update(encout["losses"])
# decode
decout = self.decoder(encoding, campos, losslist)
result["losses"].update(decout["losses"])
# NHWC
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
raydir = torch.sum(camrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2)
raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))
# compute raymarching starting points
with torch.no_grad():
t1 = (-1.0 - campos[:, None, None, :]) / raydir
t2 = ( 1.0 - campos[:, None, None, :]) / raydir
tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]),
torch.max(torch.min(t1[..., 1], t2[..., 1]),
torch.min(t1[..., 2], t2[..., 2])))
tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]),
torch.min(torch.max(t1[..., 1], t2[..., 1]),
torch.max(t1[..., 2], t2[..., 2])))
intersections = tmin < tmax
t = torch.where(intersections, tmin, torch.zeros_like(tmin)).clamp(min=0.)
tmin = torch.where(intersections, tmin, torch.zeros_like(tmin))
tmax = torch.where(intersections, tmax, torch.zeros_like(tmin))
# random starting point
t = t - self.dt * torch.rand_like(t)
raypos = campos[:, None, None, :] + raydir * t[..., None] # NHWC
rayrgb = torch.zeros_like(raypos.permute(0, 3, 1, 2)) # NCHW
rayalpha = torch.zeros_like(rayrgb[:, 0:1, :, :]) # NCHW
# raymarch
done = torch.zeros_like(t).bool()
while not done.all():
valid = torch.prod(torch.gt(raypos, -1.0) * torch.lt(raypos, 1.0), dim=-1).byte()
validf = valid.float()
sample_rgb, sample_alpha = self.volsampler(raypos[:, None, :, :, :], **decout, viewtemplate=viewtemplate)
with torch.no_grad():
step = self.dt * torch.exp(self.stepjitter * torch.randn_like(t))
done = done | ((t + step) >= tmax)
contrib = ((rayalpha + sample_alpha[:, :, 0, :, :] * step[:, None, :, :]).clamp(max=1.) - rayalpha) * validf[:, None, :, :]
rayrgb = rayrgb + sample_rgb[:, :, 0, :, :] * contrib
rayalpha = rayalpha + contrib
raypos = raypos + raydir * step[:, :, :, None]
t = t + step
if image is not None:
imagesize = torch.tensor(image.size()[3:1:-1], dtype=torch.float32, device=pixelcoords.device)
samplecoords = pixelcoords * 2. / (imagesize[None, None, None, :] - 1.) - 1.
# color correction / bg
if camindex is not None:
rayrgb = self.colorcal(rayrgb, camindex)
if pixelcoords.size()[1:3] != image.size()[2:4]:
bg = F.grid_sample(
torch.stack([self.bg[self.allcameras[camindex[i].item()]] for i in range(campos.size(0))], dim=0),
samplecoords)
else:
bg = torch.stack([self.bg[self.allcameras[camindex[i].item()]] for i in range(campos.size(0))], dim=0)
rayrgb = rayrgb + (1. - rayalpha) * bg.clamp(min=0.)
if "irgbrec" in outputlist:
result["irgbrec"] = rayrgb
if "ialpharec" in outputlist:
result["ialpharec"] = rayalpha
# opacity prior
if "alphapr" in losslist:
alphaprior = torch.mean(
torch.log(0.1 + rayalpha.view(rayalpha.size(0), -1)) +
torch.log(0.1 + 1. - rayalpha.view(rayalpha.size(0), -1)) - -2.20727, dim=-1)
result["losses"]["alphapr"] = alphaprior
# irgb loss
if image is not None:
if pixelcoords.size()[1:3] != image.size()[2:4]:
image = F.grid_sample(image, samplecoords)
# standardize
rayrgb = (rayrgb - self.imagemean) / self.imagestd
image = (image - self.imagemean) / self.imagestd
# compute reconstruction loss weighting
if imagevalid is not None:
weight = imagevalid[:, None, None, None].expand_as(image) * validinput[:, None, None, None]
else:
weight = torch.ones_like(image) * validinput[:, None, None, None]
irgbsqerr = weight * (image - rayrgb) ** 2
if "irgbsqerr" in outputlist:
result["irgbsqerr"] = rgbsqerr
if "irgbmse" in losslist:
irgbmse = torch.sum(irgbsqerr.view(irgbsqerr.size(0), -1), dim=-1)
irgbmse_weight = torch.sum(weight.view(weight.size(0), -1), dim=-1)
result["losses"]["irgbmse"] = (irgbmse, irgbmse_weight)
return result