in demucs/train.py [0:0]
def get_solver(args, model_only=False):
distrib.init()
torch.manual_seed(args.seed)
model = get_model(args)
if args.misc.show:
logger.info(model)
mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20
logger.info('Size: %.1f MB', mb)
if hasattr(model, 'valid_length'):
field = model.valid_length(1)
logger.info('Field: %.1f ms', field / args.dset.samplerate * 1000)
sys.exit(0)
# torch also initialize cuda seed if available
if torch.cuda.is_available():
model.cuda()
# optimizer
if args.optim.optim == 'adam':
optimizer = torch.optim.Adam(
model.parameters(), lr=args.optim.lr,
betas=(args.optim.momentum, args.optim.beta2),
weight_decay=args.optim.weight_decay)
elif args.optim.optim == 'adamw':
optimizer = torch.optim.AdamW(
model.parameters(), lr=args.optim.lr,
betas=(args.optim.momentum, args.optim.beta2),
weight_decay=args.optim.weight_decay)
assert args.batch_size % distrib.world_size == 0
args.batch_size //= distrib.world_size
if model_only:
return Solver(None, model, optimizer, args)
train_set, valid_set = get_musdb_wav_datasets(args.dset)
if args.dset.wav:
extra_train_set, extra_valid_set = get_wav_datasets(args.dset)
train_set = ConcatDataset([train_set, extra_train_set])
valid_set = ConcatDataset([valid_set, extra_valid_set])
if args.augment.repitch.proba:
vocals = []
if 'vocals' in args.dset.sources:
vocals.append(args.dset.sources.index('vocals'))
else:
logger.warning('No vocal source found')
if args.augment.repitch.proba:
train_set = RepitchedWrapper(train_set, vocals=vocals, **args.augment.repitch)
logger.info("train/valid set size: %d %d", len(train_set), len(valid_set))
train_loader = distrib.loader(
train_set, batch_size=args.batch_size, shuffle=True,
num_workers=args.misc.num_workers, drop_last=True)
if args.dset.full_cv:
valid_loader = distrib.loader(
valid_set, batch_size=1, shuffle=False,
num_workers=args.misc.num_workers)
else:
valid_loader = distrib.loader(
valid_set, batch_size=args.batch_size, shuffle=False,
num_workers=args.misc.num_workers, drop_last=True)
loaders = {"train": train_loader, "valid": valid_loader}
# Construct Solver
return Solver(loaders, model, optimizer, args)