torchbenchmark/models/tts_angular/angular_tts_main.py (250 lines of code) (raw):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import sys
import time
import traceback
import math
import torch
import torch as T
from .model import SpeakerEncoder, AngleProtoLoss
from torch.optim.optimizer import Optimizer
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.manual_seed(54321)
class AttrDict(dict):
"""A custom dict which converts dict keys
to class attributes"""
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if eps < 0.0:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
self.degenerated_to_sgd = degenerated_to_sgd
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
for param in params:
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
param['buffer'] = [[None, None, None] for _ in range(10)]
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
super(RAdam, self).__init__(params, defaults)
def __setstate__(self, state): # pylint: disable=useless-super-delegation
super(RAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
state['step'] += 1
buffered = group['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
elif self.degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state['step'])
else:
step_size = -1
buffered[2] = step_size
# more conservative since it's an approximated value
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group['weight_decay'] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
p.data.copy_(p_data_fp32)
return loss
CONFIG = {
"run_name": "mueller91",
"run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ",
"audio":{
# Audio processing parameters
"num_mels": 40, # size of the mel spec frame.
"fft_size": 400, # number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 16000, # DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"win_length": 400, # stft window length in ms.
"hop_length": 160, # stft window hop-lengh in ms.
"frame_length_ms": None, # stft window length in ms.If None, 'win_length' is used.
"frame_shift_ms": None, # stft window hop-lengh in ms. If None, 'hop_length' is used.
"preemphasis": 0.98, # pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, # normalization range
"ref_level_db": 20, # reference level db, theoretically 20db is the sound of air.
"power": 1.5, # value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,# #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
# Normalization parameters
"signal_norm": True, # normalize the spec values in range [0, 1]
"symmetric_norm": True, # move normalization to range [-1, 1]
"max_norm": 4.0, # scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": True, # clip normalized values into the range.
"mel_fmin": 0.0, # minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, # maximum freq level for mel-spec. Tune for dataset!!
"do_trim_silence": True, # enable trimming of slience of audio as you load it. LJspeech (False), TWEB (False), Nancy (True)
"trim_db": 60 # threshold for timming silence. Set this according to your dataset.
},
"reinit_layers": [],
"loss": "angleproto", # "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA)
"grad_clip": 3.0, # upper limit for gradients for clipping.
"epochs": 1000, # total number of epochs to train.
"lr": 0.0001, # Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": False, # if True, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, # Noam decay steps to increase the learning rate from 0 to "lr"
"tb_model_param_stats": False, # True, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"steps_plot_stats": 10, # number of steps to plot embeddings.
"num_speakers_in_batch": 64, # Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"num_utters_per_speaker": 10, #
"num_loader_workers": 8, # number of training data loader processes. Don't set it too big. 4-8 are good values.
"wd": 0.000001, # Weight decay weight.
"checkpoint": True, # If True, it saves checkpoints per "save_step"
"save_step": 1000, # Number of training steps expected to save traning stats and checkpoints.
"print_step": 20, # Number of steps to log traning on console.
"output_path": "../../MozillaTTSOutput/checkpoints/voxceleb_librispeech/speaker_encoder/", # DATASET-RELATED: output path for all training outputs.
"model": {
"input_dim": 40,
"proj_dim": 256,
"lstm_dim": 768,
"num_lstm_layers": 3,
"use_lstm_with_projection": True
},
"storage": {
"sample_from_storage_p": 0.66,
"storage_size": 15, # the size of the in-memory storage with respect to a single batch
"additive_noise": 1e-5 # add very small gaussian noise to the data in order to increase robustness
},
"datasets":
[
{
"name": "vctk_slim",
"path": "../../../audio-datasets/en/VCTK-Corpus/",
"meta_file_train": None,
"meta_file_val": None
},
{
"name": "libri_tts",
"path": "../../../audio-datasets/en/LibriTTS/train-clean-100",
"meta_file_train": None,
"meta_file_val": None
},
{
"name": "libri_tts",
"path": "../../../audio-datasets/en/LibriTTS/train-clean-360",
"meta_file_train": None,
"meta_file_val": None
},
{
"name": "libri_tts",
"path": "../../../audio-datasets/en/LibriTTS/train-other-500",
"meta_file_train": None,
"meta_file_val": None
},
{
"name": "voxceleb1",
"path": "../../../audio-datasets/en/voxceleb1/",
"meta_file_train": None,
"meta_file_val": None
},
{
"name": "voxceleb2",
"path": "../../../audio-datasets/en/voxceleb2/",
"meta_file_train": None,
"meta_file_val": None
},
{
"name": "common_voice",
"path": "../../../audio-datasets/en/MozillaCommonVoice",
"meta_file_train": "train.tsv",
"meta_file_val": "test.tsv"
}
]
}
SYNTHETIC_DATA = []
class TTSModel:
def __init__(self, device, batch_size):
self.device = device
self.use_cuda = True if self.device == 'cuda' else False
self.c = AttrDict()
self.c.update(CONFIG)
c = self.c
self.model = SpeakerEncoder(input_dim=c.model['input_dim'],
proj_dim=c.model['proj_dim'],
lstm_dim=c.model['lstm_dim'],
num_lstm_layers=c.model['num_lstm_layers'])
self.optimizer = RAdam(self.model.parameters(), lr=c.lr)
self.criterion = AngleProtoLoss()
SYNTHETIC_DATA.append(T.rand(batch_size, 50, 40).to(device=self.device))
if self.use_cuda:
self.model = self.model.cuda()
self.criterion.cuda()
self.scheduler = None
self.global_step = 0
def __del__(self):
del SYNTHETIC_DATA[0]
def train(self, niter):
_, global_step = self._train(self.model, self.criterion,
self.optimizer, self.scheduler, None,
self.global_step, self.c, niter)
def eval(self):
result = self.model(SYNTHETIC_DATA[0])
return result
def __call__(self, *things):
return self
def _train(self, model, criterion, optimizer, scheduler, ap, global_step, c, niter):
# data_loader = setup_loader(ap, is_val=False, verbose=True)
model.train()
epoch_time = 0
best_loss = float('inf')
avg_loss = 0
avg_loader_time = 0
end_time = time.time()
# for _, data in enumerate(data_loader):
start_time = time.time()
for reps in range(niter):
for _, data in enumerate(SYNTHETIC_DATA):
# setup input data
# inputs = data[0]
inputs = data
loader_time = time.time() - end_time
global_step += 1
# setup lr
# if c.lr_decay:
# scheduler.step()
optimizer.zero_grad()
# dispatch data to GPU
if self.use_cuda:
inputs = inputs.cuda(non_blocking=True)
# labels = labels.cuda(non_blocking=True)
# forward pass model
outputs = model(inputs)
# print(outputs.shape)
view = outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1)
# loss computation
loss = criterion(view)
loss.backward()
# grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
step_time = time.time() - start_time
epoch_time += step_time
# Averaged Loss and Averaged Loader Time
avg_loss = 0.01 * loss.item() \
+ 0.99 * avg_loss if avg_loss != 0 else loss.item()
avg_loader_time = 1/c.num_loader_workers * loader_time + \
(c.num_loader_workers-1) / c.num_loader_workers * avg_loader_time if avg_loader_time != 0 else loader_time
current_lr = optimizer.param_groups[0]['lr']
# save best model
#best_loss = save_best_model(model, optimizer, avg_loss, best_loss,
# OUT_PATH, global_step)
end_time = time.time()
# print(end_time - start_time)
return avg_loss, global_step