in dib/predefined/cnn.py [0:0]
def apply_convs(self, X):
if self.is_skip_resize:
n_tmp_blocks = self.n_blocks
start_block = 0
else:
n_tmp_blocks = self.n_blocks - 2
# Input block
X = self.conv_blocks[0](X)
start_block = 1
n_down_blocks = n_tmp_blocks // 2
residuals = [None] * n_down_blocks
# Down
for i in range(n_down_blocks):
X = self.conv_blocks[start_block + i](X)
residuals[i] = X
X = self.pooling(X)
# Bottleneck
X = self.conv_blocks[n_down_blocks](X)
# Representation before forcing same bottleneck
representation = X.view(*X.shape[:2], -1).mean(-1)
if self.is_force_same_bottleneck and self.training:
# forces the u-net to use the bottleneck by giving additional information
# there. I.e. taking average between bottleneck of different samples
# of the same functions. Because bottleneck should be a global representation
# => should not depend on the sample you chose
batch_size = X.size(0)
batch_1 = X[: batch_size // 2, ...]
batch_2 = X[batch_size // 2:, ...]
X_mean = (batch_1 + batch_2) / 2
X = torch.cat([X_mean, X_mean], dim=0)
# Up
for i in range(n_down_blocks + 1, n_tmp_blocks):
X = F.interpolate(
X,
mode=self.upsample_mode,
scale_factor=self.pooling_size,
align_corners=True,
)
X = torch.cat(
(X, residuals[n_down_blocks - i]), dim=1
) # concat on channels
X = self.conv_blocks[i + start_block](X)
if not self.is_skip_resize:
# Output Block
X = self.conv_blocks[-1](X)
return X, representation