in solr/solrj-streaming/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java [37:191]
public Object doWork(Object... objects) throws IOException {
if (objects.length != 2 && objects.length != 3) {
throw new IOException("The predict function expects 2 or 3 parameters.");
}
Object first = objects[0];
Object second = objects[1];
if (!(first instanceof BivariateFunction)
&& !(first instanceof VectorFunction)
&& !(first instanceof RegressionEvaluator.RegressionTuple)
&& !(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple)
&& !(first instanceof KnnRegressionEvaluator.KnnRegressionTuple)) {
throw new IOException(
String.format(
Locale.ROOT,
"Invalid expression %s - found type %s for the first value, expecting a RegressionTuple",
toExpression(constructingFactory),
first.getClass().getSimpleName()));
}
if (!(second instanceof Number)
&& !(second instanceof List<?>)
&& !(second instanceof Matrix)) {
throw new IOException(
String.format(
Locale.ROOT,
"Invalid expression %s - found type %s for the second value, expecting a Number, Array or Matrix",
toExpression(constructingFactory),
first.getClass().getSimpleName()));
}
if (first instanceof RegressionEvaluator.RegressionTuple regressedTuple) {
if (second instanceof Number) {
return regressedTuple.predict(((Number) second).doubleValue());
} else {
return ((List<?>) second)
.stream()
.map(value -> regressedTuple.predict(((Number) value).doubleValue()))
.collect(Collectors.toList());
}
} else if (first instanceof OLSRegressionEvaluator.MultipleRegressionTuple regressedTuple) {
if (second instanceof List) {
@SuppressWarnings({"unchecked"})
List<Number> list = (List<Number>) second;
double[] predictors = new double[list.size()];
for (int i = 0; i < list.size(); i++) {
predictors[i] = list.get(i).doubleValue();
}
return regressedTuple.predict(predictors);
} else if (second instanceof Matrix m) {
double[][] data = m.getData();
List<Number> predictions = new ArrayList<>();
for (double[] predictors : data) {
predictions.add(regressedTuple.predict(predictors));
}
return predictions;
}
} else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple regressedTuple) {
if (regressedTuple.getBivariate()) {
// Handle bi-variate regression
if (second instanceof Number) {
double[] predictors = new double[1];
predictors[0] = ((Number) second).doubleValue();
return regressedTuple.predict(predictors);
} else if (second instanceof List) {
@SuppressWarnings({"unchecked"})
List<Number> vec = (List<Number>) second;
List<Number> predictions = new ArrayList<>();
for (Number num : vec) {
double[] predictors = new double[1];
predictors[0] = num.doubleValue();
predictions.add(regressedTuple.predict(predictors));
}
return predictions;
}
} else {
// Handle multi-variate regression
if (second instanceof List) {
@SuppressWarnings({"unchecked"})
List<Number> list = (List<Number>) second;
double[] predictors = new double[list.size()];
for (int i = 0; i < list.size(); i++) {
predictors[i] = list.get(i).doubleValue();
}
if (regressedTuple.getScale()) {
predictors = regressedTuple.scale(predictors);
}
return regressedTuple.predict(predictors);
} else if (second instanceof Matrix m) {
if (regressedTuple.getScale()) {
m = regressedTuple.scale(m);
}
double[][] data = m.getData();
List<Number> predictions = new ArrayList<>();
for (double[] predictors : data) {
predictions.add(regressedTuple.predict(predictors));
}
return predictions;
}
}
} else if (first instanceof VectorFunction vectorFunction) {
UnivariateFunction univariateFunction = (UnivariateFunction) vectorFunction.getFunction();
if (second instanceof Number) {
double x = ((Number) second).doubleValue();
return univariateFunction.value(x);
} else {
return ((List<?>) second)
.stream()
.map(value -> univariateFunction.value(((Number) value).doubleValue()))
.collect(Collectors.toList());
}
} else if (first instanceof BivariateFunction bivariateFunction) {
if (objects.length == 3) {
Object third = objects[2];
double x = 0.0;
double y = 0.0;
if (second instanceof Number && third instanceof Number) {
x = ((Number) second).doubleValue();
y = ((Number) third).doubleValue();
return bivariateFunction.value(x, y);
} else {
throw new IOException("BivariateFunction requires two numeric parameters.");
}
} else if (objects.length == 2) {
if (second instanceof Matrix m) {
double[][] data = m.getData();
if (data[0].length == 2) {
List<Number> out = new ArrayList<>();
for (double[] row : data) {
out.add(bivariateFunction.value(row[0], row[1]));
}
return out;
} else {
throw new IOException("Bivariate Function expects a matrix with two columns");
}
} else {
throw new IOException("Bivariate Function requires a matrix parameter.");
}
}
}
return null;
}