in models/src/wavegrad/downsampling.py [0:0]
def __init__(self, in_channels, out_channels, factor, dilations):
super(DownsamplingBlock, self).__init__()
in_sizes = [in_channels] + [out_channels for _ in range(len(dilations) - 1)]
out_sizes = [out_channels for _ in range(len(in_sizes))]
self.main_branch = torch.nn.Sequential(
*(
[
InterpolationBlock(
scale_factor=factor,
mode="linear",
align_corners=False,
downsample=True,
)
]
+ [
ConvolutionBlock(in_size, out_size, dilation)
for in_size, out_size, dilation in zip(
in_sizes, out_sizes, dilations
)
]
)
)
self.residual_branch = torch.nn.Sequential(
*[
Conv1dWithInitialization(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
),
InterpolationBlock(
scale_factor=factor,
mode="linear",
align_corners=False,
downsample=True,
),
]
)