in models/trunks/spconv/models/res16unet.py [0:0]
def forward(self, x, out_feat_keys=None):
end_points = {}
out = self.conv0p1s1(x)
out = self.bn0(out)
out_p1 = self.relu(out)
out = self.conv1p1s2(out_p1)
out = self.bn1(out)
out = self.relu(out)
out_b1p2 = self.block1(out)
end_points["en0_features"] = out ## 32
out = self.conv2p2s2(out_b1p2)
out = self.bn2(out)
out = self.relu(out)
out_b2p4 = self.block2(out)
end_points["en1_features"] = out ## 32
out = self.conv3p4s2(out_b2p4)
out = self.bn3(out)
out = self.relu(out)
out_b3p8 = self.block3(out)
end_points["en2_features"] = out ## 64
# pixel_dist=16
out = self.conv4p8s2(out_b3p8)
out = self.bn4(out)
out = self.relu(out)
end_points["en3_features"] = out ## 128
out = self.block4(out)
# pixel_dist=8
out = self.convtr4p16s2(out)
out = self.bntr4(out)
out = self.relu(out)
end_points["en4_features"] = out ## 256
out = me.cat(out, out_b3p8)
out = self.block5(out)
# pixel_dist=4
out = self.convtr5p8s2(out)
out = self.bntr5(out)
out = self.relu(out)
end_points["plane4_features"] = out
out = me.cat(out, out_b2p4)
out = self.block6(out)
# pixel_dist=2
out = self.convtr6p4s2(out)
out = self.bntr6(out)
out = self.relu(out)
end_points["plane5_features"] = out
out = me.cat(out, out_b1p2)
out = self.block7(out)
# pixel_dist=1
out = self.convtr7p2s2(out)
out = self.bntr7(out)
out = self.relu(out)
end_points["plane6_features"] = out
out = me.cat(out, out_p1)
out = self.block8(out)
end_points["plane7_features"] = out
out_feats = [None] * len(out_feat_keys)
for key in out_feat_keys:
feat = end_points[key+"_features"]
org_feat = end_points[key+"_features"]
feat = self.maxpool(feat)
if self.use_mlp:
feat = self.head(feat)
out_feats[out_feat_keys.index(key)] = feat.F ### Just use smlp
return out_feats