opt/partial-application/PartialApplication.cpp (1,041 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ /* * This pass identifies commonly used constant arguments that methods are * invoked with, and then introduces helper functions that bind those arguments * if it seems beneficial to reduce the overall code size by rewriting the * call-sites. The new helper methods are placed in the same class as the * callee. Their name is stable, including a hash derived from the bound * constant arguments. * * The most interesting part of this optimization, with likely further tuning * potential, is the priority-queue based approach to find beneficial subsets of * common constant arguments. * * While this is similar in spirit to what the InstructionSequenceOutliner does, * a major difference is that this pass specifically targets individual method * invocations, and it picks up incoming constant arguments based on our * existing constant-propagation analysis, not caring where earlier in the code, * or in which order, the constant are defined. And ultimately, it picks a * beneficial subset of constant arguments regardless of what order they were * defined in. In contrast, the InstructionSequenceOutliner requires precise * matches of frequently occurring instruction opcodes sequences (but module * register names) in order outline any particular call-site. * * Here's an example of what the optimization does. Let's say there's a method * like this: * * void foo(int a, int b, Integer c); * * And it is invoked 10 times as * * foo(10, 20, Integer.valueOf(23)); * * And another 10 times as * * foo(13, 20, Integer.valueOf(23)); * * Let's say in neither case would a new helper function be beneficial to reduce * size. However, when we trim off the first argument, we are left with 20 times * * foo(*, 20, Integer.valueOf(23)); * * And this might be beneficial to transform. Then we introduce a helper * function like the following. * * foo$pa$xxxx(int a) { foo(a, 20, Integer.valueOf(23)); } * * And rewrite the call-sites to * * foo$pa$xxxx(10); * * and * * foo$pa$xxxx(13); * * respectively. * * Various safe-guards are in place: * - We won't introduce helper methods that would contain cross-store or * non-min-sdk level references. * - We only transform code with the largest root store id (so not in the * primary dex, unless there only is one, and not in other auxiliary stores) * - We won't rewrite code that sits in hot blocks in hot methods, or loops in * warm methods (reuses logic from InstructionSequenceOutliner) * * We don't do anything special for symbolication. Thus, the new helper methods * will appear in stack traces, but probably won't be confusing, as they have * names derived from the final callee, appearing as some trampoline method. The * code in the helper methods will never fail (except maybe under the most * obscure circumstances such as a stack-overflow), and thus will never be on * top of a stack trace, and only the top frame is used for symbolication. */ #include "PartialApplication.h" #include <cinttypes> #include <boost/format.hpp> #include <boost/pending/disjoint_sets.hpp> #include <boost/property_map/property_map.hpp> #include "CFGMutation.h" #include "CallSiteSummaries.h" #include "ConfigFiles.h" #include "Creators.h" #include "LiveRange.h" #include "MutablePriorityQueue.h" #include "OutliningProfileGuidanceImpl.h" #include "PassManager.h" #include "ReachableClasses.h" #include "RefChecker.h" #include "Shrinker.h" #include "SourceBlocks.h" #include "StlUtil.h" #include "Walkers.h" using namespace inliner; using namespace outliner; using namespace outliner_impl; using namespace live_range; using namespace shrinker; namespace { // Overhead of introducing a typical new helper method and its metadata. const size_t COST_METHOD = 28; // Retrieve list of classes in primary dex, if there is more than one store and // dexes. std::unordered_set<const DexType*> get_excluded_classes( DexStoresVector& stores) { std::unordered_set<const DexType*> excluded_classes; bool has_other_stores{false}; bool has_other_dexes{false}; for (auto& store : stores) { if (store.is_root_store()) { auto& dexen = store.get_dexen(); always_assert(!dexen.empty()); for (auto cls : dexen.front()) { excluded_classes.insert(cls->get_type()); } if (dexen.size() > 1) { has_other_dexes = true; } } else { has_other_stores = true; } } if (!has_other_stores && !has_other_dexes) { excluded_classes.clear(); } return excluded_classes; } const api::AndroidSDK* get_min_sdk_api(ConfigFiles& conf, PassManager& mgr) { int32_t min_sdk = mgr.get_redex_options().min_sdk; mgr.incr_metric("min_sdk", min_sdk); TRACE(PA, 2, "min_sdk: %d", min_sdk); auto min_sdk_api_file = conf.get_android_sdk_api_file(min_sdk); if (!min_sdk_api_file) { mgr.incr_metric("min_sdk_no_file", 1); TRACE(PA, 2, "Android SDK API %d file cannot be found.", min_sdk); return nullptr; } else { return &conf.get_android_sdk_api(min_sdk); } } using EnumUtilsCache = ConcurrentMap<int32_t, DexField*>; // Check if we have a boxed value for which there is a $EnumUtils field. DexField* try_get_enum_utils_f_field(EnumUtilsCache& cache, const ObjectWithImmutAttr& object) { // This matches EnumUtilsFieldAnalyzer::analyze_sget. always_assert(object.jvm_cached_singleton); always_assert(object.attributes.size() == 1); if (object.type != type::java_lang_Integer()) { return nullptr; } const auto& signed_value = object.attributes.front().value.get<SignedConstantDomain>(); auto c = signed_value.get_constant(); always_assert(c); DexField* res; cache.update(*c, [&res](int32_t key, DexField*& value, bool exists) { if (!exists) { auto cls = type_class(DexType::make_type("Lredex/$EnumUtils;")); if (cls) { std::string field_name = "f" + std::to_string(key); value = cls->find_sfield(field_name.c_str(), type::java_lang_Integer()); always_assert(!value || is_static(value)); } } res = value; }); return res; } // Identify how many argument slots an invocation needs after expansion of wide // types, and thus whether a range instruction will be needed. std::pair<param_index_t, bool> analyze_args(const DexMethod* callee) { const auto* args = callee->get_proto()->get_args(); param_index_t src_regs = args->size(); if (!is_static(callee)) { src_regs++; } param_index_t expanded_src_regs{!is_static(callee)}; for (auto t : *args) { expanded_src_regs += type::is_wide_type(t) ? 2 : 1; } auto needs_range = expanded_src_regs > 5; return {src_regs, needs_range}; } struct ArgExclusivity { // between 0 and 1 float ownership{0}; bool needs_move{false}; }; struct AggregatedArgExclusivity { double ownership{0}; uint32_t needs_move{0}; }; using ArgExclusivityVector = std::vector<std::pair<src_index_t, ArgExclusivity>>; // Determine whether, or to what extent, the instructions to compute arguments // to an invocation are exclusive to that invocation. (If not, then eliminating // the argument in the invocation likely won't give us expected cost savings.) ArgExclusivityVector get_arg_exclusivity(const UseDefChains& use_def_chains, const DefUseChains& def_use_chains, bool needs_range, IRInstruction* insn) { ArgExclusivityVector aev; for (param_index_t src_idx = 0; src_idx < insn->srcs_size(); src_idx++) { const auto& defs = use_def_chains.at((Use){insn, src_idx}); if (defs.size() != 1) { continue; } const auto def = *defs.begin(); bool other_use = false; param_index_t count = 0; for (const auto& use : def_use_chains.at(def)) { if (!opcode::is_a_move(use.insn->opcode()) && (use.insn->opcode() != insn->opcode() || use.insn->get_method() != insn->get_method())) { other_use = true; break; } count++; } float ownership = other_use ? 0.0 : (1.0 / count); // TODO: We also likely need a move if there are more than 16 args // (including extra wides) live at this point. bool needs_move = needs_range && (other_use || count > 1); if (ownership > 0 || needs_move) { aev.emplace_back(src_idx, (ArgExclusivity){ownership, needs_move}); } } return aev; } using CalleeCallerClasses = std::unordered_map<const DexMethod*, std::unordered_set<const DexType*>>; // Gather all (caller, callee) pairs. Also compute arg exclusivity, which invoke // instructions we should exclude, and how many classes calls are distributed // over. void gather_caller_callees( const ProfileGuidanceConfig& profile_guidance_config, const Scope& scope, const std::unordered_set<DexMethod*>& sufficiently_warm_methods, const std::unordered_set<DexMethod*>& sufficiently_hot_methods, const GetCalleeFunction& get_callee_fn, MethodToMethodOccurrences* callee_caller, MethodToMethodOccurrences* caller_callee, std::unordered_map<const IRInstruction*, ArgExclusivityVector>* arg_exclusivity, std::unordered_set<const IRInstruction*>* excluded_invoke_insns, CalleeCallerClasses* callee_caller_classes) { Timer timer("gather_caller_callees"); using ConcurrentMethodToMethodOccurrences = ConcurrentMap<const DexMethod*, std::unordered_map<DexMethod*, size_t>>; ConcurrentMethodToMethodOccurrences concurrent_callee_caller; ConcurrentMethodToMethodOccurrences concurrent_caller_callee; ConcurrentSet<const IRInstruction*> concurrent_excluded_invoke_insns; ConcurrentMap<const IRInstruction*, ArgExclusivityVector> concurrent_arg_exclusivity; ConcurrentMap<const DexMethod*, std::unordered_set<const DexType*>> concurrent_callee_caller_classes; walk::parallel::code(scope, [&](DexMethod* caller, IRCode& code) { code.build_cfg(true); CanOutlineBlockDecider block_decider( profile_guidance_config, sufficiently_warm_methods.count(caller), sufficiently_hot_methods.count(caller)); MoveAwareChains move_aware_chains(code.cfg()); const auto use_def_chains = move_aware_chains.get_use_def_chains(); const auto def_use_chains = move_aware_chains.get_def_use_chains(); for (auto& big_block : big_blocks::get_big_blocks(code.cfg())) { auto can_outline = block_decider.can_outline_from_big_block(big_block) == CanOutlineBlockDecider::Result::CanOutline; for (auto& mie : big_blocks::InstructionIterable(big_block)) { auto insn = mie.insn; auto callee = get_callee_fn(caller, insn); if (!callee) { continue; } if (!can_outline) { concurrent_excluded_invoke_insns.insert(insn); continue; } auto needs_range = analyze_args(callee).second; auto ae = get_arg_exclusivity(use_def_chains, def_use_chains, needs_range, insn); if (ae.empty()) { concurrent_excluded_invoke_insns.insert(insn); continue; } concurrent_callee_caller.update( callee, [caller](const DexMethod*, std::unordered_map<DexMethod*, size_t>& v, bool) { ++v[caller]; }); concurrent_caller_callee.update( caller, [callee](const DexMethod*, std::unordered_map<DexMethod*, size_t>& v, bool) { ++v[callee]; }); concurrent_arg_exclusivity.emplace(insn, std::move(ae)); concurrent_callee_caller_classes.update( callee, [caller](const DexMethod*, std::unordered_set<const DexType*>& value, bool) { value.insert(caller->get_class()); }); } } }); for (auto& p : concurrent_callee_caller) { callee_caller->insert(std::move(p)); } for (auto& p : concurrent_caller_callee) { caller_callee->insert(std::move(p)); } excluded_invoke_insns->insert(concurrent_excluded_invoke_insns.begin(), concurrent_excluded_invoke_insns.end()); for (auto& p : concurrent_arg_exclusivity) { arg_exclusivity->insert(std::move(p)); } for (auto& p : concurrent_callee_caller_classes) { callee_caller_classes->insert(std::move(p)); } } using InvokeCallSiteSummaries = std::unordered_map<const IRInstruction*, const CallSiteSummary*>; // Whether to include a particular constant argument value. We only include // actual constant (not just abstract value like NEZ), and only if they don't // violate anything the ref-checker would complain about. We can also handle // singletons and immutable objects if they represent jvm cached singletons. bool filter(const RefChecker& ref_checker, EnumUtilsCache& enum_utils_cache, const ConstantValue& value) { if (const auto& signed_value = value.maybe_get<SignedConstantDomain>()) { return !!signed_value->get_constant(); } else if (const auto& singleton_value = value.maybe_get<SingletonObjectDomain>()) { auto field = *singleton_value->get_constant(); return ref_checker.check_field(field); } else if (const auto& obj_or_none = value.maybe_get<ObjectWithImmutAttrDomain>()) { auto object = obj_or_none->get_constant(); if (!object->jvm_cached_singleton) { return false; } if (DexField* field = try_get_enum_utils_f_field(enum_utils_cache, *object)) { return ref_checker.check_field(field); } else { always_assert(object->attributes.size() == 1); const auto& signed_value2 = object->attributes.front().value.maybe_get<SignedConstantDomain>(); always_assert(signed_value2); return filter(ref_checker, enum_utils_cache, *signed_value2); } } else { not_reached_log("unexpected value: %s", SHOW(value)); } } using CallSiteSummarySet = std::unordered_set<const CallSiteSummary*>; using CallSiteSummaryVector = std::vector<const CallSiteSummary*>; CallSiteSummaryVector order_csses(const CallSiteSummarySet& csses) { CallSiteSummaryVector ordered_csses(csses.begin(), csses.end()); std::sort(ordered_csses.begin(), ordered_csses.end(), [](const CallSiteSummary* a, const CallSiteSummary* b) { return a->get_key() < b->get_key(); }); return ordered_csses; } // Priority-queue based algorithm to select which invocations and which constant // arguments are beneficial to transform. class CalleeInvocationSelector { private: EnumUtilsCache& m_enum_utils_cache; CallSiteSummarizer& m_call_site_summarizer; const DexMethod* m_callee; const std::unordered_map<const IRInstruction*, ArgExclusivityVector>& m_arg_exclusivity; size_t m_callee_caller_classes; param_index_t m_src_regs; bool m_needs_range; // When we are going to merge different call-site summaries after simplifying, // we need to efficiently track what all the underlying call-site summaries // were. We do that via a "disjoint_sets" data structure what all the // underlying call-site summaries are. using Rank = std::unordered_map<const CallSiteSummary*, size_t>; using Parent = std::unordered_map<const CallSiteSummary*, const CallSiteSummary*>; using RankPMap = boost::associative_property_map<Rank>; using ParentPMap = boost::associative_property_map<Parent>; using CallSiteSummarySets = boost::disjoint_sets<RankPMap, ParentPMap>; Rank m_rank; Parent m_parent; CallSiteSummarySets m_css_sets; CallSiteSummarySet m_call_site_summaries; using ArgumentCosts = std::unordered_map<src_index_t, int32_t>; std::unordered_map<const CallSiteSummary*, ArgumentCosts> m_call_site_summary_argument_costs; using KeyedCosts = std::unordered_map<std::string, int32_t>; std::vector<KeyedCosts> m_total_argument_costs; using KeyedCsses = std::unordered_map<std::string, CallSiteSummarySet>; std::vector<KeyedCsses> m_dependencies; std::vector<std::pair<const IRInstruction*, const CallSiteSummary*>> m_call_site_invoke_summaries; std::unordered_map<const CallSiteSummary*, std::unordered_map<src_index_t, AggregatedArgExclusivity>> m_aggregated_arg_exclusivity; static std::string get_key(const ConstantValue& value) { std::ostringstream oss; CallSiteSummary::append_key_value(oss, value); return oss.str(); } static int32_t sum_call_sites_savings(const ArgumentCosts& ac) { int32_t savings = 0; for (const auto& p : ac) { savings += p.second; } return savings; } int16_t const_value_cost(const ConstantValue& value) const { if (const auto& signed_value = value.maybe_get<SignedConstantDomain>()) { auto c = signed_value->get_constant(); always_assert(c); auto lit = *c; if (lit < -2147483648 || lit > 2147483647) { return 5; } else if (lit < -32768 || lit > 32767) { return 3; } else if (lit < -8 || lit > 7) { return 2; } else { return 1; } } else if (const auto& singleton_value = value.maybe_get<SingletonObjectDomain>()) { return 2; } else if (const auto& obj_or_none = value.maybe_get<ObjectWithImmutAttrDomain>()) { auto object = obj_or_none->get_constant(); always_assert(object); if (try_get_enum_utils_f_field(m_enum_utils_cache, *object)) { return 2; } else { always_assert(object->jvm_cached_singleton); always_assert(object->attributes.size() == 1); const auto& signed_value2 = object->attributes.front().value.maybe_get<SignedConstantDomain>(); always_assert(signed_value2); return 3 + const_value_cost(*signed_value2); } } else { not_reached_log("unexpected value: %s", SHOW(value)); } } std::pair<param_index_t, uint32_t> find_argument_with_least_cost( const CallSiteSummary* css) const { const auto& bindings = css->arguments.bindings(); boost::optional<int32_t> least_cost; param_index_t least_cost_src_idx = 0; for (auto& p : bindings) { auto& arguments_cost = m_total_argument_costs.at(p.first); auto it = arguments_cost.find(get_key(p.second)); auto cost = it == arguments_cost.end() ? 0 : it->second; if (!least_cost || *least_cost > cost || (*least_cost == cost && p.first < least_cost_src_idx)) { least_cost = cost; least_cost_src_idx = p.first; } } always_assert(least_cost); return {least_cost_src_idx, *least_cost}; } int32_t get_net_savings(const CallSiteSummary* css) const { // The cost for an additional partial-application helper method consists // of... // - the basic overhead of having a method // - an estimated cross-dex penalty, as the PartialApplication pass has to // run before the InterDex pass, and adding extra method-refs has global // negative effects on the number of needed cross-dex references. // - an extra move-result instruction // - the cost of const instructions // - some extra potetnail move overhead if we need the range form int32_t pa_cross_dex_penalty = 2 * std::ceil(std::sqrt(m_callee_caller_classes)); int32_t pa_method_cost = COST_METHOD + pa_cross_dex_penalty + css->result_used; const auto& bindings = css->arguments.bindings(); for (auto& r : bindings) { pa_method_cost += const_value_cost(r.second); } if (m_needs_range) { pa_method_cost += m_src_regs; } auto call_sites_savings = sum_call_sites_savings(m_call_site_summary_argument_costs.at(css)); return call_sites_savings - pa_method_cost; } using Priority = uint64_t; uint64_t m_running_index = 0; Priority make_priority(const CallSiteSummary* css) { // We order by... // - (1 bit) whether net savings are positive // - (31 bits) if not, (clipped) least argument costs (smaller is better) // - (32 bits) running index to make the priority unique auto net_savings = get_net_savings(css); uint64_t positive = net_savings > 0; uint64_t a = 0; if (!positive) { auto least_cost = find_argument_with_least_cost(css).second; a = std::min<uint32_t>(least_cost, (1U << 31) - 1); } uint64_t b = m_running_index++; always_assert(positive < 2); always_assert(a < (1UL << 31)); always_assert(b < (1UL << 32)); return (positive << 63) | (a << 32) | b; } MutablePriorityQueue<const CallSiteSummary*, Priority> m_pq; public: CalleeInvocationSelector( EnumUtilsCache& enum_utils_cache, CallSiteSummarizer& call_site_summarizer, const DexMethod* callee, const std::unordered_map<const IRInstruction*, ArgExclusivityVector>& arg_exclusivity, size_t callee_caller_classes) : m_enum_utils_cache(enum_utils_cache), m_call_site_summarizer(call_site_summarizer), m_callee(callee), m_arg_exclusivity(arg_exclusivity), m_callee_caller_classes(callee_caller_classes), m_css_sets((RankPMap(m_rank)), (ParentPMap(m_parent))) { auto callee_call_site_invokes = call_site_summarizer.get_callee_call_site_invokes(callee); if (!callee_call_site_invokes) { return; } std::tie(m_src_regs, m_needs_range) = analyze_args(callee); TRACE( PA, 2, "[PartialApplication] Processing %s, %zu caller classes, %u src regs%s", SHOW(m_callee), callee_caller_classes, m_src_regs, m_needs_range ? ", needs_range" : ""); m_total_argument_costs = std::vector<KeyedCosts>(m_src_regs, KeyedCosts()); m_dependencies = std::vector<KeyedCsses>(m_src_regs, KeyedCsses()); // Aggregate arg exclusivity across call-sites with the same summary. for (auto invoke_insn : *callee_call_site_invokes) { auto css = call_site_summarizer.get_instruction_call_site_summary(invoke_insn); if (css->arguments.is_top()) { continue; } if (!is_static(callee) && !css->arguments.get(0).is_top()) { // We don't want to deal with cases where an instance method is called // with nullptr. TRACE(PA, 2, "[PartialApplication] Ignoring invocation of instance method %s " "with %s", SHOW(callee), css->get_key().c_str()); continue; } m_call_site_invoke_summaries.emplace_back(invoke_insn, css); auto& aev = arg_exclusivity.at(invoke_insn); auto& aaem = m_aggregated_arg_exclusivity[css]; for (auto& p : aev) { auto& aae = aaem[p.first]; aae.ownership += p.second.ownership; aae.needs_move += p.second.needs_move; } } // For each call-site summary, // - initialize disjoint set singleton, and // - compute current constant argument costs that could potentially be saved // when introducing partial-application helper method, and // - keep track of which constant value for which parameter is involved in // that call-site summary, which we'll need later when re-prioritizing // call-site summaries in the priority queue. for (auto& p : m_aggregated_arg_exclusivity) { auto css = p.first; auto& aaem = p.second; m_call_site_summaries.insert(css); m_css_sets.make_set(css); auto& ac = m_call_site_summary_argument_costs[css]; const auto& bindings = css->arguments.bindings(); for (auto& q : bindings) { const auto src_idx = q.first; const auto& value = q.second; auto& aae = aaem[src_idx]; int32_t cost = const_value_cost(value) * aae.ownership + 2 * aae.needs_move; ac.emplace(src_idx, cost); auto key = get_key(value); m_total_argument_costs.at(src_idx)[key] += cost; m_dependencies.at(src_idx)[key].insert(css); } } } // Fill priority queue with raw data. void fill_pq() { // Populate priority queue for (auto css : order_csses(m_call_site_summaries)) { auto priority = make_priority(css); TRACE(PA, 4, "[PartialApplication] Considering %s(%s): net savings %d, priority " "%016" PRIx64, SHOW(m_callee), css->get_key().c_str(), get_net_savings(css), priority); m_pq.insert(css, priority); } } // For all items in the queue which have non-positive net savings, chop // off the argument with least cost, and lump it together with any // possibly already existing item. void reduce_pq() { while (!m_pq.empty() && get_net_savings(m_pq.back()) <= 0) { auto css = m_pq.back(); m_pq.erase(css); auto ac_it = m_call_site_summary_argument_costs.find(css); auto ac = std::move(ac_it->second); m_call_site_summary_argument_costs.erase(ac_it); for (auto& p : css->arguments.bindings()) { bool erased = m_dependencies.at(p.first).at(get_key(p.second)).erase(css); always_assert(erased); } auto [src_idx, least_cost] = find_argument_with_least_cost(css); always_assert(!css->arguments.get(src_idx).is_top()); auto key = get_key(css->arguments.get(src_idx)); m_total_argument_costs.at(src_idx).at(key) -= ac.at(src_idx); CallSiteSummary reduced_css_val{css->arguments, css->result_used}; reduced_css_val.arguments.set(src_idx, ConstantValue::top()); if (reduced_css_val.arguments.is_top()) { TRACE(PA, 4, "[PartialApplication] Removing %s(%s) with least cost %u@%u", SHOW(m_callee), css->get_key().c_str(), least_cost, src_idx); } else { auto reduced_css = m_call_site_summarizer.internalize_call_site_summary( reduced_css_val); ac_it = m_call_site_summary_argument_costs.find(reduced_css); if (ac_it == m_call_site_summary_argument_costs.end()) { ac_it = m_call_site_summary_argument_costs .emplace(reduced_css, ArgumentCosts()) .first; for (auto& p : reduced_css->arguments.bindings()) { bool inserted = m_dependencies.at(p.first) .at(get_key(p.second)) .insert(reduced_css) .second; always_assert(inserted); } } else { m_pq.erase(reduced_css); } for (auto& p : ac) { ac_it->second[p.first] += p.second; } ac_it->second.erase(src_idx); m_pq.insert(reduced_css, make_priority(reduced_css)); if (m_call_site_summaries.insert(reduced_css).second) { m_css_sets.make_set(reduced_css); } m_css_sets.union_set(css, reduced_css); TRACE(PA, 4, "[PartialApplication] Merging %s(%s ===> %s) with least cost " "%u@%u: net savings %d", SHOW(m_callee), css->get_key().c_str(), reduced_css->get_key().c_str(), least_cost, src_idx, get_net_savings(reduced_css)); } const auto& csses = m_dependencies.at(src_idx).at(key); for (auto dependent_css : order_csses(csses)) { TRACE(PA, 4, "[PartialApplication] Reprioritizing %s(%s)", SHOW(m_callee), dependent_css->get_key().c_str()); m_pq.update_priority(dependent_css, make_priority(dependent_css)); } } } // Identify all invocations which contributed to groups with combined positive // expected savings. void select_invokes(std::atomic<size_t>* total_estimated_savings, InvokeCallSiteSummaries* selected_invokes) { size_t partial_application_methods{0}; std::unordered_map<const CallSiteSummary*, const CallSiteSummary*> selected_css_sets; uint32_t callee_estimated_savings = 0; while (!m_pq.empty()) { auto css = m_pq.front(); auto net_savings = get_net_savings(css); m_pq.erase(css); selected_css_sets.emplace(m_css_sets.find_set(css), css); callee_estimated_savings += net_savings; partial_application_methods++; TRACE(PA, 3, "[PartialApplication] Selected %s(%s) with net savings %d", SHOW(m_callee), css->get_key().c_str(), net_savings); always_assert(net_savings > 0); } for (auto& p : m_call_site_invoke_summaries) { auto invoke_insn = p.first; auto css = p.second; if (!m_call_site_summaries.count(css)) { continue; } auto it = selected_css_sets.find(m_css_sets.find_set(css)); if (it == selected_css_sets.end()) { continue; } auto reduced_css = it->second; // This invoke got selected because including it together with all // other invokes with the same css was beneficial on average. Check // (and filter out) if it's not actually beneficial for this particular // invoke. auto& aev = m_arg_exclusivity.at(invoke_insn); const auto& bindings = reduced_css->arguments.bindings(); if (std::find_if(aev.begin(), aev.end(), [&bindings](auto& q) { return !bindings.at(q.first).is_top(); }) == aev.end()) { continue; } selected_invokes->emplace(invoke_insn, reduced_css); } if (callee_estimated_savings > 0) { TRACE(PA, 2, "[PartialApplication] Selected %s(...) for %zu constant argument " "combinations across %zu invokes with net savings %u", SHOW(m_callee), partial_application_methods, selected_invokes->size(), callee_estimated_savings); *total_estimated_savings += callee_estimated_savings; } } }; // From a call-site summary that include constant-arguments, derive the // signature of the new helper methods that will bind them. DexTypeList* get_partial_application_args(bool callee_is_static, const DexProto* callee_proto, const CallSiteSummary* css) { const auto* args = callee_proto->get_args(); DexTypeList::ContainerType new_args; param_index_t offset = 0; if (!callee_is_static) { always_assert(css->arguments.get(0).is_top()); offset++; } for (param_index_t i = 0; i < args->size(); i++) { if (css->arguments.get(offset + i).is_top()) { new_args.push_back(args->at(i)); } } return DexTypeList::make_type_list(std::move(new_args)); } uint64_t get_stable_hash(uint64_t a, uint64_t b) { return a ^ b; } uint64_t get_stable_hash(const std::string& s) { uint64_t stable_hash{s.size()}; for (auto c : s) { stable_hash = stable_hash * 7 + c; } return stable_hash; } using PaMethodRefs = ConcurrentMap<CalleeCallSiteSummary, DexMethodRef*, boost::hash<CalleeCallSiteSummary>>; // Run the analysis over all callees. void select_invokes_and_callers( EnumUtilsCache& enum_utils_cache, CallSiteSummarizer& call_site_summarizer, const MethodToMethodOccurrences& callee_caller, const std::unordered_map<const IRInstruction*, ArgExclusivityVector>& arg_exclusivity, const CalleeCallerClasses& callee_caller_classes, size_t iteration, std::atomic<size_t>* total_estimated_savings, PaMethodRefs* pa_method_refs, InvokeCallSiteSummaries* selected_invokes, std::unordered_set<DexMethod*>* selected_callers) { Timer t("select_invokes_and_callers"); std::vector<const DexMethod*> callees; std::unordered_map<const DexType*, std::vector<const DexMethod*>> callees_by_classes; std::unordered_map<const DexMethod*, InvokeCallSiteSummaries> selected_invokes_by_callees; for (auto& p : callee_caller) { auto callee = p.first; callees.push_back(callee); callees_by_classes[callee->get_class()].push_back(callee); selected_invokes_by_callees[callee]; } workqueue_run<const DexMethod*>( [&](const DexMethod* callee) { CalleeInvocationSelector cis(enum_utils_cache, call_site_summarizer, callee, arg_exclusivity, callee_caller_classes.at(callee).size()); cis.fill_pq(); cis.reduce_pq(); cis.select_invokes(total_estimated_savings, &selected_invokes_by_callees.at(callee)); }, callees); std::vector<const DexType*> callee_classes; callee_classes.reserve(callees_by_classes.size()); for (auto& p : callees_by_classes) { callee_classes.push_back(p.first); } std::mutex mutex; workqueue_run<const DexType*>( [&](const DexType* callee_class) { auto& class_callees = callees_by_classes.at(callee_class); std::sort(class_callees.begin(), class_callees.end(), compare_dexmethods); std::unordered_map<uint64_t, uint32_t> stable_hash_indices; for (auto callee : class_callees) { auto& callee_selected_invokes = selected_invokes_by_callees.at(callee); if (callee_selected_invokes.empty()) { continue; } auto callee_stable_hash = get_stable_hash(show(callee)); std::map<const DexTypeList*, std::unordered_set<const CallSiteSummary*>, dextypelists_comparator> ordered_pa_args_csses; auto callee_is_static = is_static(callee); auto callee_proto = callee->get_proto(); for (auto& p : callee_selected_invokes) { auto css = p.second; auto pa_args = get_partial_application_args(callee_is_static, callee_proto, css); auto inserted = ordered_pa_args_csses[pa_args].insert(css).second; always_assert(true); } for (auto& p : ordered_pa_args_csses) { auto pa_args = p.first; auto& csses = p.second; for (auto css : order_csses(csses)) { auto css_stable_hash = get_stable_hash(css->get_key()); auto stable_hash = get_stable_hash(callee_stable_hash, css_stable_hash); auto stable_hash_index = stable_hash_indices[stable_hash]++; std::ostringstream oss; oss << callee->get_name()->str() << (is_static(callee) ? "$spa$" : "$ipa$") << iteration << "$" << ((boost::format("%08x") % stable_hash).str()) << "$" << stable_hash_index; auto pa_name = DexString::make_string(oss.str()); auto pa_rtype = css->result_used ? callee_proto->get_rtype() : type::_void(); auto pa_proto = DexProto::make_proto(pa_rtype, pa_args); auto pa_type = callee->get_class(); auto pa_method_ref = DexMethod::make_method(pa_type, pa_name, pa_proto); CalleeCallSiteSummary ccss{callee, css}; pa_method_refs->emplace(ccss, pa_method_ref); } } std::lock_guard<std::mutex> lock_guard(mutex); selected_invokes->insert(callee_selected_invokes.begin(), callee_selected_invokes.end()); for (auto& p : callee_caller.at(callee)) { selected_callers->insert(p.first); } } }, callee_classes); } IROpcode get_invoke_opcode(const DexMethod* callee) { return callee->is_virtual() ? OPCODE_INVOKE_VIRTUAL : is_static(callee) ? OPCODE_INVOKE_STATIC : OPCODE_INVOKE_DIRECT; } // Given the analysis results, rewrite all callers to invoke the new helper // methods with bound arguments. void rewrite_callers( const Scope& scope, Shrinker& shrinker, const GetCalleeFunction& get_callee_fn, const std::unordered_map<const IRInstruction*, const CallSiteSummary*>& selected_invokes, const std::unordered_set<DexMethod*>& selected_callers, PaMethodRefs& pa_method_refs, std::atomic<size_t>* removed_args) { Timer t("rewrite_callers"); auto make_partial_application_invoke_insn = [&](DexMethod* caller, IRInstruction* insn) -> IRInstruction* { if (!opcode::is_an_invoke(insn->opcode())) { return nullptr; } auto it = selected_invokes.find(insn); if (it == selected_invokes.end()) { return nullptr; } auto callee = get_callee_fn(caller, insn); always_assert(callee != nullptr); auto css = it->second; CalleeCallSiteSummary ccss{callee, css}; DexMethodRef* pa_method_ref = pa_method_refs.at_unsafe(ccss); auto new_insn = (new IRInstruction(get_invoke_opcode(callee))) ->set_method(pa_method_ref); new_insn->set_srcs_size(insn->srcs_size() - css->arguments.size()); param_index_t idx = 0; for (param_index_t i = 0; i < insn->srcs_size(); i++) { if (css->arguments.get(i).is_top()) { new_insn->set_src(idx++, insn->src(i)); } } always_assert(idx == new_insn->srcs_size()); return new_insn; }; walk::parallel::code(scope, [&](DexMethod* caller, IRCode& code) { if (selected_callers.count(caller)) { bool any_changes{false}; auto& cfg = code.cfg(); cfg::CFGMutation mutation(cfg); auto ii = InstructionIterable(cfg); size_t removed_srcs{0}; for (auto it = ii.begin(); it != ii.end(); it++) { auto new_invoke_insn = make_partial_application_invoke_insn(caller, it->insn); if (!new_invoke_insn) { continue; } removed_srcs += it->insn->srcs_size() - new_invoke_insn->srcs_size(); std::vector<IRInstruction*> new_insns{new_invoke_insn}; auto move_result_it = cfg.move_result_of(it); if (!move_result_it.is_end()) { new_insns.push_back(new IRInstruction(*move_result_it->insn)); } mutation.replace(it, new_insns); any_changes = true; } mutation.flush(); if (any_changes) { TRACE(PA, 6, "[PartialApplication] Rewrote %s:\n%s", SHOW(caller), SHOW(cfg)); shrinker.shrink_method(caller); (*removed_args) += removed_srcs; } } code.clear_cfg(); }); } // Helper used to build the partial-assignment helper methods. void push_callee_arg(EnumUtilsCache& enum_utils_cache, DexType* type, const ConstantValue& value, MethodCreator* method_creator, MethodBlock* main_block, std::vector<Location>* callee_args) { if (const auto& signed_value = value.maybe_get<SignedConstantDomain>()) { auto c = signed_value->get_constant(); always_assert(c); auto tmp = method_creator->make_local(type); main_block->load_const(tmp, *c, type); callee_args->push_back(tmp); } else if (const auto& singleton_value = value.maybe_get<SingletonObjectDomain>()) { auto c = singleton_value->get_constant(); always_assert(c); auto field = *c; always_assert(is_static(field)); auto tmp = method_creator->make_local(type); main_block->sfield_op(opcode::sget_opcode_for_field(field), const_cast<DexField*>(field), tmp); callee_args->push_back(tmp); } else if (const auto& obj_or_none = value.maybe_get<ObjectWithImmutAttrDomain>()) { auto object = obj_or_none->get_constant(); always_assert(object); if (DexField* field = try_get_enum_utils_f_field(enum_utils_cache, *object)) { auto tmp = method_creator->make_local(field->get_type()); main_block->sfield_op(opcode::sget_opcode_for_field(field), field, tmp); callee_args->push_back(tmp); } else { always_assert(object->jvm_cached_singleton); always_assert(object->attributes.size() == 1); auto valueOf = type::get_value_of_method_for_type(object->type); auto valueOf_arg_type = valueOf->get_proto()->get_args()->at(0); auto tmp = method_creator->make_local(valueOf_arg_type); const auto& signed_value2 = object->attributes.front().value.maybe_get<SignedConstantDomain>(); always_assert(signed_value2); auto c = signed_value2->get_constant(); always_assert(c); main_block->load_const(tmp, *c, valueOf_arg_type); main_block->invoke(OPCODE_INVOKE_STATIC, valueOf, {tmp}); tmp = method_creator->make_local(type); main_block->move_result(tmp, type); callee_args->push_back(tmp); } } else { not_reached_log("unexpected value: %s", SHOW(value)); } } // Create all new helper methods that bind constant arguments void create_partial_application_methods(EnumUtilsCache& enum_utils_cache, PaMethodRefs& pa_method_refs) { Timer t("create_partial_application_methods"); std::map<DexMethodRef*, const CalleeCallSiteSummary*, dexmethods_comparator> inverse_ordered_pa_method_refs; for (auto& p : pa_method_refs) { bool success = inverse_ordered_pa_method_refs.emplace(p.second, &p.first).second; always_assert(success); } for (auto& p : inverse_ordered_pa_method_refs) { auto pa_method_ref = p.first; auto callee = p.second->method; auto cls = type_class(callee->get_class()); always_assert(cls); auto css = p.second->call_site_summary; auto access = callee->get_access() & ~(ACC_ABSTRACT | ACC_NATIVE); if (callee->is_virtual()) { access |= ACC_FINAL; } MethodCreator method_creator(pa_method_ref, access); auto main_block = method_creator.get_main_block(); std::vector<Location> callee_args; param_index_t offset = 0; param_index_t next_arg_idx = 0; if (!is_static(callee)) { always_assert(css->arguments.get(0).is_top()); offset++; callee_args.push_back(method_creator.get_local(next_arg_idx++)); } auto proto = callee->get_proto(); const auto* args = proto->get_args(); for (param_index_t i = 0; i < args->size(); i++) { const auto& value = css->arguments.get(offset + i); if (value.is_top()) { callee_args.push_back(method_creator.get_local(next_arg_idx++)); } else { push_callee_arg(enum_utils_cache, args->at(i), value, &method_creator, main_block, &callee_args); } } main_block->invoke(get_invoke_opcode(callee), const_cast<DexMethod*>(callee), callee_args); if (css->result_used) { auto tmp = method_creator.make_local(proto->get_rtype()); main_block->move_result(tmp, proto->get_rtype()); main_block->ret(tmp); } else { main_block->ret_void(); } auto pa_method = method_creator.create(); pa_method->rstate.set_generated(); pa_method->rstate.set_dont_inline(); if (!is_static(callee) && is_public(callee)) { pa_method->set_virtual(true); } pa_method->set_deobfuscated_name(show_deobfuscated(pa_method)); cls->add_method(pa_method); TRACE(PA, 5, "[PartialApplication] Created %s binding %s:\n%s", SHOW(pa_method), css->get_key().c_str(), SHOW(pa_method->get_code())); } } } // namespace void PartialApplicationPass::bind_config() { auto& pg = m_profile_guidance_config; bind("use_method_profiles", pg.use_method_profiles, pg.use_method_profiles, "Whether to use provided method-profiles configuration data to " "determine if certain code should not be outlined from a method"); bind("method_profiles_appear_percent", pg.method_profiles_appear_percent, pg.method_profiles_appear_percent, "Cut off when a method in a method profile is deemed relevant"); bind("method_profiles_hot_call_count", pg.method_profiles_hot_call_count, pg.method_profiles_hot_call_count, "No code is outlined out of hot methods"); bind("method_profiles_warm_call_count", pg.method_profiles_warm_call_count, pg.method_profiles_warm_call_count, "Loops are not outlined from warm methods"); std::string perf_sensitivity_str; bind("perf_sensitivity", "always-hot", perf_sensitivity_str); bind("block_profiles_hits", pg.block_profiles_hits, pg.block_profiles_hits, "No code is outlined out of hot blocks in hot methods"); after_configuration([=]() { always_assert(!perf_sensitivity_str.empty()); m_profile_guidance_config.perf_sensitivity = parse_perf_sensitivity(perf_sensitivity_str); }); } void PartialApplicationPass::run_pass(DexStoresVector& stores, ConfigFiles& conf, PassManager& mgr) { const auto scope = build_class_scope(stores); init_classes::InitClassesWithSideEffects init_classes_with_side_effects( scope, conf.create_init_class_insns()); auto excluded_classes = get_excluded_classes(stores); int min_sdk = mgr.get_redex_options().min_sdk; auto min_sdk_api = get_min_sdk_api(conf, mgr); XStoreRefs xstores(stores); // RefChecker store_idx is initialized with `largest_root_store_id()`, so that // it rejects all the references from stores with id larger than the largest // root_store id. RefChecker ref_checker(&xstores, xstores.largest_root_store_id(), min_sdk_api); std::unordered_set<DexMethod*> sufficiently_warm_methods; std::unordered_set<DexMethod*> sufficiently_hot_methods; gather_sufficiently_warm_and_hot_methods( scope, conf, mgr, m_profile_guidance_config, &sufficiently_warm_methods, &sufficiently_hot_methods); mgr.incr_metric("num_sufficiently_warm_methods", sufficiently_warm_methods.size()); mgr.incr_metric("num_sufficiently_hot_methods", sufficiently_hot_methods.size()); ShrinkerConfig shrinker_config; shrinker_config.run_local_dce = true; shrinker_config.compute_pure_methods = false; Shrinker shrinker(stores, scope, init_classes_with_side_effects, shrinker_config, min_sdk); std::unordered_set<const IRInstruction*> excluded_invoke_insns; auto get_callee_fn = [&excluded_classes, &excluded_invoke_insns]( DexMethod* caller, IRInstruction* insn) -> DexMethod* { if (!opcode::is_an_invoke(insn->opcode()) || insn->opcode() == OPCODE_INVOKE_SUPER || method::is_init(insn->get_method()) || excluded_invoke_insns.count(insn) || caller->rstate.no_optimizations() || excluded_classes.count(caller->get_class())) { return nullptr; } auto callee = resolve_method(insn->get_method(), opcode_to_search(insn), caller); if (!callee || callee->is_external()) { return nullptr; } auto cls = type_class(callee->get_class()); if (!cls || cls->is_external() || is_native(cls) || excluded_classes.count(cls->get_type())) { return nullptr; } // We'd add helper methods to the class, so we also want to avoid that it's // being used via reflection. if (!can_rename(cls)) { return nullptr; } // TODO: Support interface callees. if (is_interface(cls)) { return nullptr; } return callee; }; MethodToMethodOccurrences callee_caller; MethodToMethodOccurrences caller_callee; std::unordered_map<const IRInstruction*, ArgExclusivityVector> arg_exclusivity; CalleeCallerClasses callee_caller_classes; gather_caller_callees( m_profile_guidance_config, scope, sufficiently_warm_methods, sufficiently_hot_methods, get_callee_fn, &callee_caller, &caller_callee, &arg_exclusivity, &excluded_invoke_insns, &callee_caller_classes); TRACE(PA, 1, "[PartialApplication] %zu callers, %zu callees", caller_callee.size(), callee_caller.size()); // By indicating to the call-site summarizer that any callee may have other // call-sites, we effectively disable top-down constant-propagation, as that // would be unlikely to find true constants, and yet would take more time by // limiting parallelism. auto has_callee_other_call_sites_fn = [](DexMethod*) -> bool { return true; }; EnumUtilsCache enum_utils_cache; std::function<bool(const ConstantValue& value)> filter_fn = [&](const ConstantValue& value) { return filter(ref_checker, enum_utils_cache, value); }; CallSiteSummaryStats call_site_summarizer_stats; CallSiteSummarizer call_site_summarizer( shrinker, callee_caller, caller_callee, get_callee_fn, has_callee_other_call_sites_fn, &filter_fn, &call_site_summarizer_stats); call_site_summarizer.summarize(); std::atomic<size_t> total_estimated_savings{0}; PaMethodRefs pa_method_refs; std::unordered_map<const IRInstruction*, const CallSiteSummary*> selected_invokes; std::unordered_set<DexMethod*> selected_callers; select_invokes_and_callers( enum_utils_cache, call_site_summarizer, callee_caller, arg_exclusivity, callee_caller_classes, m_iteration++, &total_estimated_savings, &pa_method_refs, &selected_invokes, &selected_callers); std::atomic<size_t> removed_args{0}; rewrite_callers(scope, shrinker, get_callee_fn, selected_invokes, selected_callers, pa_method_refs, &removed_args); create_partial_application_methods(enum_utils_cache, pa_method_refs); TRACE(PA, 1, "[PartialApplication] Created %zu methods with particular constant " "argument combinations, rewriting %zu invokes across %zu callers, " "removing %zu args, with (estimated) net savings %zu", pa_method_refs.size(), selected_invokes.size(), selected_callers.size(), (size_t)removed_args, (size_t)total_estimated_savings); mgr.incr_metric("total_estimated_savings", total_estimated_savings); mgr.incr_metric("rewritten_invokes", selected_invokes.size()); mgr.incr_metric("removed_args", removed_args); mgr.incr_metric("affected_callers", selected_callers.size()); mgr.incr_metric("partial_application_methods", pa_method_refs.size()); } static PartialApplicationPass s_pass;