in fairnr/modules/field.py [0:0]
def build_feature_field(self, args):
den_feat_dim = self.tex_input_dims[0]
den_input_dim, tex_input_dim = sum(self.den_input_dims), sum(self.tex_input_dims)
assert not getattr(args, "hypernetwork", False), "does not support hypernetwork for now"
assert (den_input_dim == 3) or (
self.den_filters['pos'].cat_input and len(self.den_filters) == 1), "cat pos in the end"
num_layers = args.feature_layers + 2 if not self.nerf_style else 8
skips = self.skips if not self.nerf_style else [4]
self.feature_field = ImplicitField(
den_input_dim, den_feat_dim + 1, args.feature_embed_dim, # +1 is for SDF values
num_layers, with_ln=False, skips=skips, outmost_linear=True, spec_init=False)
if getattr(args, "dropout_z", 0.0) > 0.0:
self.dropout_z = nn.Dropout(p=self.args.dropout_z)
else:
self.dropout_z = None
"""
Geometric initialization from https://arxiv.org/pdf/1911.10414.pdf
This enforce a model to approximate a SDF function:
f(x; \theta) \approx |x| - 1
"""
bias = 1.0
for l in range(num_layers):
lin = self.feature_field.net[l]
if l < num_layers - 1:
lin = lin.net[0]
if l == num_layers - 1: # last layer
torch.nn.init.normal_(lin.weight, mean=math.sqrt(math.pi) / math.sqrt(lin.weight.size(1)), std=0.0001)
torch.nn.init.constant_(lin.bias, -bias)
elif l == 0:
torch.nn.init.constant_(lin.bias, 0.0)
if den_input_dim > 3:
torch.nn.init.constant_(lin.weight[:, :-3], 0.0)
torch.nn.init.normal_(lin.weight[:, -3:], 0.0, math.sqrt(2) / math.sqrt(lin.weight.size(0)))
elif (l - 1) in skips:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, math.sqrt(2) / math.sqrt(lin.weight.size(0)))
torch.nn.init.constant_(lin.weight[:, :den_input_dim-3], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, math.sqrt(2) / math.sqrt(lin.weight.size(0)))
# force the initial fearures to 0
self.feature_field.net[7].weight.data[1:] = self.feature_field.net[7].weight.data[1:] * 0.0
self.feature_field.net[7].bias.data[1:] = self.feature_field.net[7].bias.data[1:] * 0.0