in fairnr/modules/field.py [0:0]
def forward(self, inputs, outputs=['sigma', 'texture']):
filtered_inputs, context = [], None
if inputs.get('feat', None) is None:
for i, name in enumerate(self.den_filters):
d_in, func = self.den_ori_dims[i], self.den_filters[name]
assert (name in inputs), "the encoder must contain target inputs"
assert inputs[name].size(-1) == d_in, "{} dimension must match {} v.s. {}".format(
name, inputs[name].size(-1), d_in)
if name == 'context':
assert (i == (len(self.den_filters) - 1)), "we force context as the last input"
assert inputs[name].size(0) == 1, "context is object level"
context = func(inputs[name])
else:
filtered_inputs += [func(inputs[name])]
filtered_inputs = torch.cat(filtered_inputs, -1)
if context is not None:
if getattr(self.args, "hypernetwork", False):
filtered_inputs = (filtered_inputs, context)
else:
filtered_inputs = (torch.cat([filtered_inputs, context.expand(filtered_inputs.size(0), context.size(1))], -1),)
else:
filtered_inputs = (filtered_inputs, )
inputs['feat'] = self.feature_field(*filtered_inputs)
if 'sigma' in outputs:
assert 'feat' in inputs, "feature must be pre-computed"
inputs['sigma'] = self.predictor(inputs['feat'])[0]
if ('normal' not in inputs) and (
(('texture' in outputs) and ("normal" in self.tex_filters))
or ("normal" in outputs)):
assert 'sigma' in inputs, "sigma must be pre-computed"
assert 'pos' in inputs, "position is used to compute sigma"
grad_pos, = grad(
outputs=inputs['sigma'], inputs=inputs['pos'],
grad_outputs=torch.ones_like(inputs['sigma'], requires_grad=False),
retain_graph=True, create_graph=True)
if not getattr(self.args, "no_normalize_normal", False):
inputs['normal'] = F.normalize(-grad_pos, p=2, dim=1) # BUG: gradient direction reversed.
else:
inputs['normal'] = -grad_pos # no normalization. magnitude also has information?
if 'texture' in outputs:
filtered_inputs = []
if self.zero_z == 1:
inputs['feat'] = inputs['feat'] * 0.0 # zero-out latent feature
inputs['feat_n2'] = (inputs['feat'] ** 2).sum(-1)
for i, name in enumerate(self.tex_filters):
d_in, func = self.tex_ori_dims[i], self.tex_filters[name]
assert (name in inputs), "the encoder must contain target inputs"
filtered_inputs += [func(inputs[name])] if name != 'sigma' else [func(inputs[name].unsqueeze(-1))]
filtered_inputs = torch.cat(filtered_inputs, -1)
inputs['texture'] = self.renderer(filtered_inputs)
if self.min_color == 0:
inputs['texture'] = torch.sigmoid(inputs['texture'])
return inputs