in torch_xla/csrc/tensor.cpp [1134:1206]
XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors(
const std::vector<XLATensor>& tensors, const SyncTensorsConfig& config) {
tensorflow::profiler::TraceMe activity(
"CollectSyncTensors", tensorflow::profiler::TraceMeLevel::kInfo);
xla::util::Unique<Device> unique_device;
for (size_t i = 0; i < tensors.size(); ++i) {
unique_device.set(tensors[i].GetDevice());
}
SyncTensorCollection coll;
if (!unique_device) {
return coll;
}
std::vector<at::Tensor> at_tensors;
std::vector<std::string> devices;
std::vector<size_t> at_tensor_index;
std::unordered_set<xla::int64_t> tensor_ids;
// The force_xla_data controls aliasing compilation, so effectively the same
// graph with on/off force_xla_data should not match, hash wise.
coll.hash = torch::lazy::MHash(config.force_xla_data);
coll.config = config;
coll.device = *unique_device;
coll.indices.reserve(tensors.size());
TF_VLOG(4) << "Waiting on device barrier for device " << coll.device
<< " ...";
{
XLA_TIMED("DeviceLockWait");
coll.unlocker = LockDevices(unique_device.AsSet());
}
TF_VLOG(4) << "Waiting on device barrier for device " << coll.device
<< " done!";
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensor_ids.insert(tensors[i].GetUniqueId()).second &&
tensors[i].CurrentXlaData() == nullptr) {
ir::Value ir_value = tensors[i].CurrentIrValue();
if (ir_value) {
if (ShouldSyncIrValue(ir_value)) {
// Add only tensors which need to be synced.
coll.hash = torch::lazy::HashCombine(coll.hash, ir_value.hash());
coll.indices.push_back(i);
}
} else if (config.force_xla_data) {
// The tensor only has at::Tensor data. We need to queue it for a
// device upload.
c10::optional<at::Tensor> tensor_data = tensors[i].CurrentTensorData();
XLA_CHECK(tensor_data);
at_tensors.push_back(*tensor_data);
devices.push_back(tensors[i].GetDevice().ToString());
at_tensor_index.push_back(i);
}
}
}
// Mix the hash with the resource domain hashes as compile handles are only
// valid within a domain (usually a single host).
coll.hash = torch::lazy::MHash(
coll.hash,
xla::ComputationClient::Get()->GetResourceDomain(coll.device.ToString()));
if (!at_tensors.empty()) {
XLA_COUNTER("SyncTensorsToData", at_tensors.size());
std::vector<xla::ComputationClient::DataPtr> handles =
CreateTensorsData(at_tensors, devices);
for (size_t i = 0; i < handles.size(); ++i) {
// If we are here, it means that the IR Value for the tensor is not
// present. Also, we uploaded the at::Tensor data to the device, but such
// data is still valid so we leave it live on the XLA tensor (so that a
// following ToTensor() does not need to fetch it from device).
tensors[at_tensor_index[i]].data()->xla_data = std::move(handles[i]);
}
}
TF_VLOG(4) << "Tensors graph hash " << torch::lazy::HashToString(coll.hash)
<< " on device " << coll.device;
return coll;
}