def evaluate()

in torchbenchmark/models/demucs/demucs/test.py [0:0]


def evaluate(model,
             musdb_path,
             eval_folder,
             workers=2,
             device="cpu",
             rank=0,
             save=False,
             shifts=0,
             split=False,
             check=True,
             world_size=1):
    """
    Evaluate model using museval. Run the model
    on a single GPU, the bottleneck being the call to museval.
    """

    source_names = ["drums", "bass", "other", "vocals"]
    output_dir = eval_folder / "results"
    output_dir.mkdir(exist_ok=True, parents=True)
    json_folder = eval_folder / "results/test"
    json_folder.mkdir(exist_ok=True, parents=True)

    # we load tracks from the original musdb set
    test_set = musdb.DB(musdb_path, subsets=["test"])

    for p in model.parameters():
        p.requires_grad = False
        p.grad = None

    pendings = []
    with futures.ProcessPoolExecutor(workers or 1) as pool:
        for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout):
            track = test_set.tracks[index]

            out = json_folder / f"{track.name}.json.gz"
            if out.exists():
                continue

            mix = th.from_numpy(track.audio).t().float()
            ref = mix.mean(dim=0)  # mono mixture
            mix = (mix - ref.mean()) / ref.std()

            estimates = apply_model(model, mix.to(device), shifts=shifts, split=split)
            estimates = estimates * ref.std() + ref.mean()

            estimates = estimates.transpose(1, 2)
            references = th.stack(
                [th.from_numpy(track.targets[name].audio) for name in source_names])
            references = references.numpy()
            estimates = estimates.cpu().numpy()
            if save:
                folder = eval_folder / "wav/test" / track.name
                folder.mkdir(exist_ok=True, parents=True)
                for name, estimate in zip(source_names, estimates):
                    wavfile.write(str(folder / (name + ".wav")), 44100, estimate)

            if workers:
                pendings.append((track.name, pool.submit(museval.evaluate, references, estimates)))
            else:
                pendings.append((track.name, museval.evaluate(references, estimates)))
            del references, mix, estimates, track

        for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout):
            if workers:
                pending = pending.result()
            sdr, isr, sir, sar = pending
            track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name)
            for idx, target in enumerate(source_names):
                values = {
                    "SDR": sdr[idx].tolist(),
                    "SIR": sir[idx].tolist(),
                    "ISR": isr[idx].tolist(),
                    "SAR": sar[idx].tolist()
                }

                track_store.add_target(target_name=target, values=values)
                json_path = json_folder / f"{track_name}.json.gz"
                gzip.open(json_path, "w").write(track_store.json.encode('utf-8'))
    if world_size > 1:
        distributed.barrier()