void XrtComputationClient::InitializeDevices()

in Sources/x10/xla_client/xrt_computation_client.cc [1589:1656]


void XrtComputationClient::InitializeDevices(
    std::unique_ptr<tensorflow::tpu::TopologyProto> topology_proto) {
  if (topology_proto == nullptr) {
    std::set<Worker> tpu_workers;
    for (const auto& dev_target : options_.global_device_map) {
      tensorflow::DeviceNameUtils::ParsedName parsed_device =
          ParseFullXrtDevice(dev_target.second);
      if (parsed_device.type == "TPU") {
        tpu_workers.emplace(parsed_device.job, parsed_device.task);
      }
    }
    if (!tpu_workers.empty()) {
      const Worker& worker = *tpu_workers.begin();
      auto it = options_.workers_map.find(worker);
      XLA_CHECK(it != options_.workers_map.end());

      TF_VLOG(1) << "Configuring TPU for worker " << worker.name << ":"
                 << worker.task_no << " at " << it->second;
      tensorflow::tpu::TopologyProto worker_topology_proto =
          InitializeAndFetchTopology(worker.name, worker.task_no, it->second,
                                     session_cache_->GetConfig());
      if (topology_proto == nullptr) {
        topology_proto = absl::make_unique<tensorflow::tpu::TopologyProto>(
            std::move(worker_topology_proto));
      }
    }
    if (topology_proto != nullptr) {
      TF_VLOG(1) << "TPU topology: " << topology_proto->DebugString();
    }
  }
  for (const auto& dev_target : options_.global_device_map) {
    tensorflow::DeviceNameUtils::ParsedName parsed_device =
        ParseFullXrtDevice(dev_target.second);
    if (parsed_device.type != "TPU") {
      continue;
    }
    XLA_CHECK_LE(parsed_device.task, topology_proto->num_tasks());
    XLA_CHECK_LE(parsed_device.id, topology_proto->num_tpu_devices_per_task());
    // The topology proto 'device_coordinates' is a linear list of
    // [num_tasks][devices_per_task][mesh_shape_size] coordinates, where the
    // mesh coordinates are usually [x, y, z, c] ('x', 'y' and 'z' being the
    // spatial chip coordinated and 'c' the core number).
    int64 base_index = parsed_device.task *
                           topology_proto->num_tpu_devices_per_task() *
                           topology_proto->mesh_shape_size() +
                       parsed_device.id * topology_proto->mesh_shape_size();
    std::vector<int> device_mesh_coords(topology_proto->mesh_shape_size());
    for (int i = 0; i < topology_proto->mesh_shape_size(); ++i) {
      device_mesh_coords[i] =
          topology_proto->device_coordinates(base_index + i);
    }
    device_mesh_coords_.insert(
        {dev_target.second, std::move(device_mesh_coords)});
  }

  // Create the mesh service only if we have more than one worker, or if
  // multi-processing is active.
  std::string mesh_service_address =
      sys_util::GetEnvString(env::kEnvMeshService, "");
  std::string mp_device = GetMultiProcessingDevice();
  if (!mesh_service_address.empty() && !mp_device.empty()) {
    DeviceId device(mp_device);
    if (device.ordinal == 0) {
      CreateMeshService(mesh_service_address, topology_proto.get());
    }
    SetupGpuRuntime();
  }
}