def main()

in src/run_on_files.py [0:0]


def main(args):
    print('Starting')
    matplotlib.use('agg')

    checkpoints = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in args.decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_args = torch.load(args.checkpoint.parent / 'args.pth')[0]
    encoder = wavenet_models.Encoder(model_args)
    encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    encoder.eval()
    encoder = encoder.cuda()

    decoders = []
    decoder_ids = []
    for checkpoint in checkpoints:
        decoder = WaveNet(model_args)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()
        if args.py:
            decoder = WavenetGenerator(decoder, args.batch_size, wav_freq=args.rate)
        else:
            decoder = NVWavenetGenerator(decoder, args.rate * (args.split_size // 20), args.batch_size, 3)

        decoders += [decoder]
        decoder_ids += [extract_id(checkpoint)]

    xs = []
    assert args.output_next_to_orig ^ (args.output is not None)

    if len(args.files) == 1 and args.files[0].is_dir():
        top = args.files[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = args.files

    if not args.skip_filter:
        file_paths = [f for f in file_paths if not '_' in str(f.name)]

    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=16000)
            assert rate == 16000
            data = utils.mu_law(data)
        elif file_path.suffix == '.h5':
            data = utils.mu_law(h5py.File(file_path, 'r')['wav'][:] / (2 ** 15))
            if data.shape[-1] % args.rate != 0:
                data = data[:-(data.shape[-1] % args.rate)]
            assert data.shape[-1] % args.rate == 0
            print(data.shape)
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if args.sample_len:
            data = data[:args.sample_len]
        else:
            args.sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_ix, filepath):
        wav = utils.inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        if args.output_next_to_orig:
            save_audio(wav.squeeze(), filepath.parent / f'{filepath.stem}_{decoder_ix}.wav', rate=args.rate)
        else:
            save_audio(wav.squeeze(), args.output / str(decoder_ix) / filepath.with_suffix('.wav').name, rate=args.rate)

    yy = {}
    with torch.no_grad():
        zz = []
        for xs_batch in torch.split(xs, args.batch_size):
            zz += [encoder(xs_batch)]
        zz = torch.cat(zz, dim=0)

        with utils.timeit("Generation timer"):
            for i, decoder_id in enumerate(decoder_ids):
                yy[decoder_id] = []
                decoder = decoders[i]
                for zz_batch in torch.split(zz, args.batch_size):
                    print(zz_batch.shape)
                    splits = torch.split(zz_batch, args.split_size, -1)
                    audio_data = []
                    decoder.reset()
                    for cond in tqdm.tqdm(splits):
                        audio_data += [decoder.generate(cond).cpu()]
                    audio_data = torch.cat(audio_data, -1)
                    yy[decoder_id] += [audio_data]
                yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)
                del decoder

    for decoder_ix, decoder_result in yy.items():
        for sample_result, filepath in zip(decoder_result, file_paths):
            save(sample_result, decoder_ix, filepath)