in point_e/evals/pointnet2_cls_ssg.py [0:0]
def forward(self, xyz, features=False):
B, _, _ = xyz.shape
if self.normal_channel:
norm = xyz[:, 3:, :]
xyz = xyz[:, :3, :]
else:
norm = None
l1_xyz, l1_points = self.sa1(xyz, norm)
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
x = l3_points.view(B, 1024 * self.width_mult)
x = self.drop1(F.relu(self.bn1(self.fc1(x))))
result_features = self.bn2(self.fc2(x))
x = self.drop2(F.relu(result_features))
x = self.fc3(x)
x = F.log_softmax(x, -1)
if features:
return x, l3_points, result_features
else:
return x, l3_points