def build_feature_field()

in fairnr/modules/field.py [0:0]


    def build_feature_field(self, args):
        den_feat_dim = self.tex_input_dims[0]
        den_input_dim, tex_input_dim = sum(self.den_input_dims), sum(self.tex_input_dims)

        assert not getattr(args, "hypernetwork", False), "does not support hypernetwork for now"
        assert (den_input_dim == 3) or (
            self.den_filters['pos'].cat_input and len(self.den_filters) == 1), "cat pos in the end"

        num_layers = args.feature_layers + 2 if not self.nerf_style else 8
        skips = self.skips if not self.nerf_style else [4]

        self.feature_field = ImplicitField(
            den_input_dim, den_feat_dim + 1, args.feature_embed_dim,  # +1 is for SDF values
            num_layers, with_ln=False, skips=skips, outmost_linear=True, spec_init=False)  
        
        if getattr(args, "dropout_z", 0.0) > 0.0:
            self.dropout_z = nn.Dropout(p=self.args.dropout_z)
        else:
            self.dropout_z = None

        """ 
        Geometric initialization from https://arxiv.org/pdf/1911.10414.pdf
        This enforce a model to approximate a SDF function: 
            f(x; \theta) \approx |x| - 1   
        """
        bias = 1.0
        for l in range(num_layers):
            lin = self.feature_field.net[l] 
            if l < num_layers - 1:
                lin = lin.net[0]

            if l == num_layers - 1:  # last layer
                torch.nn.init.normal_(lin.weight, mean=math.sqrt(math.pi) / math.sqrt(lin.weight.size(1)), std=0.0001)
                torch.nn.init.constant_(lin.bias, -bias)
            elif l == 0:
                torch.nn.init.constant_(lin.bias, 0.0)
                if den_input_dim > 3:
                    torch.nn.init.constant_(lin.weight[:, :-3], 0.0)
                torch.nn.init.normal_(lin.weight[:, -3:], 0.0, math.sqrt(2) / math.sqrt(lin.weight.size(0)))
            elif (l - 1) in skips:
                torch.nn.init.constant_(lin.bias, 0.0)
                torch.nn.init.normal_(lin.weight, 0.0, math.sqrt(2) / math.sqrt(lin.weight.size(0)))
                torch.nn.init.constant_(lin.weight[:, :den_input_dim-3], 0.0)
            else:
                torch.nn.init.constant_(lin.bias, 0.0)
                torch.nn.init.normal_(lin.weight, 0.0, math.sqrt(2) / math.sqrt(lin.weight.size(0)))
        
        # force the initial fearures to 0
        self.feature_field.net[7].weight.data[1:] = self.feature_field.net[7].weight.data[1:] * 0.0
        self.feature_field.net[7].bias.data[1:] = self.feature_field.net[7].bias.data[1:] * 0.0