public Object doWork()

in solr/solrj-streaming/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java [37:191]


  public Object doWork(Object... objects) throws IOException {
    if (objects.length != 2 && objects.length != 3) {
      throw new IOException("The predict function expects 2 or 3 parameters.");
    }

    Object first = objects[0];
    Object second = objects[1];

    if (!(first instanceof BivariateFunction)
        && !(first instanceof VectorFunction)
        && !(first instanceof RegressionEvaluator.RegressionTuple)
        && !(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple)
        && !(first instanceof KnnRegressionEvaluator.KnnRegressionTuple)) {
      throw new IOException(
          String.format(
              Locale.ROOT,
              "Invalid expression %s - found type %s for the first value, expecting a RegressionTuple",
              toExpression(constructingFactory),
              first.getClass().getSimpleName()));
    }

    if (!(second instanceof Number)
        && !(second instanceof List<?>)
        && !(second instanceof Matrix)) {
      throw new IOException(
          String.format(
              Locale.ROOT,
              "Invalid expression %s - found type %s for the second value, expecting a Number, Array or Matrix",
              toExpression(constructingFactory),
              first.getClass().getSimpleName()));
    }

    if (first instanceof RegressionEvaluator.RegressionTuple regressedTuple) {

      if (second instanceof Number) {
        return regressedTuple.predict(((Number) second).doubleValue());
      } else {
        return ((List<?>) second)
            .stream()
                .map(value -> regressedTuple.predict(((Number) value).doubleValue()))
                .collect(Collectors.toList());
      }

    } else if (first instanceof OLSRegressionEvaluator.MultipleRegressionTuple regressedTuple) {

      if (second instanceof List) {
        @SuppressWarnings({"unchecked"})
        List<Number> list = (List<Number>) second;
        double[] predictors = new double[list.size()];

        for (int i = 0; i < list.size(); i++) {
          predictors[i] = list.get(i).doubleValue();
        }

        return regressedTuple.predict(predictors);
      } else if (second instanceof Matrix m) {

        double[][] data = m.getData();
        List<Number> predictions = new ArrayList<>();
        for (double[] predictors : data) {
          predictions.add(regressedTuple.predict(predictors));
        }
        return predictions;
      }

    } else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple regressedTuple) {

      if (regressedTuple.getBivariate()) {
        // Handle bi-variate regression
        if (second instanceof Number) {
          double[] predictors = new double[1];
          predictors[0] = ((Number) second).doubleValue();
          return regressedTuple.predict(predictors);
        } else if (second instanceof List) {
          @SuppressWarnings({"unchecked"})
          List<Number> vec = (List<Number>) second;
          List<Number> predictions = new ArrayList<>();
          for (Number num : vec) {
            double[] predictors = new double[1];
            predictors[0] = num.doubleValue();
            predictions.add(regressedTuple.predict(predictors));
          }
          return predictions;
        }
      } else {
        // Handle multi-variate regression
        if (second instanceof List) {
          @SuppressWarnings({"unchecked"})
          List<Number> list = (List<Number>) second;
          double[] predictors = new double[list.size()];

          for (int i = 0; i < list.size(); i++) {
            predictors[i] = list.get(i).doubleValue();
          }

          if (regressedTuple.getScale()) {
            predictors = regressedTuple.scale(predictors);
          }

          return regressedTuple.predict(predictors);
        } else if (second instanceof Matrix m) {

          if (regressedTuple.getScale()) {
            m = regressedTuple.scale(m);
          }
          double[][] data = m.getData();
          List<Number> predictions = new ArrayList<>();
          for (double[] predictors : data) {
            predictions.add(regressedTuple.predict(predictors));
          }
          return predictions;
        }
      }
    } else if (first instanceof VectorFunction vectorFunction) {
      UnivariateFunction univariateFunction = (UnivariateFunction) vectorFunction.getFunction();
      if (second instanceof Number) {
        double x = ((Number) second).doubleValue();
        return univariateFunction.value(x);
      } else {
        return ((List<?>) second)
            .stream()
                .map(value -> univariateFunction.value(((Number) value).doubleValue()))
                .collect(Collectors.toList());
      }
    } else if (first instanceof BivariateFunction bivariateFunction) {
      if (objects.length == 3) {
        Object third = objects[2];
        double x = 0.0;
        double y = 0.0;
        if (second instanceof Number && third instanceof Number) {
          x = ((Number) second).doubleValue();
          y = ((Number) third).doubleValue();
          return bivariateFunction.value(x, y);
        } else {
          throw new IOException("BivariateFunction requires two numeric parameters.");
        }
      } else if (objects.length == 2) {
        if (second instanceof Matrix m) {
          double[][] data = m.getData();
          if (data[0].length == 2) {
            List<Number> out = new ArrayList<>();
            for (double[] row : data) {
              out.add(bivariateFunction.value(row[0], row[1]));
            }
            return out;
          } else {
            throw new IOException("Bivariate Function expects a matrix with two columns");
          }
        } else {
          throw new IOException("Bivariate Function requires a matrix parameter.");
        }
      }
    }
    return null;
  }