tf::Status RunInference()

in tensorflow_decision_forests/tensorflow/ops/inference/kernel.cc [365:461]


  tf::Status RunInference(const InputTensors& inputs,
                          const FeatureIndex& feature_index,
                          OutputTensors* outputs,
                          AbstractCache* abstract_cache) const override {
    // Update the vertical dataset with the input tensors.
    auto* cache = dynamic_cast<Cache*>(abstract_cache);
    if (cache == nullptr) {
      return tf::Status(tf::error::INTERNAL, "Unexpected cache type.");
    }
    TF_RETURN_IF_ERROR(SetVerticalDataset(inputs, feature_index, cache));

    // Run the model.
    model::proto::Prediction prediction;
    for (int example_idx = 0; example_idx < inputs.batch_size; example_idx++) {
      model_->Predict(cache->dataset_, example_idx, &prediction);

      // Copy the predictions to the output tensor.
      switch (model_->task()) {
        case Task::CLASSIFICATION: {
          const auto& pred = prediction.classification();
          // Note: "pred" contains a probability for each possible classes.
          // Because the label is categorical, the first label value (i.e. index
          // 0) is reserved for the Out-of-vocabulary value. As simpleML models
          // are not expected to output such value, we skip it (see the ".. - 1"
          // and ".. + 1" in the next part of the code).
          DCHECK_EQ(outputs->dense_predictions.dimension(1),
                    outputs->output_dim);
          const bool output_is_proba =
              model_->classification_outputs_probabilities();
          if (outputs->output_dim == 1 && !output_is_proba) {
            // Output the logit of the positive class.
            if (pred.distribution().counts().size() != 3) {
              return tf::Status(tf::error::INTERNAL,
                                "Wrong \"distribution\" shape.");
            }
            const float logit =
                prediction.classification().distribution().counts(2) /
                prediction.classification().distribution().sum();
            outputs->dense_predictions(example_idx, 1) = logit;
          } else {
            // Output the logit or probabilities.
            if (outputs->dense_predictions.dimension(1) !=
                pred.distribution().counts().size() - 1) {
              return tf::Status(tf::error::INTERNAL,
                                "Wrong \"distribution\" shape.");
            }
            for (int class_idx = 0; class_idx < outputs->output_dim;
                 class_idx++) {
              const float output =
                  prediction.classification().distribution().counts(class_idx +
                                                                    1) /
                  prediction.classification().distribution().sum();
              outputs->dense_predictions(example_idx, class_idx) =
                  output_is_proba ? utils::clamp(output, 0.f, 1.f) : output;
            }
          }
        } break;

        case Task::REGRESSION: {
          DCHECK_EQ(outputs->output_dim, 1);
          DCHECK_EQ(outputs->dense_predictions.dimension(1), 1);
          outputs->dense_predictions(example_idx, 0) =
              prediction.regression().value();
        } break;

        case Task::RANKING: {
          DCHECK_EQ(outputs->output_dim, 1);
          DCHECK_EQ(outputs->dense_predictions.dimension(1), 1);
          outputs->dense_predictions(example_idx, 0) =
              prediction.ranking().relevance();
        } break;

        case Task::CATEGORICAL_UPLIFT: {
          DCHECK_EQ(outputs->dense_predictions.dimension(1),
                    outputs->output_dim);
          const auto& pred = prediction.uplift();
          if (outputs->dense_predictions.dimension(1) !=
              pred.treatment_effect_size()) {
            return tf::Status(tf::error::INTERNAL,
                              "Wrong \"distribution\" shape.");
          }
          for (int uplift_idx = 0; uplift_idx < outputs->output_dim;
               uplift_idx++) {
            outputs->dense_predictions(example_idx, uplift_idx) =
                pred.treatment_effect(uplift_idx);
          }
        } break;

        default:
          return tf::Status(tf::error::UNIMPLEMENTED,
                            absl::Substitute("Non supported task $0",
                                             Task_Name(model_->task())));
      }
    }

    return tf::Status::OK();
  }