in Java/src/main/java/com/example/customername/Ingest.java [113:176]
public static long batchIngest(int totalNumOfThreadsToCreate, SageMakerFeatureStoreRuntimeClient sageMakerFeatureStoreRuntimeClient, List < List < FeatureValue >> featureRecordsList, String[] featureGroupNames, String eventTimeName){
List<Ingest> ingestThreads = new ArrayList<Ingest>();
PerfMetrics batchIngestMetrics = new PerfMetrics("Batch ingestion metrics");
// Split config for multi-threaded ingestion wiht multiple feature groups
int numofThreadsPerFeatureGroup = (totalNumOfThreadsToCreate / featureGroupNames.length);
int increment = featureRecordsList.size() / numofThreadsPerFeatureGroup;
// Create ingestion threads with the proper split data
int count = 0;
for(String featureGroupName : featureGroupNames){
int startIdx = 0;
int endIdx = increment;
int numOfThreadsLeftToCreate = numofThreadsPerFeatureGroup;
do {
// Deep copy subset to allocate to thread in order to add EventTime timestamp at putRecord call
List <List<FeatureValue>> subSetList = deepCopy(featureRecordsList.subList(startIdx, endIdx));
Ingest ingest = new Ingest(sageMakerFeatureStoreRuntimeClient, subSetList, featureGroupName, eventTimeName);
ingest.setName(String.format("Ingest_%1$d", count++));
// Add to List of threads to keep track
ingestThreads.add(ingest);
// Update indexes
startIdx = endIdx;
endIdx += increment;
if(endIdx > featureRecordsList.size() - 1){
endIdx = featureRecordsList.size();
}
numOfThreadsLeftToCreate--;
} while(numOfThreadsLeftToCreate > 0);
}
// Run all threads
System.out.println("Starting batch ingestion");
batchIngestMetrics.startTimer();
for(Ingest ingest: ingestThreads){
ingest.start();
}
System.out.println("Number of created threads: " + ingestThreads.size());
// Continuously check to see if all threads of the thread group have finished
int totalNumOfIngestedRecords = 0;
do {
for(int i = 0; i < ingestThreads.size(); i++){
Ingest thread = ingestThreads.get(i);
if (!thread.isAlive() && thread.getState() == Thread.State.TERMINATED){
totalNumOfIngestedRecords += thread.getNumIngested();
batchIngestMetrics.addMultiIntervals(thread.getIngestMetrics().getLatencies());
System.out.println(String.format("Thread: %1$s, State: %2$s", thread.getName(), thread.getState()));
// Remove the thread from the list of threads
ingestThreads.remove(i);
}
}
} while (ingestThreads.size() > 0);
batchIngestMetrics.endTimer();
System.out.println(String.format("\nIngestion finished \nIngested %1$d of %2$d", totalNumOfIngestedRecords, featureRecordsList.size() * featureGroupNames.length));
batchIngestMetrics.printMetrics();
return batchIngestMetrics.getTotalTime();
}