def calculate_snr()

in data_preparation/calculate_snr.py [0:0]


def calculate_snr(sample, vad, fs=16000, noise_th=0.995, speech_th=0.8, vad_window_ms=80):
    sample = convert_wav_buf_f32(sample)
    sample_chunk = np.split(sample, range(
        0, len(sample), int(vad_window_ms * fs / 1000)))
    speech_chunk = []
    noise_chunk = []
    leftover_chunk = []
    speech_continue_chunk = 2  # heuristic, 240ms
    for x, v in zip(sample_chunk, vad):
        if v < speech_th or speech_continue_chunk >= 0:
            speech_chunk.append(x)
            if v < speech_th:
                speech_continue_chunk = 2
            else:
                speech_continue_chunk -= 1
        elif v > noise_th:
            noise_chunk.append(x)
        else:
            leftover_chunk.append(x)
    speech_chunk = np.concatenate(speech_chunk)
    speech_energy = np.sum(np.power(speech_chunk, 2))
    speech_time = len(speech_chunk)/fs
    speech_power = speech_energy/speech_time
    if len(noise_chunk) == 0:
        print("no noise?", file=sys.stderr)
        return [float('nan'), speech_power, float('nan')]
    noise_chunk = np.concatenate(noise_chunk)
    leftover_chunk = np.concatenate(leftover_chunk)
    noise_energy = np.sum(np.power(noise_chunk, 2))
    noise_time = len(noise_chunk)/fs
    noise_power = noise_energy/noise_time
    snr = 10 * np.log10((speech_power)/noise_power)
    return [snr, speech_power, noise_power]