in demucs/hdemucs.py [0:0]
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
rewrite=True):
"""Encoder layer. This used both by the time and the frequency branch.
Args:
chin: number of input channels.
chout: number of output channels.
norm_groups: number of groups for group norm.
empty: used to make a layer with just the first conv. this is used
before merging the time and freq. branches.
freq: this is acting on frequencies.
dconv: insert DConv residual branches.
norm: use GroupNorm.
context: context size for the 1x1 conv.
dconv_kw: list of kwargs for the DConv class.
pad: pad the input. Padding is done so that the output size is
always the input size / stride.
rewrite: add 1x1 conv at the end of the layer.
"""
super().__init__()
norm_fn = lambda d: nn.Identity() # noqa
if norm:
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
if pad:
pad = kernel_size // 4
else:
pad = 0
klass = nn.Conv1d
self.freq = freq
self.kernel_size = kernel_size
self.stride = stride
self.empty = empty
self.norm = norm
self.pad = pad
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
pad = [pad, 0]
klass = nn.Conv2d
self.conv = klass(chin, chout, kernel_size, stride, pad)
if self.empty:
return
self.norm1 = norm_fn(chout)
self.rewrite = None
if rewrite:
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
self.norm2 = norm_fn(2 * chout)
self.dconv = None
if dconv:
self.dconv = DConv(chout, **dconv_kw)