in torch_xla/csrc/tensor.cpp [1535:1616]
XLATensor::CompilationResult XLATensor::Compile(
const std::vector<XLATensor>& tensors,
absl::Span<const std::string> devices, const SyncTensorCollection& coll,
PostOrderData* po_data) {
tensorflow::profiler::TraceMe activity(
[&] {
return tensorflow::profiler::TraceMeEncode(
"XLATensor::Compile",
{{"graph_hash", torch::lazy::HashToString(coll.hash)}});
},
tensorflow::profiler::TraceMeLevel::kInfo);
static const bool enable_aliasing =
xla::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true);
ir::LoweringContext lowering_ctx("SyncTensorsGraph", coll.device,
po_data->post_order,
std::move(po_data->emission_map));
for (auto index : coll.indices) {
ir::Value ir_value = tensors[index].CurrentIrValue();
xla::XlaOp root = lowering_ctx.GetOutputOp(ir_value);
lowering_ctx.AddResult(root);
}
if (enable_aliasing && coll.config.sync_xla_data) {
// We can only alias at the step barrier, when force_xla_data is true.
// Consider the case:
// 1. Tensor A(DEVICE_DATA)
// 2. Tensor B = A + 0.9
// 3. A += 0.4
// If we activate aliasing for A's graph, and we do:
// print(A)
// print(A)
// The first print will update DEVICE_DATA' with DEVICE_DATA+0.4, and the
// second print will again update DEVICE_DATA" with DEVICE_DATA'+0.4, which
// will lead to incorrect results.
// We cannot normally turn A's state into DEVICE_DATA, as if any of the
// sources is a view, this will not lead to correct results (as A's value
// taken at different times need to reflect view source changes):
// 1. Tensor A = some_graph_with_view_source(V)
// 2. print(A)
// 3. V += 1
// 4. print(A)
// The second print should reflect the new value due to V's changes.
// Also in the first example, unless we are doing a step barrier and hence
// include all live tensors, if the B value is not part of the graph, it
// will later fetch the new value of A, which is incorrect.
// But, when we issue a step barrier (force_xla_data == true) we have to
// turn everything into DEVICE_DATA, so we can activate aliasing.
BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx);
}
xla::XlaComputation computation = ConsumeValue(lowering_ctx.Build());
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
xla::Shape shape =
MakeShapeWithDeviceLayout(program_shape.result(), coll.device.hw_type);
std::vector<xla::ComputationClient::CompileInstance> instances;
instances.push_back({std::move(computation), coll.device.ToString(),
xla::ComputationClient::Get()->GetCompilationDevices(
coll.device.ToString(), devices),
&shape});
TF_VLOG(3) << "Compiling IR graph hash "
<< torch::lazy::HashToString(coll.hash) << " on device "
<< coll.device << " ...";
std::vector<std::shared_ptr<xla::ComputationClient::Computation>>
computations =
xla::ComputationClient::Get()->Compile(std::move(instances));
TF_VLOG(3) << "Compiling IR graph hash "
<< torch::lazy::HashToString(coll.hash) << " on device "
<< coll.device << " done!";
TF_VLOG(5)
<< "Graph hash " << torch::lazy::HashToString(coll.hash)
<< " is computation hash "
<< torch::lazy::HashToString(torch::lazy::Hash(
computations.front()->computation().proto().SerializeAsString()));
XLA_CHECK_EQ(program_shape.parameters_size(),
po_data->parameters_data.size());
return {/*device=*/coll.device,
/*emitted_nodes=*/lowering_ctx.GetEmittedNodeCount(),
/*computation=*/std::move(computations.front()),
/*parameters_data=*/std::move(po_data->parameters_data)};
}