in svoice/evaluate_auto_select.py [0:0]
def evaluate_auto_select(args):
total_sisnr = 0
total_pesq = 0
total_stoi = 0
total_cnt = 0
updates = 5
models = list()
paths = [args.model_path_2spk, args.model_path_3spk,
args.model_path_4spk, args.model_path_5spk]
for path in paths:
# Load model
pkg = torch.load(path)
if 'model' in pkg:
model = pkg['model']
else:
model = pkg
model = deserialize_model(model)
if 'best_state' in pkg:
model.load_state_dict(pkg['best_state'])
logger.debug(model)
model.eval()
model.to(args.device)
models.append(model)
# Load data
dataset = Validset(args.data_dir)
data_loader = distrib.loader(
dataset, batch_size=1, num_workers=args.num_workers)
sr = args.sample_rate
y_hat = torch.zeros((4))
pendings = []
with ProcessPoolExecutor(args.num_workers) as pool:
with torch.no_grad():
iterator = LogProgress(logger, data_loader, name="Eval estimates")
for i, data in enumerate(iterator):
# Get batch data
mixture, lengths, sources = [x.to(args.device) for x in data]
estimated_sources = list()
reorder_estimated_sources = list()
for model in models:
# Forward
with torch.no_grad():
raw_estimate = model(mixture)[-1]
estimate = pair_wise(sources, raw_estimate)
sisnr_loss, snr, estimate, reorder_estimate = cal_loss(
sources, estimate, lengths)
estimated_sources.insert(0, raw_estimate)
reorder_estimated_sources.insert(0, reorder_estimate)
# =================== DETECT NUM. NON-ACTIVE CHANNELS ============== #
selected_idx = 0
thresh = args.thresh
max_spk = 5
mix_spk = 2
ground = (max_spk - mix_spk)
while (selected_idx <= ground):
no_sils = 0
vals = torch.mean(
(estimated_sources[selected_idx]/torch.abs(estimated_sources[selected_idx]).max())**2, axis=2)
new_selected_idx = max_spk - len(vals[vals > thresh])
if new_selected_idx == selected_idx:
break
else:
selected_idx = new_selected_idx
if selected_idx < 0:
selected_idx = 0
elif selected_idx > ground:
selected_idx = ground
y_hat[ground - selected_idx] += 1
reorder_estimate = reorder_estimated_sources[selected_idx].cpu(
)
sources = sources.cpu()
mixture = mixture.cpu()
pendings.append(
pool.submit(_run_metrics, sources, reorder_estimate, mixture, None,
sr=sr))
total_cnt += sources.shape[0]
for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
sisnr_i, pesq_i, stoi_i = pending.result()
total_sisnr += sisnr_i
total_pesq += pesq_i
total_stoi += stoi_i
metrics = [total_sisnr, total_pesq, total_stoi]
sisnr, pesq, stoi = distrib.average(
[m/total_cnt for m in metrics], total_cnt)
logger.info(bold(f'Test set performance: SISNRi={sisnr:.2f} '
f'PESQ={pesq}, STOI={stoi}.'))
logger.info(f'Two spks prob: {y_hat[0]/(total_cnt)}')
logger.info(f'Three spks prob: {y_hat[1]/(total_cnt)}')
logger.info(f'Four spks prob: {y_hat[2]/(total_cnt)}')
logger.info(f'Five spks prob: {y_hat[3]/(total_cnt)}')
return sisnr, pesq, stoi