def main()

in torchbenchmark/models/demucs/demucs/separate.py [0:0]


def main():
    parser = argparse.ArgumentParser("demucs.separate",
                                     description="Separate the sources for the given tracks")
    parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks')
    parser.add_argument("-n",
                        "--name",
                        default="demucs",
                        help="Model name. See README.md for the list of pretrained models. "
                             "Default is demucs.")
    parser.add_argument("-Q", "--quantized", action="store_true", dest="quantized", default=False,
                        help="Load the quantized model rather than the quantized version. "
                             "Quantized model is about 4 times smaller but might worsen "
                             "slightly quality.")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("-o",
                        "--out",
                        type=Path,
                        default=Path("separated"),
                        help="Folder where to put extracted tracks. A subfolder "
                        "with the model name will be created.")
    parser.add_argument("--models",
                        type=Path,
                        default=Path("models"),
                        help="Path to trained models. "
                        "Also used to store downloaded pretrained models")
    parser.add_argument("--dl",
                        action="store_true",
                        help="Automatically download model if missing.")
    parser.add_argument("-d",
                        "--device",
                        default="cuda" if th.cuda.is_available() else "cpu",
                        help="Device to use, default is cuda if available else cpu")
    parser.add_argument("--shifts",
                        default=0,
                        type=int,
                        help="Number of random shifts for equivariant stabilization."
                        "Increase separation time but improves quality for Demucs. 10 was used "
                        "in the original paper.")
    parser.add_argument("--nosplit",
                        action="store_false",
                        default=True,
                        dest="split",
                        help="Apply the model to the entire input at once rather than "
                        "first splitting it in chunks of 10 seconds. Will OOM with Tasnet "
                        "but will work fine for Demucs if you have at least 16GB of RAM.")
    parser.add_argument("--float32",
                        action="store_true",
                        help="Convert the output wavefile to use pcm f32 format instead of s16. "
                        "This should not make a difference if you just plan on listening to the "
                        "audio but might be needed to compute exactly metrics like SDR etc.")
    parser.add_argument("--int16",
                        action="store_false",
                        dest="float32",
                        help="Opposite of --float32, here for compatibility.")
    parser.add_argument("--mp3", action="store_true",
                        help="Convert the output wavs to mp3 with 320 kb/s rate.")

    args = parser.parse_args()
    name = args.name + ".th"
    if args.quantized:
        name += ".gz"

    model_path = args.models / name
    sha256 = PRETRAINED_MODELS.get(name)
    if not model_path.is_file():
        if sha256 is None:
            print(f"No pretrained model {args.name}", file=sys.stderr)
            sys.exit(1)
        if not args.dl:
            print(
                f"Could not find model {model_path}, however a matching pretrained model exist, "
                "to download it, use --dl",
                file=sys.stderr)
            sys.exit(1)
        args.models.mkdir(exist_ok=True, parents=True)
        url = BASE_URL + name
        print("Downloading pre-trained model weights, this could take a while...")
        download_file(url, model_path)
    if sha256 is not None:
        verify_file(model_path, sha256)
    model = load_model(model_path).to(args.device)
    if args.quantized:
        args.name += "_quantized"
    out = args.out / args.name
    out.mkdir(parents=True, exist_ok=True)
    source_names = ["drums", "bass", "other", "vocals"]
    print(f"Separated tracks will be stored in {out.resolve()}")
    for track in args.tracks:
        if not track.exists():
            print(
                f"File {track} does not exist. If the path contains spaces, "
                "please try again after surrounding the entire path with quotes \"\".",
                file=sys.stderr)
            continue
        print(f"Separating track {track}")
        wav = AudioFile(track).read(streams=0, samplerate=44100, channels=2).to(args.device)
        # Round to nearest short integer for compatibility with how MusDB load audio with stempeg.
        wav = (wav * 2**15).round() / 2**15
        ref = wav.mean(0)
        wav = (wav - ref.mean()) / ref.std()
        sources = apply_model(model, wav, shifts=args.shifts, split=args.split, progress=True)
        sources = sources * ref.std() + ref.mean()

        track_folder = out / track.name.split(".")[0]
        track_folder.mkdir(exist_ok=True)
        for source, name in zip(sources, source_names):
            if args.mp3 or not args.float32:
                source = (source * 2**15).clamp_(-2**15, 2**15 - 1).short()
            source = source.cpu().transpose(0, 1).numpy()
            stem = str(track_folder / name)
            if args.mp3:
                encode_mp3(source, stem + ".mp3", verbose=args.verbose)
            else:
                wavname = str(track_folder / f"{name}.wav")
                wavfile.write(wavname, 44100, source)