in denoiser/demucs.py [0:0]
def _separate_frame(self, frame):
demucs = self.demucs
skips = []
next_state = []
first = self.conv_state is None
stride = self.stride * demucs.resample
x = frame[None]
for idx, encode in enumerate(demucs.encoder):
stride //= demucs.stride
length = x.shape[2]
if idx == demucs.depth - 1:
# This is sligthly faster for the last conv
x = fast_conv(encode[0], x)
x = encode[1](x)
x = fast_conv(encode[2], x)
x = encode[3](x)
else:
if not first:
prev = self.conv_state.pop(0)
prev = prev[..., stride:]
tgt = (length - demucs.kernel_size) // demucs.stride + 1
missing = tgt - prev.shape[-1]
offset = length - demucs.kernel_size - demucs.stride * (missing - 1)
x = x[..., offset:]
x = encode[1](encode[0](x))
x = fast_conv(encode[2], x)
x = encode[3](x)
if not first:
x = th.cat([prev, x], -1)
next_state.append(x)
skips.append(x)
x = x.permute(2, 0, 1)
x, self.lstm_state = demucs.lstm(x, self.lstm_state)
x = x.permute(1, 2, 0)
# In the following, x contains only correct samples, i.e. the one
# for which each time position is covered by two window of the upper layer.
# extra contains extra samples to the right, and is used only as a
# better padding for the online resampling.
extra = None
for idx, decode in enumerate(demucs.decoder):
skip = skips.pop(-1)
x += skip[..., :x.shape[-1]]
x = fast_conv(decode[0], x)
x = decode[1](x)
if extra is not None:
skip = skip[..., x.shape[-1]:]
extra += skip[..., :extra.shape[-1]]
extra = decode[2](decode[1](decode[0](extra)))
x = decode[2](x)
next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1))
if extra is None:
extra = x[..., -demucs.stride:]
else:
extra[..., :demucs.stride] += next_state[-1]
x = x[..., :-demucs.stride]
if not first:
prev = self.conv_state.pop(0)
x[..., :demucs.stride] += prev
if idx != demucs.depth - 1:
x = decode[3](x)
extra = decode[3](extra)
self.conv_state = next_state
return x[0], extra[0]