in demucs/hdemucs.py [0:0]
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
context_freq=True, rewrite=True):
"""
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
"""
super().__init__()
norm_fn = lambda d: nn.Identity() # noqa
if norm:
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
if pad:
pad = kernel_size // 4
else:
pad = 0
self.pad = pad
self.last = last
self.freq = freq
self.chin = chin
self.empty = empty
self.stride = stride
self.kernel_size = kernel_size
self.norm = norm
self.context_freq = context_freq
klass = nn.Conv1d
klass_tr = nn.ConvTranspose1d
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
klass = nn.Conv2d
klass_tr = nn.ConvTranspose2d
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
self.norm2 = norm_fn(chout)
if self.empty:
return
self.rewrite = None
if rewrite:
if context_freq:
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
else:
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
[0, context])
self.norm1 = norm_fn(2 * chin)
self.dconv = None
if dconv:
self.dconv = DConv(chin, **dconv_kw)