in tensorflow_decision_forests/tensorflow/ops/inference/kernel.cc [365:461]
tf::Status RunInference(const InputTensors& inputs,
const FeatureIndex& feature_index,
OutputTensors* outputs,
AbstractCache* abstract_cache) const override {
// Update the vertical dataset with the input tensors.
auto* cache = dynamic_cast<Cache*>(abstract_cache);
if (cache == nullptr) {
return tf::Status(tf::error::INTERNAL, "Unexpected cache type.");
}
TF_RETURN_IF_ERROR(SetVerticalDataset(inputs, feature_index, cache));
// Run the model.
model::proto::Prediction prediction;
for (int example_idx = 0; example_idx < inputs.batch_size; example_idx++) {
model_->Predict(cache->dataset_, example_idx, &prediction);
// Copy the predictions to the output tensor.
switch (model_->task()) {
case Task::CLASSIFICATION: {
const auto& pred = prediction.classification();
// Note: "pred" contains a probability for each possible classes.
// Because the label is categorical, the first label value (i.e. index
// 0) is reserved for the Out-of-vocabulary value. As simpleML models
// are not expected to output such value, we skip it (see the ".. - 1"
// and ".. + 1" in the next part of the code).
DCHECK_EQ(outputs->dense_predictions.dimension(1),
outputs->output_dim);
const bool output_is_proba =
model_->classification_outputs_probabilities();
if (outputs->output_dim == 1 && !output_is_proba) {
// Output the logit of the positive class.
if (pred.distribution().counts().size() != 3) {
return tf::Status(tf::error::INTERNAL,
"Wrong \"distribution\" shape.");
}
const float logit =
prediction.classification().distribution().counts(2) /
prediction.classification().distribution().sum();
outputs->dense_predictions(example_idx, 1) = logit;
} else {
// Output the logit or probabilities.
if (outputs->dense_predictions.dimension(1) !=
pred.distribution().counts().size() - 1) {
return tf::Status(tf::error::INTERNAL,
"Wrong \"distribution\" shape.");
}
for (int class_idx = 0; class_idx < outputs->output_dim;
class_idx++) {
const float output =
prediction.classification().distribution().counts(class_idx +
1) /
prediction.classification().distribution().sum();
outputs->dense_predictions(example_idx, class_idx) =
output_is_proba ? utils::clamp(output, 0.f, 1.f) : output;
}
}
} break;
case Task::REGRESSION: {
DCHECK_EQ(outputs->output_dim, 1);
DCHECK_EQ(outputs->dense_predictions.dimension(1), 1);
outputs->dense_predictions(example_idx, 0) =
prediction.regression().value();
} break;
case Task::RANKING: {
DCHECK_EQ(outputs->output_dim, 1);
DCHECK_EQ(outputs->dense_predictions.dimension(1), 1);
outputs->dense_predictions(example_idx, 0) =
prediction.ranking().relevance();
} break;
case Task::CATEGORICAL_UPLIFT: {
DCHECK_EQ(outputs->dense_predictions.dimension(1),
outputs->output_dim);
const auto& pred = prediction.uplift();
if (outputs->dense_predictions.dimension(1) !=
pred.treatment_effect_size()) {
return tf::Status(tf::error::INTERNAL,
"Wrong \"distribution\" shape.");
}
for (int uplift_idx = 0; uplift_idx < outputs->output_dim;
uplift_idx++) {
outputs->dense_predictions(example_idx, uplift_idx) =
pred.treatment_effect(uplift_idx);
}
} break;
default:
return tf::Status(tf::error::UNIMPLEMENTED,
absl::Substitute("Non supported task $0",
Task_Name(model_->task())));
}
}
return tf::Status::OK();
}