in denoiser/demucs.py [0:0]
def test():
import argparse
parser = argparse.ArgumentParser(
"denoiser.demucs",
description="Benchmark the streaming Demucs implementation, "
"as well as checking the delta with the offline implementation.")
parser.add_argument("--depth", default=5, type=int)
parser.add_argument("--resample", default=4, type=int)
parser.add_argument("--hidden", default=48, type=int)
parser.add_argument("--sample_rate", default=16000, type=float)
parser.add_argument("--device", default="cpu")
parser.add_argument("-t", "--num_threads", type=int)
parser.add_argument("-f", "--num_frames", type=int, default=1)
args = parser.parse_args()
if args.num_threads:
th.set_num_threads(args.num_threads)
sr = args.sample_rate
sr_ms = sr / 1000
demucs = Demucs(depth=args.depth, hidden=args.hidden, resample=args.resample).to(args.device)
x = th.randn(1, int(sr * 4)).to(args.device)
out = demucs(x[None])[0]
streamer = DemucsStreamer(demucs, num_frames=args.num_frames)
out_rt = []
frame_size = streamer.total_length
with th.no_grad():
while x.shape[1] > 0:
out_rt.append(streamer.feed(x[:, :frame_size]))
x = x[:, frame_size:]
frame_size = streamer.demucs.total_stride
out_rt.append(streamer.flush())
out_rt = th.cat(out_rt, 1)
model_size = sum(p.numel() for p in demucs.parameters()) * 4 / 2**20
initial_lag = streamer.total_length / sr_ms
tpf = 1000 * streamer.time_per_frame
print(f"model size: {model_size:.1f}MB, ", end='')
print(f"delta batch/streaming: {th.norm(out - out_rt) / th.norm(out):.2%}")
print(f"initial lag: {initial_lag:.1f}ms, ", end='')
print(f"stride: {streamer.stride * args.num_frames / sr_ms:.1f}ms")
print(f"time per frame: {tpf:.1f}ms, ", end='')
print(f"RTF: {((1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)):.2f}")
print(f"Total lag with computation: {initial_lag + tpf:.1f}ms")