in Sources/x10/xla_client/xrt_computation_client.cc [1589:1656]
void XrtComputationClient::InitializeDevices(
std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto) {
if (topology_proto == nullptr) {
std::set<Worker> tpu_workers;
for (const auto& dev_target : options_.global_device_map) {
tensorflow::DeviceNameUtils::ParsedName parsed_device =
ParseFullXrtDevice(dev_target.second);
if (parsed_device.type == "TPU") {
tpu_workers.emplace(parsed_device.job, parsed_device.task);
}
}
if (!tpu_workers.empty()) {
const Worker& worker = *tpu_workers.begin();
auto it = options_.workers_map.find(worker);
XLA_CHECK(it != options_.workers_map.end());
TF_VLOG(1) << "Configuring TPU for worker " << worker.name << ":"
<< worker.task_no << " at " << it->second;
tensorflow::tpu::TopologyProto worker_topology_proto =
InitializeAndFetchTopology(worker.name, worker.task_no, it->second,
session_cache_->GetConfig());
if (topology_proto == nullptr) {
topology_proto = absl::make_unique<tensorflow::tpu::TopologyProto>(
std::move(worker_topology_proto));
}
}
if (topology_proto != nullptr) {
TF_VLOG(1) << "TPU topology: " << topology_proto->DebugString();
}
}
for (const auto& dev_target : options_.global_device_map) {
tensorflow::DeviceNameUtils::ParsedName parsed_device =
ParseFullXrtDevice(dev_target.second);
if (parsed_device.type != "TPU") {
continue;
}
XLA_CHECK_LE(parsed_device.task, topology_proto->num_tasks());
XLA_CHECK_LE(parsed_device.id, topology_proto->num_tpu_devices_per_task());
// The topology proto 'device_coordinates' is a linear list of
// [num_tasks][devices_per_task][mesh_shape_size] coordinates, where the
// mesh coordinates are usually [x, y, z, c] ('x', 'y' and 'z' being the
// spatial chip coordinated and 'c' the core number).
int64 base_index = parsed_device.task *
topology_proto->num_tpu_devices_per_task() *
topology_proto->mesh_shape_size() +
parsed_device.id * topology_proto->mesh_shape_size();
std::vector<int> device_mesh_coords(topology_proto->mesh_shape_size());
for (int i = 0; i < topology_proto->mesh_shape_size(); ++i) {
device_mesh_coords[i] =
topology_proto->device_coordinates(base_index + i);
}
device_mesh_coords_.insert(
{dev_target.second, std::move(device_mesh_coords)});
}
// Create the mesh service only if we have more than one worker, or if
// multi-processing is active.
std::string mesh_service_address =
sys_util::GetEnvString(env::kEnvMeshService, "");
std::string mp_device = GetMultiProcessingDevice();
if (!mesh_service_address.empty() && !mp_device.empty()) {
DeviceId device(mp_device);
if (device.ordinal == 0) {
CreateMeshService(mesh_service_address, topology_proto.get());
}
SetupGpuRuntime();
}
}