in torch_xla/csrc/op_by_op_executor.cpp [85:196]
std::vector<xla::ComputationClient::ExecuteChainedOp> OpByOpExecutor::BuildOps(
absl::Span<const ir::Value> roots, const std::string& device,
absl::Span<const std::string> devices) {
std::vector<const ir::Node*> root_nodes;
root_nodes.reserve(roots.size());
for (auto& root : roots) {
root_nodes.push_back(root.node.get());
}
std::vector<const ir::Node*> post_order =
ir::Util::ComputePostOrder(root_nodes);
XLA_VALUE_METRIC("OpByOpGraphSize", post_order.size());
TF_VLOG(5) << "TensorsGraphSize=" << post_order.size();
std::unordered_map<const ir::Node*, size_t> node_to_index;
node_to_index.reserve(post_order.size());
for (size_t i = 0; i < post_order.size(); ++i) {
node_to_index[post_order[i]] = i;
}
auto compilation_devices =
xla::ComputationClient::Get()->GetCompilationDevices(device, devices);
torch::lazy::hash_t nodes_key_seed =
GetNodesKeySeed(device, compilation_devices);
Device exec_device(device);
std::vector<torch::lazy::hash_t> cache_keys;
std::unordered_map<torch::lazy::hash_t, std::vector<size_t>,
torch::lazy::HashReducer>
compile_indices;
std::unordered_map<torch::lazy::hash_t, size_t, torch::lazy::HashReducer>
cache_keys_instance;
std::list<xla::Shape> compile_shapes;
std::vector<bool> device_data_ops(post_order.size());
std::vector<const xla::Shape*> ops_shapes(post_order.size());
std::vector<xla::ComputationClient::CompileInstance> compile_instances;
std::vector<xla::ComputationClient::ExecuteChainedOp> chained_exec_ops(
post_order.size());
for (size_t i = 0; i < post_order.size(); ++i) {
const ir::Node* node = post_order[i];
xla::ComputationClient::ExecuteChainedOp& cxop = chained_exec_ops[i];
const ir::ops::DeviceData* device_data = ir::ops::DeviceData::Cast(node);
if (device_data != nullptr) {
cxop.device_data = device_data->data();
ops_shapes[i] = &cxop.device_data->shape();
device_data_ops[i] = true;
} else {
std::vector<const xla::Shape*> op_input_shapes;
for (auto& operand : node->operands()) {
size_t op_index = node_to_index.at(operand.node);
cxop.inputs.push_back(
{op_index,
GetOutputIndex(device_data_ops[op_index], operand.index)});
op_input_shapes.push_back(ops_shapes[op_index]);
}
torch::lazy::hash_t cache_key =
ComputeNodeKey(node, op_input_shapes, nodes_key_seed);
cxop.computation = compile_cache_.Get(cache_key);
if (cxop.computation == nullptr) {
XLA_COUNTER("OpByOpCompileCacheMiss", 1);
// Within a single IR graph, there can be many duplicated IR nodes, so
// make sure we do not issue an XLA compilation for each one of those.
auto& cache_key_indices = compile_indices[cache_key];
cache_key_indices.push_back(i);
if (cache_key_indices.size() == 1) {
cache_keys.push_back(cache_key);
cache_keys_instance[cache_key] = compile_instances.size();
xla::XlaComputation computation =
BuildNodeComputation(node, op_input_shapes, exec_device);
xla::ProgramShape program_shape =
ConsumeValue(computation.GetProgramShape());
compile_shapes.push_back(MakeShapeWithDeviceLayout(
program_shape.result(), exec_device.hw_type));
compile_instances.push_back({std::move(computation), device,
compilation_devices,
&compile_shapes.back()});
ops_shapes[i] = &compile_shapes.back();
} else {
ops_shapes[i] =
compile_instances[cache_keys_instance.at(cache_key)].output_shape;
}
} else {
ops_shapes[i] = &cxop.computation->program_shape().result();
}
}
}
// Fixup the requested outputs (roots) within the chained ops vector.
for (size_t i = 0; i < roots.size(); ++i) {
size_t op_index = node_to_index.at(roots[i].node.get());
chained_exec_ops[op_index].outputs.push_back(
{i, GetOutputIndex(device_data_ops[op_index], roots[i].index)});
}
// If we missed the cache for certain ops, compile them now and fixup the
// chained ops vector.
if (!compile_instances.empty()) {
TF_VLOG(3) << "Compiling " << compile_instances.size()
<< " computations on device " << device;
auto computation_ptrs =
xla::ComputationClient::Get()->Compile(std::move(compile_instances));
TF_VLOG(3) << "Compiling " << computation_ptrs.size()
<< " computations on device " << device << " done!";
for (size_t i = 0; i < computation_ptrs.size(); ++i) {
compile_cache_.Add(cache_keys[i], computation_ptrs[i]);
for (auto index : compile_indices[cache_keys[i]]) {
chained_exec_ops[index].computation = computation_ptrs[i];
}
}
}
return chained_exec_ops;
}