in models.py [0:0]
def __init__(self, in_channels, img_depth, kernel_size, dropout):
super(TDSBlock2d, self).__init__()
self.in_channels = in_channels
self.img_depth = img_depth
fc_size = in_channels * img_depth
self.conv = torch.nn.Sequential(
torch.nn.Conv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=(1, kernel_size[0], kernel_size[1]),
padding=(0, kernel_size[0] // 2, kernel_size[1] // 2),
),
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
)
self.fc = torch.nn.Sequential(
torch.nn.Linear(fc_size, fc_size),
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(fc_size, fc_size),
torch.nn.Dropout(dropout),
)
self.instance_norms = torch.nn.ModuleList(
[
torch.nn.InstanceNorm2d(fc_size, affine=True),
torch.nn.InstanceNorm2d(fc_size, affine=True),
]
)