void gather_true_virtual_methods()

in service/method-inliner/MethodInliner.cpp [393:629]


void gather_true_virtual_methods(const mog::Graph& method_override_graph,
                                 const Scope& scope,
                                 CalleeCallerInsns* true_virtual_callers) {
  Timer t("gather_true_virtual_methods");
  auto non_virtual = mog::get_non_true_virtuals(method_override_graph, scope);
  auto same_implementation_map =
      get_same_implementation_map(scope, method_override_graph);
  ConcurrentMap<const DexMethod*, CallerInsns> concurrent_true_virtual_callers;
  ConcurrentMap<IRInstruction*, SameImplementation*>
      same_implementation_invokes;
  // Add mapping from callee to monomorphic callsites.
  auto add_monomorphic_call_site = [&](const DexMethod* caller,
                                       IRInstruction* callsite,
                                       const DexMethod* callee) {
    concurrent_true_virtual_callers.update(
        callee, [&](const DexMethod*, CallerInsns& m, bool) {
          m.caller_insns[caller].emplace(callsite);
        });
  };
  auto add_other_call_site = [&](const DexMethod* callee) {
    concurrent_true_virtual_callers.update(
        callee, [&](const DexMethod*, CallerInsns& m, bool) {
          m.other_call_sites = true;
        });
  };
  auto add_candidate = [&](const DexMethod* callee) {
    concurrent_true_virtual_callers.emplace(callee, CallerInsns());
  };

  walk::parallel::methods(scope, [&non_virtual, &method_override_graph,
                                  &add_monomorphic_call_site,
                                  &add_other_call_site, &add_candidate,
                                  &same_implementation_invokes,
                                  &same_implementation_map](DexMethod* method) {
    if (method->is_virtual() && !non_virtual.count(method)) {
      add_candidate(method);
      if (root(method)) {
        add_other_call_site(method);
      } else {
        const auto& overridden_methods = mog::get_overridden_methods(
            method_override_graph, method, /* include_interfaces */ true);
        for (auto overridden_method : overridden_methods) {
          if (root(overridden_method) || overridden_method->is_external()) {
            add_other_call_site(method);
            break;
          }
        }
      }
    }
    auto code = method->get_code();
    if (!code) {
      return;
    }
    for (auto& mie : InstructionIterable(code)) {
      auto insn = mie.insn;
      if (insn->opcode() != OPCODE_INVOKE_VIRTUAL &&
          insn->opcode() != OPCODE_INVOKE_INTERFACE &&
          insn->opcode() != OPCODE_INVOKE_SUPER) {
        continue;
      }
      auto insn_method = insn->get_method();
      auto callee = resolve_method(insn_method, opcode_to_search(insn), method);
      if (callee == nullptr) {
        // There are some invoke-virtual call on methods whose def are
        // actually in interface.
        callee = resolve_method(insn->get_method(), MethodSearch::Interface);
      }
      if (callee == nullptr) {
        continue;
      }
      if (non_virtual.count(callee) != 0) {
        // Not true virtual, no need to continue;
        continue;
      }
      if (can_have_unknown_implementations(method_override_graph, callee)) {
        add_other_call_site(callee);
        if (insn->opcode() != OPCODE_INVOKE_SUPER) {
          auto overriding_methods =
              mog::get_overriding_methods(method_override_graph, callee);
          for (auto overriding_method : overriding_methods) {
            add_other_call_site(overriding_method);
          }
        }
        continue;
      }
      always_assert_log(callee->is_def(), "Resolved method not def %s",
                        SHOW(callee));
      if (insn->opcode() == OPCODE_INVOKE_SUPER) {
        add_monomorphic_call_site(method, insn, callee);
        continue;
      }
      auto it = same_implementation_map.find(callee);
      if (it != same_implementation_map.end()) {
        // We can find the resolved callee in same_implementation_map,
        // just use that piece of info because we know the implementors are all
        // the same
        add_monomorphic_call_site(method, insn, it->second->representative);
        same_implementation_invokes.emplace(insn, it->second.get());
        continue;
      }
      auto overriding_methods =
          mog::get_overriding_methods(method_override_graph, callee);
      std20::erase_if(overriding_methods,
                      [&](auto* m) { return is_abstract(m); });
      if (overriding_methods.empty()) {
        // There is no override for this method
        add_monomorphic_call_site(method, insn, callee);
      } else if (is_abstract(callee) && overriding_methods.size() == 1) {
        // The method is an abstract method, the only override is its
        // implementation.
        auto implementing_method = *overriding_methods.begin();
        add_monomorphic_call_site(method, insn, implementing_method);
      } else {
        add_other_call_site(callee);
        for (auto overriding_method : overriding_methods) {
          add_other_call_site(overriding_method);
        }
      }
    }
  });

  // Post processing candidates.
  std::vector<const DexMethod*> true_virtual_callees;
  for (auto& p : concurrent_true_virtual_callers) {
    true_virtual_callees.push_back(p.first);
  }
  workqueue_run<const DexMethod*>(
      [&](sparta::SpartaWorkerState<const DexMethod*>*,
          const DexMethod* callee) {
        auto& caller_to_invocations =
            concurrent_true_virtual_callers.at_unsafe(callee);
        if (caller_to_invocations.caller_insns.empty()) {
          return;
        }
        auto code = const_cast<DexMethod*>(callee)->get_code();
        if (!code || !method::no_invoke_super(*code)) {
          if (!caller_to_invocations.caller_insns.empty()) {
            caller_to_invocations.caller_insns.clear();
            caller_to_invocations.other_call_sites = true;
          }
          return;
        }
        // Figure out if candidates use the receiver in a way that does require
        // a cast.
        std::unordered_set<live_range::Use> first_load_param_uses;
        {
          code->build_cfg(/* editable */ true);
          live_range::MoveAwareChains chains(code->cfg());
          auto ii = InstructionIterable(code->cfg().get_param_instructions());
          auto first_load_param = ii.begin()->insn;
          first_load_param_uses =
              std::move(chains.get_def_use_chains()[first_load_param]);
          code->clear_cfg();
        }
        std::unordered_set<DexType*> formal_callee_types;
        bool any_same_implementation_invokes{false};
        for (auto& p : caller_to_invocations.caller_insns) {
          for (auto insn : p.second) {
            formal_callee_types.insert(insn->get_method()->get_class());
            if (same_implementation_invokes.count_unsafe(insn)) {
              any_same_implementation_invokes = true;
            }
          }
        }
        auto type_demands = std::make_unique<std::unordered_set<DexType*>>();
        // Note that the callee-rtype is the same for all methods in a
        // same-implementations cluster.
        auto callee_rtype = callee->get_proto()->get_rtype();
        for (auto use : first_load_param_uses) {
          if (opcode::is_a_move(use.insn->opcode())) {
            continue;
          }
          auto type_demand = get_receiver_type_demand(callee_rtype, use);
          always_assert(type::check_cast(callee->get_class(), type_demand));
          if (type_demand == nullptr) {
            formal_callee_types.clear();
            type_demands = nullptr;
            break;
          }
          always_assert_log(type::check_cast(callee->get_class(), type_demand),
                            "For the incoming code to be type correct, %s must "
                            "be castable to %s.",
                            SHOW(callee->get_class()), SHOW(type_demand));
          if (type_demands->insert(type_demand).second) {
            std20::erase_if(formal_callee_types, [&](auto* t) {
              return !type::check_cast(t, type_demand);
            });
          }
        }
        for (auto& p : caller_to_invocations.caller_insns) {
          for (auto it = p.second.begin(); it != p.second.end();) {
            auto insn = *it;
            if (!formal_callee_types.count(insn->get_method()->get_class())) {
              auto it2 = same_implementation_invokes.find(insn);
              if (it2 != same_implementation_invokes.end()) {
                always_assert(any_same_implementation_invokes);
                auto combined_type_demand = reduce_type_demands(type_demands);
                if (combined_type_demand) {
                  for (auto same_implementation_callee : it2->second->methods) {
                    always_assert_log(
                        type::check_cast(
                            same_implementation_callee->get_class(),
                            combined_type_demand),
                        "For the incoming code to be type correct, %s must "
                        "be castable to %s.",
                        SHOW(same_implementation_callee->get_class()),
                        SHOW(combined_type_demand));
                  }
                  caller_to_invocations.inlined_invokes_need_cast.emplace(
                      insn, combined_type_demand);
                } else {
                  // We can't just cast to the type of the representative. And
                  // it's not trivial to find the right common base type of the
                  // representatives, it might not even exist. (Imagine all
                  // subtypes happen the implement a set of interfaces.)
                  // TODO: Try harder.
                  caller_to_invocations.other_call_sites = true;
                  it = p.second.erase(it);
                  continue;
                }
              } else {
                caller_to_invocations.inlined_invokes_need_cast.emplace(
                    insn, callee->get_class());
              }
            }
            it++;
          }
        }
        std20::erase_if(caller_to_invocations.caller_insns,
                        [&](auto& p) { return p.second.empty(); });
      },
      true_virtual_callees);
  for (auto& pair : concurrent_true_virtual_callers) {
    DexMethod* callee = const_cast<DexMethod*>(pair.first);
    true_virtual_callers->emplace(callee, std::move(pair.second));
  }
}