in react-native-pytorch-core/android/src/main/java/org/pytorch/rn/core/ml/processing/BaseIValuePacker.java [264:341]
private WritableArray decodeObjects(
final IValue ivalue, final JSONObject jobject, final PackerContext packerContext)
throws JSONException {
final Map<String, IValue> map = ivalue.toDictStringKey();
IValue predLogits = map.get("pred_logits");
IValue predBoxes = map.get("pred_boxes");
final String PROBABILITY_THRESHOLD_KEY = "probabilityThreshold";
if (!jobject.has(PROBABILITY_THRESHOLD_KEY)) {
throw new IllegalStateException(
"model param value for " + PROBABILITY_THRESHOLD_KEY + " is missing [0, 1]");
}
double probabilityThreshold = jobject.getDouble(PROBABILITY_THRESHOLD_KEY);
final String CLASSES_KEY = "classes";
String[] classes;
if (packerContext.get(CLASSES_KEY) != null) {
classes = (String[]) packerContext.get(CLASSES_KEY);
} else {
if (!jobject.has(CLASSES_KEY)) {
throw new IllegalStateException(
CLASSES_KEY
+ "classes property is missing in the unpack definition for bounding_boxes unpack type");
}
try {
JSONArray classesArray = jobject.getJSONArray(CLASSES_KEY);
classes = toStringArray(classesArray);
packerContext.store(CLASSES_KEY, classes);
} catch (JSONException e) {
throw new IllegalStateException(
CLASSES_KEY
+ "classes property in the unpack definition for bounding_boxes needs to be an array of strings");
}
}
final Tensor predLogitsTensor = predLogits.toTensor();
final float[] confidencesTensor = predLogitsTensor.getDataAsFloatArray();
final long[] confidencesShape = predLogitsTensor.shape();
final int numClasses = (int) predLogitsTensor.shape()[2];
final Tensor predBoxesTensor = predBoxes.toTensor();
final float[] locationsTensor = predBoxesTensor.getDataAsFloatArray();
final long[] locationsShape = predBoxesTensor.shape();
WritableArray result = Arguments.createArray();
for (int i = 0; i < confidencesShape[1]; i++) {
float[] scores = softmax(confidencesTensor, i * numClasses, (i + 1) * numClasses);
float maxProb = scores[0];
int maxIndex = -1;
for (int j = 0; j < scores.length; j++) {
if (scores[j] > maxProb) {
maxProb = scores[j];
maxIndex = j;
}
}
if (maxProb <= probabilityThreshold || maxIndex >= classes.length) {
continue;
}
WritableMap match = Arguments.createMap();
match.putString("objectClass", classes[maxIndex]);
int locationsFrom = (int) (i * locationsShape[2]);
WritableArray bounds = Arguments.createArray();
bounds.pushDouble(locationsTensor[locationsFrom]);
bounds.pushDouble(locationsTensor[locationsFrom + 1]);
bounds.pushDouble(locationsTensor[locationsFrom + 2]);
bounds.pushDouble(locationsTensor[locationsFrom + 3]);
match.putArray("bounds", bounds);
result.pushMap(match);
}
return result;
}