def input_fn()

in sagemaker-voice-classification/notebook/inference.py [0:0]


def input_fn(request_body, request_content_type):
    if request_content_type == 'text/csv':
        new_sr=8000
        audio_len=20
        sampling_ratio=5
        tmp=request_body[5:]
        bucket=tmp[:tmp.index('/')]
        print("bucket: {}".format(bucket))
        obj=tmp[tmp.index('/')+1:]
        print("object: {}".format(obj))
        s3.download_file(bucket, obj, '/audioinput.wav')
        print("audio input file size: {}".format(os.path.getsize('/audioinput.wav')))
        waveform, sample_rate = torchaudio.load('/audioinput.wav')
        waveform = torchaudio.transforms.Resample(sample_rate, new_sr)(waveform[0, :].view(1, -1))
        const_len = new_sr * audio_len
        tempData = torch.zeros([1, const_len])
        if waveform.shape[1] < const_len:
            tempData[0, : waveform.shape[1]] = waveform[:]
        else:
            tempData[0, :] = waveform[0, :const_len]
        sound = tempData
        tempData = torch.zeros([1, const_len])
        if sound.shape[1] < const_len:
            tempData[0, : sound.shape[1]] = sound[:]
        else:
            tempData[0, :] = sound[0, :const_len]
        sound = tempData
        new_const_len = const_len // sampling_ratio
        soundFormatted = torch.zeros([1, 1, new_const_len])
        soundFormatted[0, 0, :] = sound[0, ::5]
        return soundFormatted
    elif request_content_type in ['application/x-npy', 'application/python-pickle']:
        return torch.tensor(np.load(BytesIO(request_body), allow_pickle=True))
    else:
        print("unknown request content type: {}".format(request_content_type))
        return request_body