in svoice/separate.py [0:0]
def separate(args, model=None, local_out_dir=None):
mix_dir, mix_json = get_mix_paths(args)
if not mix_json and not mix_dir:
logger.error("Must provide mix_dir or mix_json! "
"When providing mix_dir, mix_json is ignored.")
# Load model
if not model:
# model
pkg = torch.load(args.model_path)
if 'model' in pkg:
model = pkg['model']
else:
model = pkg
model = deserialize_model(model)
logger.debug(model)
model.eval()
model.to(args.device)
if local_out_dir:
out_dir = local_out_dir
else:
out_dir = args.out_dir
# Load data
eval_dataset = EvalDataset(
mix_dir,
mix_json,
batch_size=args.batch_size,
sample_rate=args.sample_rate,
)
eval_loader = distrib.loader(
eval_dataset, batch_size=1, klass=EvalDataLoader)
if distrib.rank == 0:
os.makedirs(out_dir, exist_ok=True)
distrib.barrier()
with torch.no_grad():
for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
# Get batch data
mixture, lengths, filenames = data
mixture = mixture.to(args.device)
lengths = lengths.to(args.device)
# Forward
estimate_sources = model(mixture)[-1]
# save wav files
save_wavs(estimate_sources, mixture, lengths,
filenames, out_dir, sr=args.sample_rate)