std::vector XrtComputationClient::Compile()

in Sources/x10/xla_client/xrt_computation_client.cc [818:898]


std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
    const std::string& device, const std::vector<std::string>& devices,
    std::vector<CompileInstance> instances) {
  metrics::TimedSection timed(CompileMetric());

  std::mutex lock;
  util::MultiWait mwait(instances.size());
  std::vector<ProgramShape> program_shapes(instances.size());
  std::vector<ComputationPtr> results(instances.size());
  std::vector<CompilationCacheKey> cache_keys(instances.size());
  XrtSessionCache::SessionMap session_map;
  std::map<XrtSession*, SessionWork> session_work_map;
  for (size_t i = 0; i < instances.size(); ++i) {
    auto builder = [&, this, i]() {
      const CompileInstance& instance = instances[i];
      std::unique_ptr<xrt::XLAComputation> xrt_computation =
          CreateXrtComputation(instance.computation, devices,
                               instance.output_shape);
      CompilationCacheKey cache_key(GetResourceDomain(device),
                                    xrt_computation->SerializeAsString());
      auto computation_ptr = compilation_cache_.Get(cache_key);
      if (computation_ptr == nullptr) {
        cache_keys[i] = std::move(cache_key);
        program_shapes[i] =
            ProgramShape(xrt_computation->config().program_shape());

        const std::string& xrt_device = SwiftDeviceToXrtDevice(device);
        {
          std::lock_guard<std::mutex> slock(lock);
          XrtSession* session = GetSessionForXrtDevice(
              session_cache_.get(), xrt_device, &session_map);
          SessionWork* session_work = &session_work_map[session];
          tensorflow::Scope device_scope =
              session->root()->WithDevice(xrt_device);
          const XrtSession::CachedNode& cached_node =
              GetCompileNode(session, device_scope, device);
          session_work->feed_inputs.insert(
              {cached_node.holders[0], cache_keys[i].serialized_computation});
          session_work->outputs_handles.push_back(cached_node.outputs[0]);
          session_work->index_mapping.push_back(i);
        }
      } else {
        results[i] = computation_ptr;
      }
    };
    env::ScheduleClosure(mwait.Completer(std::move(builder)));
  }
  mwait.Wait();
  mwait.Reset(session_work_map.size());

  for (auto& session_and_work : session_work_map) {
    XrtSession* session = session_and_work.first;
    const SessionWork& session_work = session_and_work.second;

    auto session_runner = [&, this, session]() {
      std::vector<tensorflow::Tensor> outputs;
      CheckCompileStatus(
          session->session()->Run(session_work.feed_inputs,
                                  session_work.outputs_handles, &outputs),
          instances, session_work);
      XLA_CHECK_EQ(outputs.size(), session_work.outputs_handles.size());

      double compile_time = timed.Elapsed();
      size_t output_index = 0;
      for (auto li : session_work.index_mapping) {
        CompileInstance* instance = &instances[li];
        MaybeSaveLongCompileHlo(compile_time, instance->computation);
        results[li] = std::make_shared<XrtComputation>(
            this, std::move(instance->computation), program_shapes[li], devices,
            outputs[output_index].scalar<int64>()(), device);
        ++output_index;

        compilation_cache_.Add(std::move(cache_keys[li]), results[li]);
        CreateCompileHandlesCounter()->AddValue(1);
      }
    };
    env::ScheduleIoClosure(mwait.Completer(std::move(session_runner)));
  }
  mwait.Wait();
  return results;
}