in Damage Assessment Visualizer/utils/models.py [0:0]
def forward(self, x1, x2):
# UNet on x1
enc1_1 = self.encoder1(x1)
enc2_1 = self.encoder2(self.pool1(enc1_1))
enc3_1 = self.encoder3(self.pool2(enc2_1))
enc4_1 = self.encoder4(self.pool3(enc3_1))
bottleneck_1 = self.bottleneck(self.pool4(enc4_1))
dec4_1 = self.upconv4(bottleneck_1)
dec4_1 = torch.cat((dec4_1, enc4_1), dim=1)
dec4_1 = self.decoder4(dec4_1)
dec3_1 = self.upconv3(dec4_1)
dec3_1 = torch.cat((dec3_1, enc3_1), dim=1)
dec3_1 = self.decoder3(dec3_1)
dec2_1 = self.upconv2(dec3_1)
dec2_1 = torch.cat((dec2_1, enc2_1), dim=1)
dec2_1 = self.decoder2(dec2_1)
dec1_1 = self.upconv1(dec2_1)
dec1_1 = torch.cat((dec1_1, enc1_1), dim=1)
dec1_1 = self.decoder1(dec1_1)
# UNet on x2
enc1_2 = self.encoder1(x2)
enc2_2 = self.encoder2(self.pool1(enc1_2))
enc3_2 = self.encoder3(self.pool2(enc2_2))
enc4_2 = self.encoder4(self.pool3(enc3_2))
bottleneck_2 = self.bottleneck(self.pool4(enc4_2))
dec4_2 = self.upconv4(bottleneck_2)
dec4_2 = torch.cat((dec4_2, enc4_2), dim=1)
dec4_2 = self.decoder4(dec4_2)
dec3_2 = self.upconv3(dec4_2)
dec3_2 = torch.cat((dec3_2, enc3_2), dim=1)
dec3_2 = self.decoder3(dec3_2)
dec2_2 = self.upconv2(dec3_2)
dec2_2 = torch.cat((dec2_2, enc2_2), dim=1)
dec2_2 = self.decoder2(dec2_2)
dec1_2 = self.upconv1(dec2_2)
dec1_2 = torch.cat((dec1_2, enc1_2), dim=1)
dec1_2 = self.decoder1(dec1_2)
# Siamese
dec1_c = bottleneck_2 - bottleneck_1
dec1_c = self.upconv4_c(dec1_c) # features * 16 -> features * 8
diff_2 = enc4_2 - enc4_1 # features * 16 -> features * 8
dec2_c = torch.cat((diff_2, dec1_c), dim=1) # 512
dec2_c = self.conv4_c(dec2_c)
dec2_c = self.upconv3_c(dec2_c) # 512->256
diff_3 = enc3_2 - enc3_1
dec3_c = torch.cat((diff_3, dec2_c), dim=1) # ->512
dec3_c = self.conv3_c(dec3_c)
dec3_c = self.upconv2_c(dec3_c) # 512->256
diff_4 = enc2_2 - enc2_1
dec4_c = torch.cat((diff_4, dec3_c), dim=1) #
dec4_c = self.conv2_c(dec4_c)
dec4_c = self.upconv1_c(dec4_c)
diff_5 = enc1_2 - enc1_1
dec5_c = torch.cat((diff_5, dec4_c), dim=1)
dec5_c = self.conv1_c(dec5_c)
return self.conv_s(dec1_1), self.conv_s(dec1_2), self.conv_c(dec5_c)