private WritableArray decodeObjects()

in react-native-pytorch-core/android/src/main/java/org/pytorch/rn/core/ml/processing/BaseIValuePacker.java [264:341]


  private WritableArray decodeObjects(
      final IValue ivalue, final JSONObject jobject, final PackerContext packerContext)
      throws JSONException {
    final Map<String, IValue> map = ivalue.toDictStringKey();
    IValue predLogits = map.get("pred_logits");
    IValue predBoxes = map.get("pred_boxes");

    final String PROBABILITY_THRESHOLD_KEY = "probabilityThreshold";
    if (!jobject.has(PROBABILITY_THRESHOLD_KEY)) {
      throw new IllegalStateException(
          "model param value for " + PROBABILITY_THRESHOLD_KEY + " is missing [0, 1]");
    }
    double probabilityThreshold = jobject.getDouble(PROBABILITY_THRESHOLD_KEY);

    final String CLASSES_KEY = "classes";
    String[] classes;
    if (packerContext.get(CLASSES_KEY) != null) {
      classes = (String[]) packerContext.get(CLASSES_KEY);
    } else {
      if (!jobject.has(CLASSES_KEY)) {
        throw new IllegalStateException(
            CLASSES_KEY
                + "classes property is missing in the unpack definition for bounding_boxes unpack type");
      }
      try {
        JSONArray classesArray = jobject.getJSONArray(CLASSES_KEY);
        classes = toStringArray(classesArray);
        packerContext.store(CLASSES_KEY, classes);
      } catch (JSONException e) {
        throw new IllegalStateException(
            CLASSES_KEY
                + "classes property in the unpack definition for bounding_boxes needs to be an array of strings");
      }
    }

    final Tensor predLogitsTensor = predLogits.toTensor();
    final float[] confidencesTensor = predLogitsTensor.getDataAsFloatArray();
    final long[] confidencesShape = predLogitsTensor.shape();
    final int numClasses = (int) predLogitsTensor.shape()[2];

    final Tensor predBoxesTensor = predBoxes.toTensor();
    final float[] locationsTensor = predBoxesTensor.getDataAsFloatArray();
    final long[] locationsShape = predBoxesTensor.shape();

    WritableArray result = Arguments.createArray();

    for (int i = 0; i < confidencesShape[1]; i++) {
      float[] scores = softmax(confidencesTensor, i * numClasses, (i + 1) * numClasses);

      float maxProb = scores[0];
      int maxIndex = -1;
      for (int j = 0; j < scores.length; j++) {
        if (scores[j] > maxProb) {
          maxProb = scores[j];
          maxIndex = j;
        }
      }

      if (maxProb <= probabilityThreshold || maxIndex >= classes.length) {
        continue;
      }

      WritableMap match = Arguments.createMap();
      match.putString("objectClass", classes[maxIndex]);

      int locationsFrom = (int) (i * locationsShape[2]);
      WritableArray bounds = Arguments.createArray();
      bounds.pushDouble(locationsTensor[locationsFrom]);
      bounds.pushDouble(locationsTensor[locationsFrom + 1]);
      bounds.pushDouble(locationsTensor[locationsFrom + 2]);
      bounds.pushDouble(locationsTensor[locationsFrom + 3]);
      match.putArray("bounds", bounds);

      result.pushMap(match);
    }

    return result;
  }