in connect-audio-stream-solution/kvsMLInferenceFunction/src/main/java/com/amazonaws/kvsmlinference/KVSMLInferenceLambda.java [115:191]
private void startKVSToPredictionStreaming(String streamARN, String startFragmentNum, String contactId, Optional<Boolean> saveCallRecording) throws Exception {
String streamName = streamARN.substring(streamARN.indexOf("/") + 1, streamARN.lastIndexOf("/"));
KVSStreamTrackObject kvsStreamTrackObjectFromCustomer = getKVSStreamTrackObject(streamName, startFragmentNum, KVSUtils.TrackName.AUDIO_FROM_CUSTOMER.getName(), contactId);
logger.info("Start to process KVS streaming and make prediction.");
if (kvsStreamTrackObjectFromCustomer != null) {
// get audio streaming from KVS to local file
ByteBuffer audioBuffer = KVSUtils.getByteBufferFromStream(kvsStreamTrackObjectFromCustomer.getStreamingMkvReader(),
kvsStreamTrackObjectFromCustomer.getFragmentVisitor(), kvsStreamTrackObjectFromCustomer.getTagProcessor(), contactId, kvsStreamTrackObjectFromCustomer.getTrackName());
while (audioBuffer.remaining() > 0) {
byte[] audioBytes = new byte[audioBuffer.remaining()];
audioBuffer.get(audioBytes);
kvsStreamTrackObjectFromCustomer.getOutputStream().write(audioBytes);
audioBuffer = KVSUtils.getByteBufferFromStream(kvsStreamTrackObjectFromCustomer.getStreamingMkvReader(),
kvsStreamTrackObjectFromCustomer.getFragmentVisitor(), kvsStreamTrackObjectFromCustomer.getTagProcessor(), contactId, kvsStreamTrackObjectFromCustomer.getTrackName());
}
String audioFilePath = kvsStreamTrackObjectFromCustomer.getSaveAudioFilePath().toString();
File audioFile = new File(audioFilePath);
logger.info("file path: "+audioFilePath);
logger.info("file size: "+audioFile.length());
//Upload the Raw Audio file to S3
kvsStreamTrackObjectFromCustomer.getInputStream().close();
kvsStreamTrackObjectFromCustomer.getOutputStream().close();
if (audioFile.length() > 0) {
String s3path = AudioUtils.uploadRawAudio(REGION, RECORDINGS_BUCKET_NAME, RECORDINGS_KEY_PREFIX, kvsStreamTrackObjectFromCustomer.getSaveAudioFilePath().toString(), contactId, RECORDINGS_PUBLIC_READ_ACL,
getAWSCredentials());
if (s3path.length()>1) {
logger.info("Audio file uploaded successfully to: " + s3path);
try {
//Invoke SageMaker Inference endpoint
AmazonSageMakerRuntime smclient = AmazonSageMakerRuntimeClientBuilder
.standard()
.withRegion(REGION)
.withCredentials(getAWSCredentials())
.build();
InvokeEndpointRequest invokeEndpointRequest = new InvokeEndpointRequest();
invokeEndpointRequest.setContentType("text/csv");
invokeEndpointRequest.setEndpointName(SM_ENDPOINT_NAME);
invokeEndpointRequest.setBody(ByteBuffer.wrap(s3path.getBytes("UTF-8")));
InvokeEndpointResult result = smclient.invokeEndpoint(invokeEndpointRequest);
String body = StandardCharsets.UTF_8.decode(result.getBody()).toString();
logger.info("SageMaker Inference result for the probability of positive class: "+body);
//Write to DynamoDB
AmazonDynamoDB ddbbuilder = AmazonDynamoDBClientBuilder
.standard()
.withRegion(REGION)
.build();
DynamoDB ddbclient = new DynamoDB(ddbbuilder);
Instant now = Instant.now();
Item ddbItem = new Item()
.withKeyComponent("ContactId", contactId)
.withKeyComponent("StartTime", now.toEpochMilli())
.withString("predictionTime", now.toString())
.withString("predictionBody", body);
PutItemOutcome outcome = ddbclient.getTable(TABLE_ML_INFERENCE).putItem(ddbItem);
logger.info("DynamoDB putItem result: "+outcome.toString());
} catch (UnsupportedEncodingException e) {
logger.error("Failed to invoke SageMaker Endpoint: ", e);
} catch (SdkClientException e) {
logger.error("Failed to invoke SageMaker Endpoint: ", e);
} catch (Exception e) {
logger.error("Exception while writing to DDB: ", e);
}
}
} else {
logger.info("Skipping upload to S3. saveCallRecording was disabled or audio file has 0 bytes: " + kvsStreamTrackObjectFromCustomer.getSaveAudioFilePath().toString());
}
}
}