in models.py [0:0]
def __init__(self, in_channels, num_features, kernel_size, dropout):
super(TDSBlock, self).__init__()
self.in_channels = in_channels
self.num_features = num_features
fc_size = in_channels * num_features
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=(1, kernel_size),
padding=(0, kernel_size // 2),
),
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
)
self.fc = torch.nn.Sequential(
torch.nn.Linear(fc_size, fc_size),
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(fc_size, fc_size),
torch.nn.Dropout(dropout),
)
self.instance_norms = torch.nn.ModuleList(
[
torch.nn.InstanceNorm1d(fc_size, affine=True),
torch.nn.InstanceNorm1d(fc_size, affine=True),
]
)