def evaluate()

in svoice/evaluate.py [0:0]


def evaluate(args, model=None, data_loader=None, sr=None):
    total_sisnr = 0
    total_pesq = 0
    total_stoi = 0
    total_cnt = 0
    updates = 5

    # Load model
    if not model:
        pkg = torch.load(args.model_path, map_location=args.device)
        if 'model' in pkg:
            model = pkg['model']
        else:
            model = pkg
        model = deserialize_model(model)
        if 'best_state' in pkg:
            model.load_state_dict(pkg['best_state'])
    logger.debug(model)
    model.eval()
    model.to(args.device)
    # Load data
    if not data_loader:
        dataset = Validset(args.data_dir)
        data_loader = distrib.loader(
            dataset, batch_size=1, num_workers=args.num_workers)
        sr = args.sample_rate
    pendings = []
    with ProcessPoolExecutor(args.num_workers) as pool:
        with torch.no_grad():
            iterator = LogProgress(logger, data_loader, name="Eval estimates")
            for i, data in enumerate(iterator):
                # Get batch data
                mixture, lengths, sources = [x.to(args.device) for x in data]
                # Forward
                with torch.no_grad():
                    mixture /= mixture.max()
                    estimate = model(mixture)[-1]
                sisnr_loss, snr, estimate, reorder_estimate = cal_loss(
                    sources, estimate, lengths)
                reorder_estimate = reorder_estimate.cpu()
                sources = sources.cpu()
                mixture = mixture.cpu()

                pendings.append(
                    pool.submit(_run_metrics, sources, reorder_estimate, mixture, None,
                                sr=sr))
                total_cnt += sources.shape[0]

            for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
                sisnr_i, pesq_i, stoi_i = pending.result()
                total_sisnr += sisnr_i
                total_pesq += pesq_i
                total_stoi += stoi_i

    metrics = [total_sisnr, total_pesq, total_stoi]
    sisnr, pesq, stoi = distrib.average(
        [m/total_cnt for m in metrics], total_cnt)
    logger.info(
        bold(f'Test set performance: SISNRi={sisnr:.2f} PESQ={pesq}, STOI={stoi}.'))
    return sisnr, pesq, stoi