in Sources/x10/xla_client/xrt_computation_client.cc [818:898]
std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
const std::string& device, const std::vector<std::string>& devices,
std::vector<CompileInstance> instances) {
metrics::TimedSection timed(CompileMetric());
std::mutex lock;
util::MultiWait mwait(instances.size());
std::vector<ProgramShape> program_shapes(instances.size());
std::vector<ComputationPtr> results(instances.size());
std::vector<CompilationCacheKey> cache_keys(instances.size());
XrtSessionCache::SessionMap session_map;
std::map<XrtSession*, SessionWork> session_work_map;
for (size_t i = 0; i < instances.size(); ++i) {
auto builder = [&, this, i]() {
const CompileInstance& instance = instances[i];
std::unique_ptr<xrt::XLAComputation> xrt_computation =
CreateXrtComputation(instance.computation, devices,
instance.output_shape);
CompilationCacheKey cache_key(GetResourceDomain(device),
xrt_computation->SerializeAsString());
auto computation_ptr = compilation_cache_.Get(cache_key);
if (computation_ptr == nullptr) {
cache_keys[i] = std::move(cache_key);
program_shapes[i] =
ProgramShape(xrt_computation->config().program_shape());
const std::string& xrt_device = SwiftDeviceToXrtDevice(device);
{
std::lock_guard<std::mutex> slock(lock);
XrtSession* session = GetSessionForXrtDevice(
session_cache_.get(), xrt_device, &session_map);
SessionWork* session_work = &session_work_map[session];
tensorflow::Scope device_scope =
session->root()->WithDevice(xrt_device);
const XrtSession::CachedNode& cached_node =
GetCompileNode(session, device_scope, device);
session_work->feed_inputs.insert(
{cached_node.holders[0], cache_keys[i].serialized_computation});
session_work->outputs_handles.push_back(cached_node.outputs[0]);
session_work->index_mapping.push_back(i);
}
} else {
results[i] = computation_ptr;
}
};
env::ScheduleClosure(mwait.Completer(std::move(builder)));
}
mwait.Wait();
mwait.Reset(session_work_map.size());
for (auto& session_and_work : session_work_map) {
XrtSession* session = session_and_work.first;
const SessionWork& session_work = session_and_work.second;
auto session_runner = [&, this, session]() {
std::vector<tensorflow::Tensor> outputs;
CheckCompileStatus(
session->session()->Run(session_work.feed_inputs,
session_work.outputs_handles, &outputs),
instances, session_work);
XLA_CHECK_EQ(outputs.size(), session_work.outputs_handles.size());
double compile_time = timed.Elapsed();
size_t output_index = 0;
for (auto li : session_work.index_mapping) {
CompileInstance* instance = &instances[li];
MaybeSaveLongCompileHlo(compile_time, instance->computation);
results[li] = std::make_shared<XrtComputation>(
this, std::move(instance->computation), program_shapes[li], devices,
outputs[output_index].scalar<int64>()(), device);
++output_index;
compilation_cache_.Add(std::move(cache_keys[li]), results[li]);
CreateCompileHandlesCounter()->AddValue(1);
}
};
env::ScheduleIoClosure(mwait.Completer(std::move(session_runner)));
}
mwait.Wait();
return results;
}