def evaluate_auto_select()

in svoice/evaluate_auto_select.py [0:0]


def evaluate_auto_select(args):
    total_sisnr = 0
    total_pesq = 0
    total_stoi = 0
    total_cnt = 0
    updates = 5

    models = list()
    paths = [args.model_path_2spk, args.model_path_3spk,
             args.model_path_4spk, args.model_path_5spk]

    for path in paths:
        # Load model
        pkg = torch.load(path)
        if 'model' in pkg:
            model = pkg['model']
        else:
            model = pkg
        model = deserialize_model(model)
        if 'best_state' in pkg:
            model.load_state_dict(pkg['best_state'])
        logger.debug(model)

        model.eval()
        model.to(args.device)
        models.append(model)

    # Load data
    dataset = Validset(args.data_dir)
    data_loader = distrib.loader(
        dataset, batch_size=1, num_workers=args.num_workers)
    sr = args.sample_rate
    y_hat = torch.zeros((4))

    pendings = []
    with ProcessPoolExecutor(args.num_workers) as pool:
        with torch.no_grad():
            iterator = LogProgress(logger, data_loader, name="Eval estimates")
            for i, data in enumerate(iterator):
                # Get batch data
                mixture, lengths, sources = [x.to(args.device) for x in data]
                estimated_sources = list()
                reorder_estimated_sources = list()

                for model in models:
                    # Forward
                    with torch.no_grad():
                        raw_estimate = model(mixture)[-1]

                    estimate = pair_wise(sources, raw_estimate)
                    sisnr_loss, snr, estimate, reorder_estimate = cal_loss(
                        sources, estimate, lengths)
                    estimated_sources.insert(0, raw_estimate)
                    reorder_estimated_sources.insert(0, reorder_estimate)

                # =================== DETECT NUM. NON-ACTIVE CHANNELS ============== #
                selected_idx = 0
                thresh = args.thresh
                max_spk = 5
                mix_spk = 2
                ground = (max_spk - mix_spk)
                while (selected_idx <= ground):
                    no_sils = 0
                    vals = torch.mean(
                        (estimated_sources[selected_idx]/torch.abs(estimated_sources[selected_idx]).max())**2, axis=2)
                    new_selected_idx = max_spk - len(vals[vals > thresh])
                    if new_selected_idx == selected_idx:
                        break
                    else:
                        selected_idx = new_selected_idx
                if selected_idx < 0:
                    selected_idx = 0
                elif selected_idx > ground:
                    selected_idx = ground

                y_hat[ground - selected_idx] += 1
                reorder_estimate = reorder_estimated_sources[selected_idx].cpu(
                )
                sources = sources.cpu()
                mixture = mixture.cpu()

                pendings.append(
                    pool.submit(_run_metrics, sources, reorder_estimate, mixture, None,
                                sr=sr))
                total_cnt += sources.shape[0]

            for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
                sisnr_i, pesq_i, stoi_i = pending.result()
                total_sisnr += sisnr_i
                total_pesq += pesq_i
                total_stoi += stoi_i

    metrics = [total_sisnr, total_pesq, total_stoi]
    sisnr, pesq, stoi = distrib.average(
        [m/total_cnt for m in metrics], total_cnt)
    logger.info(bold(f'Test set performance: SISNRi={sisnr:.2f} '
                     f'PESQ={pesq}, STOI={stoi}.'))
    logger.info(f'Two spks prob: {y_hat[0]/(total_cnt)}')
    logger.info(f'Three spks prob: {y_hat[1]/(total_cnt)}')
    logger.info(f'Four spks prob: {y_hat[2]/(total_cnt)}')
    logger.info(f'Five spks prob: {y_hat[3]/(total_cnt)}')
    return sisnr, pesq, stoi