in torch_xla/csrc/tensor.cpp [1441:1497]
XLATensor::OpByOpAsync XLATensor::SyncTensorsGraphOpByOp(
std::vector<XLATensor>* tensors, absl::Span<const std::string> devices,
const SyncTensorsConfig& config) {
struct Async {
explicit Async(SyncTensorCollection coll,
std::vector<xla::ComputationClient::DataPtr> tensors_data,
std::vector<ir::Value> roots,
absl::Span<const std::string> devices)
: coll(std::move(coll)),
tensors_data(std::move(tensors_data)),
roots(std::move(roots)),
devices(devices.begin(), devices.end()) {}
SyncTensorCollection coll;
std::vector<xla::ComputationClient::DataPtr> tensors_data;
std::vector<ir::Value> roots;
std::vector<std::string> devices;
};
SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
DebugUtil::SaveTensorsGraphInfo("SyncTensorsGraphOpByOp", *tensors,
&coll.indices);
std::vector<ir::Value> roots = CollectRoots(*tensors, coll.indices);
auto tensors_data = FetchTensorData(tensors, coll.config, coll.indices);
auto async = std::make_shared<Async>(std::move(coll), std::move(tensors_data),
std::move(roots), devices);
auto syncfn = [async]() -> int {
try {
TF_VLOG(3) << "Executing (OpByOp) IR graph hash "
<< torch::lazy::HashToString(async->coll.hash) << " on device "
<< async->coll.device << " ...";
std::vector<xla::ComputationClient::DataPtr> results =
OpByOpExecutor::Get()->Execute(
async->roots, async->coll.device.ToString(), async->devices);
TF_VLOG(3) << "Executing (OpByOp) IR graph hash "
<< torch::lazy::HashToString(async->coll.hash) << " on device "
<< async->coll.device << " done!";
for (size_t i = 0; i < results.size(); ++i) {
if (async->tensors_data[i] != nullptr) {
async->tensors_data[i]->Assign(*results[i]);
}
}
} catch (...) {
std::exception_ptr exptr = std::current_exception();
for (auto& unlocker : async->coll.unlocker) {
unlocker.SetStatus(exptr);
}
throw;
}
return 0;
};
OpByOpAsync async_op(std::move(syncfn));
return async_op.Schedule();
}