std::map XrtComputationClient::GetMetrics()

in Sources/x10/xla_client/xrt_computation_client.cc [1743:1851]


std::map<std::string, Metric> XrtComputationClient::GetMetrics() const {
  static const std::map<std::string, std::string>* metric_remap =
      new std::map<std::string, std::string>{
          {"/tensorflow/xrt/ops/allocate", "XrtAllocate"},
          {"/tensorflow/xrt/ops/allocate_from_tensor", "XrtAllocateFromTensor"},
          {"/tensorflow/xrt/ops/sub_tuple", "XrtSubTuple"},
          {"/tensorflow/xrt/ops/make_tuple", "XrtMakeTuple"},
          {"/tensorflow/xrt/ops/compile", "XrtCompile"},
          {"/tensorflow/xrt/ops/release_compilation", "XrtReleaseCompilation"},
          {"/tensorflow/xrt/ops/execute", "XrtExecute"},
          {"/tensorflow/xrt/ops/execute_chained", "XrtExecuteChained"},
          {"/tensorflow/xrt/ops/read_literal", "XrtReadLiteral"},
          {"/tensorflow/xrt/ops/read_tensor", "XrtReadTensor"},
          {"/tensorflow/xrt/ops/write_literal", "XrtWriteLiteral"},
          {"/tensorflow/xrt/ops/release_allocation", "XrtReleaseAllocation"},
          {"/tensorflow/xrt/ops/release_all_allocations",
           "XrtReleaseAllAllocations"},
          {"/tensorflow/xrt/ops/compact_allocations", "XrtCompactAllocations"},
          {"/tensorflow/xrt/memory_manager/compaction", "XrtCompaction"},
          {"/tensorflow/xrt/memory_manager/try_free_memory",
           "XrtTryFreeMemory"},
          {"/tensorflow/xrt/executor/program_memory_evict", "XrtExecutorEvict"},
          {"/tensorflow/xrt/ds_executor/program_memory_evict",
           "XrtExecutorEvict"}};

  std::map<std::string, Metric> metrics_data;
  xrt::XRTMetricsCollect metrics;
  metrics.add_metrics_regex("/tensorflow/xrt/.*");

  for (auto& worker_target : options_.workers_map) {
    tensorflow::SessionOptions session_options;
    session_options.env = tensorflow::Env::Default();
    session_options.target = worker_target.second;
    session_options.config = session_cache_->GetConfig();

    tensorflow::Scope root = tensorflow::Scope::NewRootScope();
    tensorflow::ClientSession session(root, session_options);
    std::string cpu0_device = absl::StrCat(
        "/job:", worker_target.first.name,
        "/replica:0/task:", worker_target.first.task_no, "/device:CPU:0");
    tensorflow::Scope cpu_system_scope = root.WithDevice(cpu0_device);
    auto metrics_value =
        tensorflow::ops::Const(cpu_system_scope, metrics.SerializeAsString());
    tensorflow::Output result =
        tensorflow::ops::XRTMetricsCollect(cpu_system_scope, metrics_value);
    XLA_CHECK_OK(cpu_system_scope.status());

    std::vector<tensorflow::Tensor> outputs;
    XLA_CHECK_OK(session.Run({result}, &outputs));
    XLA_CHECK_EQ(outputs.size(), 1);

    xrt::MetricsReport report = ParseProto<xrt::MetricsReport>(outputs[0]);
    for (auto& xrt_metric : report.metrics()) {
      Metric metric;
      if (xrt_metric.values_oneof_case() ==
          xrt::MetricValues::kPercentilesValue) {
        const xrt::Percentiles& xrt_percentile = xrt_metric.percentiles_value();
        Percentile percentile;
        switch (xrt_metric.unit_of_measure()) {
          case xrt::MetricValues::NUMBER:
            percentile.unit_of_measure = Percentile::UnitOfMeaure::kNumber;
            break;
          case xrt::MetricValues::TIME:
            percentile.unit_of_measure = Percentile::UnitOfMeaure::kTime;
            break;
          case xrt::MetricValues::BYTES:
            percentile.unit_of_measure = Percentile::UnitOfMeaure::kBytes;
            break;
          default:
            TF_LOG(FATAL) << "Invalid unit of measure";
            break;
        }
        percentile.start_nstime = xrt_percentile.start_nstime();
        percentile.end_nstime = xrt_percentile.end_nstime();
        percentile.min_value = xrt_percentile.min_value();
        percentile.max_value = xrt_percentile.max_value();
        percentile.mean = xrt_percentile.mean();
        percentile.stddev = xrt_percentile.stddev();
        percentile.num_samples = xrt_percentile.num_samples();
        percentile.total_samples = xrt_percentile.total_samples();
        percentile.accumulator = xrt_percentile.accumulator();
        for (auto& xrt_point : xrt_percentile.points()) {
          percentile.points.push_back(
              Percentile::Point{xrt_point.percentile(), xrt_point.value()});
        }
        metric.percentile = std::move(percentile);
      } else if (xrt_metric.values_oneof_case() ==
                 xrt::MetricValues::kInt64Value) {
        metric.int64_value = xrt_metric.int64_value();
      } else {
        continue;
      }

      std::string metric_name;
      auto it = metric_remap->find(xrt_metric.name());
      if (it != metric_remap->end()) {
        metric_name = it->second;
      } else {
        metric_name = xrt_metric.name();
      }
      if (options_.workers_map.size() > 1) {
        absl::StrAppend(&metric_name, ".", worker_target.first.name, ".",
                        worker_target.first.task_no);
      }
      metrics_data.emplace(std::move(metric_name), std::move(metric));
    }
  }
  return metrics_data;
}