in Sources/x10/xla_client/xrt_computation_client.cc [1047:1126]
std::vector<ComputationClient::DataPtr> XrtComputationClient::ExecuteChainedXrt(
absl::Span<const ExecuteChainedOp> ops, const std::string& device) {
metrics::TimedSection timed(ExecuteChainedMetric());
XrtSessionCache::SessionMap session_map;
std::string effective_device = GetEffectiveDevice(device);
const std::string& xrt_device = SwiftDeviceToXrtDevice(effective_device);
tensorflow::ClientSession::FeedType feed_inputs;
XrtSession* session =
GetSessionForXrtDevice(session_cache_.get(), xrt_device, &session_map);
tensorflow::Scope device_scope = session->root()->WithDevice(xrt_device);
xrt::XRTChainedExecuteConfig config;
config.set_core_index_in_replica(0);
config.set_rng_seed(rng_seed_);
xrt::XRTChainedExecutePlan plan;
std::vector<xla::Shape> result_shapes;
for (size_t i = 0; i < ops.size(); ++i) {
const ExecuteChainedOp& op = ops[i];
xrt::XRTChainedExecuteOp* plan_op = plan.add_ops();
const xla::Shape* op_shape = nullptr;
if (op.device_data != nullptr) {
const XrtData& xrt_data = dynamic_cast<const XrtData&>(*op.device_data);
op_shape = &xrt_data.shape();
plan_op->set_data_handle(xrt_data.get_handle());
} else {
const XrtComputation& xrt_computation =
dynamic_cast<const XrtComputation&>(*op.computation);
op_shape = &xrt_computation.program_shape().result();
plan_op->set_computation_handle(xrt_computation.get_handle());
for (auto& input : op.inputs) {
XLA_CHECK_LT(input.op_index, i);
xrt::XRTChainedExecuteOp::Input* plan_input = plan_op->add_inputs();
plan_input->set_op_index(input.op_index);
if (input.output_index) {
plan_input->set_output_index(*input.output_index + 1);
}
}
}
for (auto& output : op.outputs) {
XLA_CHECK(op_shape != nullptr);
xrt::XRTChainedExecuteOp::Output* plan_output = plan_op->add_outputs();
plan_output->set_result_index(output.result_index);
if (output.result_index >= result_shapes.size()) {
result_shapes.resize(output.result_index + 1);
}
if (output.output_index) {
plan_output->set_output_index(*output.output_index + 1);
result_shapes[output.result_index] =
ShapeUtil::GetTupleElementShape(*op_shape, *output.output_index);
} else {
result_shapes[output.result_index] = *op_shape;
}
}
}
const XrtSession::CachedNode& cached_node =
GetExecuteChainedNode(session, device_scope, effective_device);
feed_inputs.insert({cached_node.holders[0], plan.SerializeAsString()});
feed_inputs.insert({cached_node.holders[1], config.SerializeAsString()});
std::vector<tensorflow::Tensor> outputs;
util::CheckComputationStatus(
session->session()->Run(feed_inputs, {cached_node.outputs[0]}, &outputs),
{}, {});
XLA_CHECK_EQ(outputs.size(), 1);
std::vector<DataPtr> results;
auto handles_vec = outputs[0].vec<int64>();
for (int64 i = 0; i < handles_vec.size(); ++i) {
results.push_back(std::make_shared<XrtData>(
dynamic_cast<XrtDevice*>(GetDevice(effective_device)),
std::move(result_shapes.at(i)), handles_vec(i)));
}
CreateDataHandlesCounter()->AddValue(results.size());
return results;
}