in point_e/evals/pointnet2_cls_ssg.py [0:0]
def __init__(self, num_class, normal_channel=True, width_mult=1):
super(get_model, self).__init__()
self.width_mult = width_mult
in_channel = 6 if normal_channel else 3
self.normal_channel = normal_channel
self.sa1 = PointNetSetAbstraction(
npoint=512,
radius=0.2,
nsample=32,
in_channel=in_channel,
mlp=[64 * width_mult, 64 * width_mult, 128 * width_mult],
group_all=False,
)
self.sa2 = PointNetSetAbstraction(
npoint=128,
radius=0.4,
nsample=64,
in_channel=128 * width_mult + 3,
mlp=[128 * width_mult, 128 * width_mult, 256 * width_mult],
group_all=False,
)
self.sa3 = PointNetSetAbstraction(
npoint=None,
radius=None,
nsample=None,
in_channel=256 * width_mult + 3,
mlp=[256 * width_mult, 512 * width_mult, 1024 * width_mult],
group_all=True,
)
self.fc1 = nn.Linear(1024 * width_mult, 512 * width_mult)
self.bn1 = nn.BatchNorm1d(512 * width_mult)
self.drop1 = nn.Dropout(0.4)
self.fc2 = nn.Linear(512 * width_mult, 256 * width_mult)
self.bn2 = nn.BatchNorm1d(256 * width_mult)
self.drop2 = nn.Dropout(0.4)
self.fc3 = nn.Linear(256 * width_mult, num_class)