in networks/deep_sdf_decoder.py [0:0]
def forward(self, input):
xyz = input[:, -3:]
if input.shape[1] > 3 and self.latent_dropout:
latent_vecs = input[:, :-3]
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
x = torch.cat([latent_vecs, xyz], 1)
else:
x = input
for layer in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(layer))
if layer in self.latent_in:
x = torch.cat([x, input], 1)
elif layer != 0 and self.xyz_in_all:
x = torch.cat([x, xyz], 1)
x = lin(x)
# last layer Tanh
if layer == self.num_layers - 2 and self.use_tanh:
x = self.tanh(x)
if layer < self.num_layers - 2:
if (
self.norm_layers is not None
and layer in self.norm_layers
and not self.weight_norm
):
bn = getattr(self, "bn" + str(layer))
x = bn(x)
x = self.relu(x)
if self.dropout is not None and layer in self.dropout:
x = F.dropout(x, p=self.dropout_prob, training=self.training)
if hasattr(self, "th"):
x = self.th(x)
return x