def get_musdb_wav_datasets()

in demucs/wav.py [0:0]


def get_musdb_wav_datasets(args):
    """Extract the musdb dataset from the XP arguments."""
    sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8]
    metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json")
    root = Path(args.musdb) / "train"
    if not metadata_file.is_file() and distrib.rank == 0:
        metadata_file.parent.mkdir(exist_ok=True, parents=True)
        metadata = build_metadata(root, args.sources)
        json.dump(metadata, open(metadata_file, "w"))
    if distrib.world_size > 1:
        distributed.barrier()
    metadata = json.load(open(metadata_file))

    valid_tracks = _get_musdb_valid()
    if args.train_valid:
        metadata_train = metadata
    else:
        metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks}
    metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks}
    if args.full_cv:
        kw_cv = {}
    else:
        kw_cv = {'segment': args.segment, 'shift': args.shift}
    train_set = Wavset(root, metadata_train, args.sources,
                       segment=args.segment, shift=args.shift,
                       samplerate=args.samplerate, channels=args.channels,
                       normalize=args.normalize)
    valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources),
                       samplerate=args.samplerate, channels=args.channels,
                       normalize=args.normalize, **kw_cv)
    return train_set, valid_set