in sing/generate.py [0:0]
def main():
args = get_parser().parse_args()
if not args.model.exists():
if args.dl:
print("Downloading pretrained SING model")
args.model.parent.mkdir(parents=True, exist_ok=True)
download_pretrained_model(args.model)
else:
utils.fatal("No model found for path {}. To download "
"a pretrained model, use --dl".format(args.model))
elif args.dl:
print(
"WARNING: --dl is set but {} already exist.".format(args.model),
file=sys.stderr)
model = torch.load(args.model)
if args.cuda:
model.cuda()
if args.parallel:
model = nn.DataParallel(model)
args.output.mkdir(exist_ok=True, parents=True)
dataset = nsynth.NSynthMetadata(args.metadata)
names = [name.strip() for name in open(args.list)]
indexes = [dataset.names.index(name) for name in names]
to_generate = DatasetSubset(dataset, indexes)
loader = DataLoader(
to_generate, batch_size=args.batch_size, collate_fn=collate)
with tqdm.tqdm(total=len(to_generate), unit="ex") as bar:
for batch in loader:
if args.cuda:
batch.cuda_()
with torch.no_grad():
rebuilt = model.forward(**batch.tensors)
rebuilt = utils.unpad1d(rebuilt, args.unpad)
for metadata, wav in zip(batch.metadata, rebuilt):
path = args.output / (metadata['name'] + ".wav")
wavfile.write(
str(path), metadata['sample_rate'],
dsp.float_wav_to_short(wav).cpu().detach().numpy())
bar.update(len(batch))