std::vector OpByOpExecutor::BuildOps()

in torch_xla/csrc/op_by_op_executor.cpp [85:196]


std::vector<xla::ComputationClient::ExecuteChainedOp> OpByOpExecutor::BuildOps(
    absl::Span<const ir::Value> roots, const std::string& device,
    absl::Span<const std::string> devices) {
  std::vector<const ir::Node*> root_nodes;
  root_nodes.reserve(roots.size());
  for (auto& root : roots) {
    root_nodes.push_back(root.node.get());
  }
  std::vector<const ir::Node*> post_order =
      ir::Util::ComputePostOrder(root_nodes);
  XLA_VALUE_METRIC("OpByOpGraphSize", post_order.size());
  TF_VLOG(5) << "TensorsGraphSize=" << post_order.size();

  std::unordered_map<const ir::Node*, size_t> node_to_index;
  node_to_index.reserve(post_order.size());
  for (size_t i = 0; i < post_order.size(); ++i) {
    node_to_index[post_order[i]] = i;
  }

  auto compilation_devices =
      xla::ComputationClient::Get()->GetCompilationDevices(device, devices);
  torch::lazy::hash_t nodes_key_seed =
      GetNodesKeySeed(device, compilation_devices);
  Device exec_device(device);
  std::vector<torch::lazy::hash_t> cache_keys;
  std::unordered_map<torch::lazy::hash_t, std::vector<size_t>,
                     torch::lazy::HashReducer>
      compile_indices;
  std::unordered_map<torch::lazy::hash_t, size_t, torch::lazy::HashReducer>
      cache_keys_instance;
  std::list<xla::Shape> compile_shapes;
  std::vector<bool> device_data_ops(post_order.size());
  std::vector<const xla::Shape*> ops_shapes(post_order.size());
  std::vector<xla::ComputationClient::CompileInstance> compile_instances;
  std::vector<xla::ComputationClient::ExecuteChainedOp> chained_exec_ops(
      post_order.size());
  for (size_t i = 0; i < post_order.size(); ++i) {
    const ir::Node* node = post_order[i];
    xla::ComputationClient::ExecuteChainedOp& cxop = chained_exec_ops[i];
    const ir::ops::DeviceData* device_data = ir::ops::DeviceData::Cast(node);
    if (device_data != nullptr) {
      cxop.device_data = device_data->data();
      ops_shapes[i] = &cxop.device_data->shape();
      device_data_ops[i] = true;
    } else {
      std::vector<const xla::Shape*> op_input_shapes;
      for (auto& operand : node->operands()) {
        size_t op_index = node_to_index.at(operand.node);
        cxop.inputs.push_back(
            {op_index,
             GetOutputIndex(device_data_ops[op_index], operand.index)});
        op_input_shapes.push_back(ops_shapes[op_index]);
      }

      torch::lazy::hash_t cache_key =
          ComputeNodeKey(node, op_input_shapes, nodes_key_seed);
      cxop.computation = compile_cache_.Get(cache_key);
      if (cxop.computation == nullptr) {
        XLA_COUNTER("OpByOpCompileCacheMiss", 1);

        // Within a single IR graph, there can be many duplicated IR nodes, so
        // make sure we do not issue an XLA compilation for each one of those.
        auto& cache_key_indices = compile_indices[cache_key];
        cache_key_indices.push_back(i);
        if (cache_key_indices.size() == 1) {
          cache_keys.push_back(cache_key);
          cache_keys_instance[cache_key] = compile_instances.size();

          xla::XlaComputation computation =
              BuildNodeComputation(node, op_input_shapes, exec_device);
          xla::ProgramShape program_shape =
              ConsumeValue(computation.GetProgramShape());
          compile_shapes.push_back(MakeShapeWithDeviceLayout(
              program_shape.result(), exec_device.hw_type));
          compile_instances.push_back({std::move(computation), device,
                                       compilation_devices,
                                       &compile_shapes.back()});
          ops_shapes[i] = &compile_shapes.back();
        } else {
          ops_shapes[i] =
              compile_instances[cache_keys_instance.at(cache_key)].output_shape;
        }
      } else {
        ops_shapes[i] = &cxop.computation->program_shape().result();
      }
    }
  }
  // Fixup the requested outputs (roots) within the chained ops vector.
  for (size_t i = 0; i < roots.size(); ++i) {
    size_t op_index = node_to_index.at(roots[i].node.get());
    chained_exec_ops[op_index].outputs.push_back(
        {i, GetOutputIndex(device_data_ops[op_index], roots[i].index)});
  }

  // If we missed the cache for certain ops, compile them now and fixup the
  // chained ops vector.
  if (!compile_instances.empty()) {
    TF_VLOG(3) << "Compiling " << compile_instances.size()
               << " computations on device " << device;
    auto computation_ptrs =
        xla::ComputationClient::Get()->Compile(std::move(compile_instances));
    TF_VLOG(3) << "Compiling " << computation_ptrs.size()
               << " computations on device " << device << " done!";
    for (size_t i = 0; i < computation_ptrs.size(); ++i) {
      compile_cache_.Add(cache_keys[i], computation_ptrs[i]);
      for (auto index : compile_indices[cache_keys[i]]) {
        chained_exec_ops[index].computation = computation_ptrs[i];
      }
    }
  }
  return chained_exec_ops;
}