in train.py [0:0]
def run(args):
import torch
from svoice import distrib
from svoice.data.data import Trainset, Validset
from svoice.models.swave import SWave
from svoice.solver import Solver
logger.info("Running on host %s", socket.gethostname())
distrib.init(args)
if args.model == "swave":
kwargs = dict(args.swave)
kwargs['sr'] = args.sample_rate
kwargs['segment'] = args.segment
model = SWave(**kwargs)
else:
logger.fatal("Invalid model name %s", args.model)
os._exit(1)
# requires a specific number of samples to avoid 0 padding during training
if hasattr(model, 'valid_length'):
segment_len = int(args.segment * args.sample_rate)
segment_len = model.valid_length(segment_len)
args.segment = segment_len / args.sample_rate
if args.show:
logger.info(model)
mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20
logger.info('Size: %.1f MB', mb)
if hasattr(model, 'valid_length'):
field = model.valid_length(1)
logger.info('Field: %.1f ms', field / args.sample_rate * 1000)
return
assert args.batch_size % distrib.world_size == 0
args.batch_size //= distrib.world_size
# Building datasets and loaders
tr_dataset = Trainset(
args.dset.train, sample_rate=args.sample_rate, segment=args.segment, stride=args.stride, pad=args.pad)
tr_loader = distrib.loader(
tr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
# batch_size=1 -> use less GPU memory to do cv
cv_dataset = Validset(args.dset.valid)
tt_dataset = Validset(args.dset.test)
cv_loader = distrib.loader(
cv_dataset, batch_size=1, num_workers=args.num_workers)
tt_loader = distrib.loader(
tt_dataset, batch_size=1, num_workers=args.num_workers)
data = {"tr_loader": tr_loader,
"cv_loader": cv_loader, "tt_loader": tt_loader}
# torch also initialize cuda seed if available
torch.manual_seed(args.seed)
if torch.cuda.is_available():
model.cuda()
# optimizer
if args.optim == "adam":
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, betas=(0.9, args.beta2))
else:
logger.fatal('Invalid optimizer %s', args.optim)
os._exit(1)
# Construct Solver
solver = Solver(data, model, optimizer, args)
solver.train()