std::vector XrtComputationClient::ExecuteChainedXrt()

in Sources/x10/xla_client/xrt_computation_client.cc [1047:1126]


std::vector<ComputationClient::DataPtr> XrtComputationClient::ExecuteChainedXrt(
    absl::Span<const ExecuteChainedOp> ops, const std::string& device) {
  metrics::TimedSection timed(ExecuteChainedMetric());

  XrtSessionCache::SessionMap session_map;
  std::string effective_device = GetEffectiveDevice(device);
  const std::string& xrt_device = SwiftDeviceToXrtDevice(effective_device);
  tensorflow::ClientSession::FeedType feed_inputs;
  XrtSession* session =
      GetSessionForXrtDevice(session_cache_.get(), xrt_device, &session_map);
  tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device);

  xrt::XRTChainedExecuteConfig config;
  config.set_core_index_in_replica(0);
  config.set_rng_seed(rng_seed_);

  xrt::XRTChainedExecutePlan plan;
  std::vector<xla::Shape> result_shapes;
  for (size_t i = 0; i < ops.size(); ++i) {
    const ExecuteChainedOp& op = ops[i];
    xrt::XRTChainedExecuteOp* plan_op = plan.add_ops();
    const xla::Shape* op_shape = nullptr;
    if (op.device_data != nullptr) {
      const XrtData& xrt_data = dynamic_cast<const XrtData&>(*op.device_data);
      op_shape = &xrt_data.shape();
      plan_op->set_data_handle(xrt_data.get_handle());
    } else {
      const XrtComputation& xrt_computation =
          dynamic_cast<const XrtComputation&>(*op.computation);
      op_shape = &xrt_computation.program_shape().result();
      plan_op->set_computation_handle(xrt_computation.get_handle());
      for (auto& input : op.inputs) {
        XLA_CHECK_LT(input.op_index, i);

        xrt::XRTChainedExecuteOp::Input* plan_input = plan_op->add_inputs();
        plan_input->set_op_index(input.op_index);
        if (input.output_index) {
          plan_input->set_output_index(*input.output_index + 1);
        }
      }
    }
    for (auto& output : op.outputs) {
      XLA_CHECK(op_shape != nullptr);

      xrt::XRTChainedExecuteOp::Output* plan_output = plan_op->add_outputs();
      plan_output->set_result_index(output.result_index);
      if (output.result_index >= result_shapes.size()) {
        result_shapes.resize(output.result_index + 1);
      }
      if (output.output_index) {
        plan_output->set_output_index(*output.output_index + 1);
        result_shapes[output.result_index] =
            ShapeUtil::GetTupleElementShape(*op_shape, *output.output_index);
      } else {
        result_shapes[output.result_index] = *op_shape;
      }
    }
  }

  const XrtSession::CachedNode& cached_node =
      GetExecuteChainedNode(session, device_scope, effective_device);
  feed_inputs.insert({cached_node.holders[0], plan.SerializeAsString()});
  feed_inputs.insert({cached_node.holders[1], config.SerializeAsString()});

  std::vector<tensorflow::Tensor> outputs;
  util::CheckComputationStatus(
      session->session()->Run(feed_inputs, {cached_node.outputs[0]}, &outputs),
      {}, {});
  XLA_CHECK_EQ(outputs.size(), 1);

  std::vector<DataPtr> results;
  auto handles_vec = outputs[0].vec<int64>();
  for (int64 i = 0; i < handles_vec.size(); ++i) {
    results.push_back(std::make_shared<XrtData>(
        dynamic_cast<XrtDevice*>(GetDevice(effective_device)),
        std::move(result_shapes.at(i)), handles_vec(i)));
  }
  CreateDataHandlesCounter()->AddValue(results.size());
  return results;
}