in torchbenchmark/models/demucs/demucs/separate.py [0:0]
def main():
parser = argparse.ArgumentParser("demucs.separate",
description="Separate the sources for the given tracks")
parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks')
parser.add_argument("-n",
"--name",
default="demucs",
help="Model name. See README.md for the list of pretrained models. "
"Default is demucs.")
parser.add_argument("-Q", "--quantized", action="store_true", dest="quantized", default=False,
help="Load the quantized model rather than the quantized version. "
"Quantized model is about 4 times smaller but might worsen "
"slightly quality.")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-o",
"--out",
type=Path,
default=Path("separated"),
help="Folder where to put extracted tracks. A subfolder "
"with the model name will be created.")
parser.add_argument("--models",
type=Path,
default=Path("models"),
help="Path to trained models. "
"Also used to store downloaded pretrained models")
parser.add_argument("--dl",
action="store_true",
help="Automatically download model if missing.")
parser.add_argument("-d",
"--device",
default="cuda" if th.cuda.is_available() else "cpu",
help="Device to use, default is cuda if available else cpu")
parser.add_argument("--shifts",
default=0,
type=int,
help="Number of random shifts for equivariant stabilization."
"Increase separation time but improves quality for Demucs. 10 was used "
"in the original paper.")
parser.add_argument("--nosplit",
action="store_false",
default=True,
dest="split",
help="Apply the model to the entire input at once rather than "
"first splitting it in chunks of 10 seconds. Will OOM with Tasnet "
"but will work fine for Demucs if you have at least 16GB of RAM.")
parser.add_argument("--float32",
action="store_true",
help="Convert the output wavefile to use pcm f32 format instead of s16. "
"This should not make a difference if you just plan on listening to the "
"audio but might be needed to compute exactly metrics like SDR etc.")
parser.add_argument("--int16",
action="store_false",
dest="float32",
help="Opposite of --float32, here for compatibility.")
parser.add_argument("--mp3", action="store_true",
help="Convert the output wavs to mp3 with 320 kb/s rate.")
args = parser.parse_args()
name = args.name + ".th"
if args.quantized:
name += ".gz"
model_path = args.models / name
sha256 = PRETRAINED_MODELS.get(name)
if not model_path.is_file():
if sha256 is None:
print(f"No pretrained model {args.name}", file=sys.stderr)
sys.exit(1)
if not args.dl:
print(
f"Could not find model {model_path}, however a matching pretrained model exist, "
"to download it, use --dl",
file=sys.stderr)
sys.exit(1)
args.models.mkdir(exist_ok=True, parents=True)
url = BASE_URL + name
print("Downloading pre-trained model weights, this could take a while...")
download_file(url, model_path)
if sha256 is not None:
verify_file(model_path, sha256)
model = load_model(model_path).to(args.device)
if args.quantized:
args.name += "_quantized"
out = args.out / args.name
out.mkdir(parents=True, exist_ok=True)
source_names = ["drums", "bass", "other", "vocals"]
print(f"Separated tracks will be stored in {out.resolve()}")
for track in args.tracks:
if not track.exists():
print(
f"File {track} does not exist. If the path contains spaces, "
"please try again after surrounding the entire path with quotes \"\".",
file=sys.stderr)
continue
print(f"Separating track {track}")
wav = AudioFile(track).read(streams=0, samplerate=44100, channels=2).to(args.device)
# Round to nearest short integer for compatibility with how MusDB load audio with stempeg.
wav = (wav * 2**15).round() / 2**15
ref = wav.mean(0)
wav = (wav - ref.mean()) / ref.std()
sources = apply_model(model, wav, shifts=args.shifts, split=args.split, progress=True)
sources = sources * ref.std() + ref.mean()
track_folder = out / track.name.split(".")[0]
track_folder.mkdir(exist_ok=True)
for source, name in zip(sources, source_names):
if args.mp3 or not args.float32:
source = (source * 2**15).clamp_(-2**15, 2**15 - 1).short()
source = source.cpu().transpose(0, 1).numpy()
stem = str(track_folder / name)
if args.mp3:
encode_mp3(source, stem + ".mp3", verbose=args.verbose)
else:
wavname = str(track_folder / f"{name}.wav")
wavfile.write(wavname, 44100, source)