opt/partial-application/PartialApplication.cpp (1,041 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
* This pass identifies commonly used constant arguments that methods are
* invoked with, and then introduces helper functions that bind those arguments
* if it seems beneficial to reduce the overall code size by rewriting the
* call-sites. The new helper methods are placed in the same class as the
* callee. Their name is stable, including a hash derived from the bound
* constant arguments.
*
* The most interesting part of this optimization, with likely further tuning
* potential, is the priority-queue based approach to find beneficial subsets of
* common constant arguments.
*
* While this is similar in spirit to what the InstructionSequenceOutliner does,
* a major difference is that this pass specifically targets individual method
* invocations, and it picks up incoming constant arguments based on our
* existing constant-propagation analysis, not caring where earlier in the code,
* or in which order, the constant are defined. And ultimately, it picks a
* beneficial subset of constant arguments regardless of what order they were
* defined in. In contrast, the InstructionSequenceOutliner requires precise
* matches of frequently occurring instruction opcodes sequences (but module
* register names) in order outline any particular call-site.
*
* Here's an example of what the optimization does. Let's say there's a method
* like this:
*
* void foo(int a, int b, Integer c);
*
* And it is invoked 10 times as
*
* foo(10, 20, Integer.valueOf(23));
*
* And another 10 times as
*
* foo(13, 20, Integer.valueOf(23));
*
* Let's say in neither case would a new helper function be beneficial to reduce
* size. However, when we trim off the first argument, we are left with 20 times
*
* foo(*, 20, Integer.valueOf(23));
*
* And this might be beneficial to transform. Then we introduce a helper
* function like the following.
*
* foo$pa$xxxx(int a) { foo(a, 20, Integer.valueOf(23)); }
*
* And rewrite the call-sites to
*
* foo$pa$xxxx(10);
*
* and
*
* foo$pa$xxxx(13);
*
* respectively.
*
* Various safe-guards are in place:
* - We won't introduce helper methods that would contain cross-store or
* non-min-sdk level references.
* - We only transform code with the largest root store id (so not in the
* primary dex, unless there only is one, and not in other auxiliary stores)
* - We won't rewrite code that sits in hot blocks in hot methods, or loops in
* warm methods (reuses logic from InstructionSequenceOutliner)
*
* We don't do anything special for symbolication. Thus, the new helper methods
* will appear in stack traces, but probably won't be confusing, as they have
* names derived from the final callee, appearing as some trampoline method. The
* code in the helper methods will never fail (except maybe under the most
* obscure circumstances such as a stack-overflow), and thus will never be on
* top of a stack trace, and only the top frame is used for symbolication.
*/
#include "PartialApplication.h"
#include <cinttypes>
#include <boost/format.hpp>
#include <boost/pending/disjoint_sets.hpp>
#include <boost/property_map/property_map.hpp>
#include "CFGMutation.h"
#include "CallSiteSummaries.h"
#include "ConfigFiles.h"
#include "Creators.h"
#include "LiveRange.h"
#include "MutablePriorityQueue.h"
#include "OutliningProfileGuidanceImpl.h"
#include "PassManager.h"
#include "ReachableClasses.h"
#include "RefChecker.h"
#include "Shrinker.h"
#include "SourceBlocks.h"
#include "StlUtil.h"
#include "Walkers.h"
using namespace inliner;
using namespace outliner;
using namespace outliner_impl;
using namespace live_range;
using namespace shrinker;
namespace {
// Overhead of introducing a typical new helper method and its metadata.
const size_t COST_METHOD = 28;
// Retrieve list of classes in primary dex, if there is more than one store and
// dexes.
std::unordered_set<const DexType*> get_excluded_classes(
DexStoresVector& stores) {
std::unordered_set<const DexType*> excluded_classes;
bool has_other_stores{false};
bool has_other_dexes{false};
for (auto& store : stores) {
if (store.is_root_store()) {
auto& dexen = store.get_dexen();
always_assert(!dexen.empty());
for (auto cls : dexen.front()) {
excluded_classes.insert(cls->get_type());
}
if (dexen.size() > 1) {
has_other_dexes = true;
}
} else {
has_other_stores = true;
}
}
if (!has_other_stores && !has_other_dexes) {
excluded_classes.clear();
}
return excluded_classes;
}
const api::AndroidSDK* get_min_sdk_api(ConfigFiles& conf, PassManager& mgr) {
int32_t min_sdk = mgr.get_redex_options().min_sdk;
mgr.incr_metric("min_sdk", min_sdk);
TRACE(PA, 2, "min_sdk: %d", min_sdk);
auto min_sdk_api_file = conf.get_android_sdk_api_file(min_sdk);
if (!min_sdk_api_file) {
mgr.incr_metric("min_sdk_no_file", 1);
TRACE(PA, 2, "Android SDK API %d file cannot be found.", min_sdk);
return nullptr;
} else {
return &conf.get_android_sdk_api(min_sdk);
}
}
using EnumUtilsCache = ConcurrentMap<int32_t, DexField*>;
// Check if we have a boxed value for which there is a $EnumUtils field.
DexField* try_get_enum_utils_f_field(EnumUtilsCache& cache,
const ObjectWithImmutAttr& object) {
// This matches EnumUtilsFieldAnalyzer::analyze_sget.
always_assert(object.jvm_cached_singleton);
always_assert(object.attributes.size() == 1);
if (object.type != type::java_lang_Integer()) {
return nullptr;
}
const auto& signed_value =
object.attributes.front().value.get<SignedConstantDomain>();
auto c = signed_value.get_constant();
always_assert(c);
DexField* res;
cache.update(*c, [&res](int32_t key, DexField*& value, bool exists) {
if (!exists) {
auto cls = type_class(DexType::make_type("Lredex/$EnumUtils;"));
if (cls) {
std::string field_name = "f" + std::to_string(key);
value = cls->find_sfield(field_name.c_str(), type::java_lang_Integer());
always_assert(!value || is_static(value));
}
}
res = value;
});
return res;
}
// Identify how many argument slots an invocation needs after expansion of wide
// types, and thus whether a range instruction will be needed.
std::pair<param_index_t, bool> analyze_args(const DexMethod* callee) {
const auto* args = callee->get_proto()->get_args();
param_index_t src_regs = args->size();
if (!is_static(callee)) {
src_regs++;
}
param_index_t expanded_src_regs{!is_static(callee)};
for (auto t : *args) {
expanded_src_regs += type::is_wide_type(t) ? 2 : 1;
}
auto needs_range = expanded_src_regs > 5;
return {src_regs, needs_range};
}
struct ArgExclusivity {
// between 0 and 1
float ownership{0};
bool needs_move{false};
};
struct AggregatedArgExclusivity {
double ownership{0};
uint32_t needs_move{0};
};
using ArgExclusivityVector =
std::vector<std::pair<src_index_t, ArgExclusivity>>;
// Determine whether, or to what extent, the instructions to compute arguments
// to an invocation are exclusive to that invocation. (If not, then eliminating
// the argument in the invocation likely won't give us expected cost savings.)
ArgExclusivityVector get_arg_exclusivity(const UseDefChains& use_def_chains,
const DefUseChains& def_use_chains,
bool needs_range,
IRInstruction* insn) {
ArgExclusivityVector aev;
for (param_index_t src_idx = 0; src_idx < insn->srcs_size(); src_idx++) {
const auto& defs = use_def_chains.at((Use){insn, src_idx});
if (defs.size() != 1) {
continue;
}
const auto def = *defs.begin();
bool other_use = false;
param_index_t count = 0;
for (const auto& use : def_use_chains.at(def)) {
if (!opcode::is_a_move(use.insn->opcode()) &&
(use.insn->opcode() != insn->opcode() ||
use.insn->get_method() != insn->get_method())) {
other_use = true;
break;
}
count++;
}
float ownership = other_use ? 0.0 : (1.0 / count);
// TODO: We also likely need a move if there are more than 16 args
// (including extra wides) live at this point.
bool needs_move = needs_range && (other_use || count > 1);
if (ownership > 0 || needs_move) {
aev.emplace_back(src_idx, (ArgExclusivity){ownership, needs_move});
}
}
return aev;
}
using CalleeCallerClasses =
std::unordered_map<const DexMethod*, std::unordered_set<const DexType*>>;
// Gather all (caller, callee) pairs. Also compute arg exclusivity, which invoke
// instructions we should exclude, and how many classes calls are distributed
// over.
void gather_caller_callees(
const ProfileGuidanceConfig& profile_guidance_config,
const Scope& scope,
const std::unordered_set<DexMethod*>& sufficiently_warm_methods,
const std::unordered_set<DexMethod*>& sufficiently_hot_methods,
const GetCalleeFunction& get_callee_fn,
MethodToMethodOccurrences* callee_caller,
MethodToMethodOccurrences* caller_callee,
std::unordered_map<const IRInstruction*, ArgExclusivityVector>*
arg_exclusivity,
std::unordered_set<const IRInstruction*>* excluded_invoke_insns,
CalleeCallerClasses* callee_caller_classes) {
Timer timer("gather_caller_callees");
using ConcurrentMethodToMethodOccurrences =
ConcurrentMap<const DexMethod*, std::unordered_map<DexMethod*, size_t>>;
ConcurrentMethodToMethodOccurrences concurrent_callee_caller;
ConcurrentMethodToMethodOccurrences concurrent_caller_callee;
ConcurrentSet<const IRInstruction*> concurrent_excluded_invoke_insns;
ConcurrentMap<const IRInstruction*, ArgExclusivityVector>
concurrent_arg_exclusivity;
ConcurrentMap<const DexMethod*, std::unordered_set<const DexType*>>
concurrent_callee_caller_classes;
walk::parallel::code(scope, [&](DexMethod* caller, IRCode& code) {
code.build_cfg(true);
CanOutlineBlockDecider block_decider(
profile_guidance_config, sufficiently_warm_methods.count(caller),
sufficiently_hot_methods.count(caller));
MoveAwareChains move_aware_chains(code.cfg());
const auto use_def_chains = move_aware_chains.get_use_def_chains();
const auto def_use_chains = move_aware_chains.get_def_use_chains();
for (auto& big_block : big_blocks::get_big_blocks(code.cfg())) {
auto can_outline = block_decider.can_outline_from_big_block(big_block) ==
CanOutlineBlockDecider::Result::CanOutline;
for (auto& mie : big_blocks::InstructionIterable(big_block)) {
auto insn = mie.insn;
auto callee = get_callee_fn(caller, insn);
if (!callee) {
continue;
}
if (!can_outline) {
concurrent_excluded_invoke_insns.insert(insn);
continue;
}
auto needs_range = analyze_args(callee).second;
auto ae = get_arg_exclusivity(use_def_chains, def_use_chains,
needs_range, insn);
if (ae.empty()) {
concurrent_excluded_invoke_insns.insert(insn);
continue;
}
concurrent_callee_caller.update(
callee,
[caller](const DexMethod*,
std::unordered_map<DexMethod*, size_t>& v,
bool) { ++v[caller]; });
concurrent_caller_callee.update(
caller,
[callee](const DexMethod*,
std::unordered_map<DexMethod*, size_t>& v,
bool) { ++v[callee]; });
concurrent_arg_exclusivity.emplace(insn, std::move(ae));
concurrent_callee_caller_classes.update(
callee,
[caller](const DexMethod*,
std::unordered_set<const DexType*>& value,
bool) { value.insert(caller->get_class()); });
}
}
});
for (auto& p : concurrent_callee_caller) {
callee_caller->insert(std::move(p));
}
for (auto& p : concurrent_caller_callee) {
caller_callee->insert(std::move(p));
}
excluded_invoke_insns->insert(concurrent_excluded_invoke_insns.begin(),
concurrent_excluded_invoke_insns.end());
for (auto& p : concurrent_arg_exclusivity) {
arg_exclusivity->insert(std::move(p));
}
for (auto& p : concurrent_callee_caller_classes) {
callee_caller_classes->insert(std::move(p));
}
}
using InvokeCallSiteSummaries =
std::unordered_map<const IRInstruction*, const CallSiteSummary*>;
// Whether to include a particular constant argument value. We only include
// actual constant (not just abstract value like NEZ), and only if they don't
// violate anything the ref-checker would complain about. We can also handle
// singletons and immutable objects if they represent jvm cached singletons.
bool filter(const RefChecker& ref_checker,
EnumUtilsCache& enum_utils_cache,
const ConstantValue& value) {
if (const auto& signed_value = value.maybe_get<SignedConstantDomain>()) {
return !!signed_value->get_constant();
} else if (const auto& singleton_value =
value.maybe_get<SingletonObjectDomain>()) {
auto field = *singleton_value->get_constant();
return ref_checker.check_field(field);
} else if (const auto& obj_or_none =
value.maybe_get<ObjectWithImmutAttrDomain>()) {
auto object = obj_or_none->get_constant();
if (!object->jvm_cached_singleton) {
return false;
}
if (DexField* field =
try_get_enum_utils_f_field(enum_utils_cache, *object)) {
return ref_checker.check_field(field);
} else {
always_assert(object->attributes.size() == 1);
const auto& signed_value2 =
object->attributes.front().value.maybe_get<SignedConstantDomain>();
always_assert(signed_value2);
return filter(ref_checker, enum_utils_cache, *signed_value2);
}
} else {
not_reached_log("unexpected value: %s", SHOW(value));
}
}
using CallSiteSummarySet = std::unordered_set<const CallSiteSummary*>;
using CallSiteSummaryVector = std::vector<const CallSiteSummary*>;
CallSiteSummaryVector order_csses(const CallSiteSummarySet& csses) {
CallSiteSummaryVector ordered_csses(csses.begin(), csses.end());
std::sort(ordered_csses.begin(), ordered_csses.end(),
[](const CallSiteSummary* a, const CallSiteSummary* b) {
return a->get_key() < b->get_key();
});
return ordered_csses;
}
// Priority-queue based algorithm to select which invocations and which constant
// arguments are beneficial to transform.
class CalleeInvocationSelector {
private:
EnumUtilsCache& m_enum_utils_cache;
CallSiteSummarizer& m_call_site_summarizer;
const DexMethod* m_callee;
const std::unordered_map<const IRInstruction*, ArgExclusivityVector>&
m_arg_exclusivity;
size_t m_callee_caller_classes;
param_index_t m_src_regs;
bool m_needs_range;
// When we are going to merge different call-site summaries after simplifying,
// we need to efficiently track what all the underlying call-site summaries
// were. We do that via a "disjoint_sets" data structure what all the
// underlying call-site summaries are.
using Rank = std::unordered_map<const CallSiteSummary*, size_t>;
using Parent =
std::unordered_map<const CallSiteSummary*, const CallSiteSummary*>;
using RankPMap = boost::associative_property_map<Rank>;
using ParentPMap = boost::associative_property_map<Parent>;
using CallSiteSummarySets = boost::disjoint_sets<RankPMap, ParentPMap>;
Rank m_rank;
Parent m_parent;
CallSiteSummarySets m_css_sets;
CallSiteSummarySet m_call_site_summaries;
using ArgumentCosts = std::unordered_map<src_index_t, int32_t>;
std::unordered_map<const CallSiteSummary*, ArgumentCosts>
m_call_site_summary_argument_costs;
using KeyedCosts = std::unordered_map<std::string, int32_t>;
std::vector<KeyedCosts> m_total_argument_costs;
using KeyedCsses = std::unordered_map<std::string, CallSiteSummarySet>;
std::vector<KeyedCsses> m_dependencies;
std::vector<std::pair<const IRInstruction*, const CallSiteSummary*>>
m_call_site_invoke_summaries;
std::unordered_map<const CallSiteSummary*,
std::unordered_map<src_index_t, AggregatedArgExclusivity>>
m_aggregated_arg_exclusivity;
static std::string get_key(const ConstantValue& value) {
std::ostringstream oss;
CallSiteSummary::append_key_value(oss, value);
return oss.str();
}
static int32_t sum_call_sites_savings(const ArgumentCosts& ac) {
int32_t savings = 0;
for (const auto& p : ac) {
savings += p.second;
}
return savings;
}
int16_t const_value_cost(const ConstantValue& value) const {
if (const auto& signed_value = value.maybe_get<SignedConstantDomain>()) {
auto c = signed_value->get_constant();
always_assert(c);
auto lit = *c;
if (lit < -2147483648 || lit > 2147483647) {
return 5;
} else if (lit < -32768 || lit > 32767) {
return 3;
} else if (lit < -8 || lit > 7) {
return 2;
} else {
return 1;
}
} else if (const auto& singleton_value =
value.maybe_get<SingletonObjectDomain>()) {
return 2;
} else if (const auto& obj_or_none =
value.maybe_get<ObjectWithImmutAttrDomain>()) {
auto object = obj_or_none->get_constant();
always_assert(object);
if (try_get_enum_utils_f_field(m_enum_utils_cache, *object)) {
return 2;
} else {
always_assert(object->jvm_cached_singleton);
always_assert(object->attributes.size() == 1);
const auto& signed_value2 =
object->attributes.front().value.maybe_get<SignedConstantDomain>();
always_assert(signed_value2);
return 3 + const_value_cost(*signed_value2);
}
} else {
not_reached_log("unexpected value: %s", SHOW(value));
}
}
std::pair<param_index_t, uint32_t> find_argument_with_least_cost(
const CallSiteSummary* css) const {
const auto& bindings = css->arguments.bindings();
boost::optional<int32_t> least_cost;
param_index_t least_cost_src_idx = 0;
for (auto& p : bindings) {
auto& arguments_cost = m_total_argument_costs.at(p.first);
auto it = arguments_cost.find(get_key(p.second));
auto cost = it == arguments_cost.end() ? 0 : it->second;
if (!least_cost || *least_cost > cost ||
(*least_cost == cost && p.first < least_cost_src_idx)) {
least_cost = cost;
least_cost_src_idx = p.first;
}
}
always_assert(least_cost);
return {least_cost_src_idx, *least_cost};
}
int32_t get_net_savings(const CallSiteSummary* css) const {
// The cost for an additional partial-application helper method consists
// of...
// - the basic overhead of having a method
// - an estimated cross-dex penalty, as the PartialApplication pass has to
// run before the InterDex pass, and adding extra method-refs has global
// negative effects on the number of needed cross-dex references.
// - an extra move-result instruction
// - the cost of const instructions
// - some extra potetnail move overhead if we need the range form
int32_t pa_cross_dex_penalty =
2 * std::ceil(std::sqrt(m_callee_caller_classes));
int32_t pa_method_cost =
COST_METHOD + pa_cross_dex_penalty + css->result_used;
const auto& bindings = css->arguments.bindings();
for (auto& r : bindings) {
pa_method_cost += const_value_cost(r.second);
}
if (m_needs_range) {
pa_method_cost += m_src_regs;
}
auto call_sites_savings =
sum_call_sites_savings(m_call_site_summary_argument_costs.at(css));
return call_sites_savings - pa_method_cost;
}
using Priority = uint64_t;
uint64_t m_running_index = 0;
Priority make_priority(const CallSiteSummary* css) {
// We order by...
// - (1 bit) whether net savings are positive
// - (31 bits) if not, (clipped) least argument costs (smaller is better)
// - (32 bits) running index to make the priority unique
auto net_savings = get_net_savings(css);
uint64_t positive = net_savings > 0;
uint64_t a = 0;
if (!positive) {
auto least_cost = find_argument_with_least_cost(css).second;
a = std::min<uint32_t>(least_cost, (1U << 31) - 1);
}
uint64_t b = m_running_index++;
always_assert(positive < 2);
always_assert(a < (1UL << 31));
always_assert(b < (1UL << 32));
return (positive << 63) | (a << 32) | b;
}
MutablePriorityQueue<const CallSiteSummary*, Priority> m_pq;
public:
CalleeInvocationSelector(
EnumUtilsCache& enum_utils_cache,
CallSiteSummarizer& call_site_summarizer,
const DexMethod* callee,
const std::unordered_map<const IRInstruction*, ArgExclusivityVector>&
arg_exclusivity,
size_t callee_caller_classes)
: m_enum_utils_cache(enum_utils_cache),
m_call_site_summarizer(call_site_summarizer),
m_callee(callee),
m_arg_exclusivity(arg_exclusivity),
m_callee_caller_classes(callee_caller_classes),
m_css_sets((RankPMap(m_rank)), (ParentPMap(m_parent))) {
auto callee_call_site_invokes =
call_site_summarizer.get_callee_call_site_invokes(callee);
if (!callee_call_site_invokes) {
return;
}
std::tie(m_src_regs, m_needs_range) = analyze_args(callee);
TRACE(
PA, 2,
"[PartialApplication] Processing %s, %zu caller classes, %u src regs%s",
SHOW(m_callee), callee_caller_classes, m_src_regs,
m_needs_range ? ", needs_range" : "");
m_total_argument_costs = std::vector<KeyedCosts>(m_src_regs, KeyedCosts());
m_dependencies = std::vector<KeyedCsses>(m_src_regs, KeyedCsses());
// Aggregate arg exclusivity across call-sites with the same summary.
for (auto invoke_insn : *callee_call_site_invokes) {
auto css =
call_site_summarizer.get_instruction_call_site_summary(invoke_insn);
if (css->arguments.is_top()) {
continue;
}
if (!is_static(callee) && !css->arguments.get(0).is_top()) {
// We don't want to deal with cases where an instance method is called
// with nullptr.
TRACE(PA, 2,
"[PartialApplication] Ignoring invocation of instance method %s "
"with %s",
SHOW(callee), css->get_key().c_str());
continue;
}
m_call_site_invoke_summaries.emplace_back(invoke_insn, css);
auto& aev = arg_exclusivity.at(invoke_insn);
auto& aaem = m_aggregated_arg_exclusivity[css];
for (auto& p : aev) {
auto& aae = aaem[p.first];
aae.ownership += p.second.ownership;
aae.needs_move += p.second.needs_move;
}
}
// For each call-site summary,
// - initialize disjoint set singleton, and
// - compute current constant argument costs that could potentially be saved
// when introducing partial-application helper method, and
// - keep track of which constant value for which parameter is involved in
// that call-site summary, which we'll need later when re-prioritizing
// call-site summaries in the priority queue.
for (auto& p : m_aggregated_arg_exclusivity) {
auto css = p.first;
auto& aaem = p.second;
m_call_site_summaries.insert(css);
m_css_sets.make_set(css);
auto& ac = m_call_site_summary_argument_costs[css];
const auto& bindings = css->arguments.bindings();
for (auto& q : bindings) {
const auto src_idx = q.first;
const auto& value = q.second;
auto& aae = aaem[src_idx];
int32_t cost =
const_value_cost(value) * aae.ownership + 2 * aae.needs_move;
ac.emplace(src_idx, cost);
auto key = get_key(value);
m_total_argument_costs.at(src_idx)[key] += cost;
m_dependencies.at(src_idx)[key].insert(css);
}
}
}
// Fill priority queue with raw data.
void fill_pq() {
// Populate priority queue
for (auto css : order_csses(m_call_site_summaries)) {
auto priority = make_priority(css);
TRACE(PA, 4,
"[PartialApplication] Considering %s(%s): net savings %d, priority "
"%016" PRIx64,
SHOW(m_callee), css->get_key().c_str(), get_net_savings(css),
priority);
m_pq.insert(css, priority);
}
}
// For all items in the queue which have non-positive net savings, chop
// off the argument with least cost, and lump it together with any
// possibly already existing item.
void reduce_pq() {
while (!m_pq.empty() && get_net_savings(m_pq.back()) <= 0) {
auto css = m_pq.back();
m_pq.erase(css);
auto ac_it = m_call_site_summary_argument_costs.find(css);
auto ac = std::move(ac_it->second);
m_call_site_summary_argument_costs.erase(ac_it);
for (auto& p : css->arguments.bindings()) {
bool erased =
m_dependencies.at(p.first).at(get_key(p.second)).erase(css);
always_assert(erased);
}
auto [src_idx, least_cost] = find_argument_with_least_cost(css);
always_assert(!css->arguments.get(src_idx).is_top());
auto key = get_key(css->arguments.get(src_idx));
m_total_argument_costs.at(src_idx).at(key) -= ac.at(src_idx);
CallSiteSummary reduced_css_val{css->arguments, css->result_used};
reduced_css_val.arguments.set(src_idx, ConstantValue::top());
if (reduced_css_val.arguments.is_top()) {
TRACE(PA, 4,
"[PartialApplication] Removing %s(%s) with least cost %u@%u",
SHOW(m_callee), css->get_key().c_str(), least_cost, src_idx);
} else {
auto reduced_css = m_call_site_summarizer.internalize_call_site_summary(
reduced_css_val);
ac_it = m_call_site_summary_argument_costs.find(reduced_css);
if (ac_it == m_call_site_summary_argument_costs.end()) {
ac_it = m_call_site_summary_argument_costs
.emplace(reduced_css, ArgumentCosts())
.first;
for (auto& p : reduced_css->arguments.bindings()) {
bool inserted = m_dependencies.at(p.first)
.at(get_key(p.second))
.insert(reduced_css)
.second;
always_assert(inserted);
}
} else {
m_pq.erase(reduced_css);
}
for (auto& p : ac) {
ac_it->second[p.first] += p.second;
}
ac_it->second.erase(src_idx);
m_pq.insert(reduced_css, make_priority(reduced_css));
if (m_call_site_summaries.insert(reduced_css).second) {
m_css_sets.make_set(reduced_css);
}
m_css_sets.union_set(css, reduced_css);
TRACE(PA, 4,
"[PartialApplication] Merging %s(%s ===> %s) with least cost "
"%u@%u: net savings %d",
SHOW(m_callee), css->get_key().c_str(),
reduced_css->get_key().c_str(), least_cost, src_idx,
get_net_savings(reduced_css));
}
const auto& csses = m_dependencies.at(src_idx).at(key);
for (auto dependent_css : order_csses(csses)) {
TRACE(PA, 4, "[PartialApplication] Reprioritizing %s(%s)",
SHOW(m_callee), dependent_css->get_key().c_str());
m_pq.update_priority(dependent_css, make_priority(dependent_css));
}
}
}
// Identify all invocations which contributed to groups with combined positive
// expected savings.
void select_invokes(std::atomic<size_t>* total_estimated_savings,
InvokeCallSiteSummaries* selected_invokes) {
size_t partial_application_methods{0};
std::unordered_map<const CallSiteSummary*, const CallSiteSummary*>
selected_css_sets;
uint32_t callee_estimated_savings = 0;
while (!m_pq.empty()) {
auto css = m_pq.front();
auto net_savings = get_net_savings(css);
m_pq.erase(css);
selected_css_sets.emplace(m_css_sets.find_set(css), css);
callee_estimated_savings += net_savings;
partial_application_methods++;
TRACE(PA, 3, "[PartialApplication] Selected %s(%s) with net savings %d",
SHOW(m_callee), css->get_key().c_str(), net_savings);
always_assert(net_savings > 0);
}
for (auto& p : m_call_site_invoke_summaries) {
auto invoke_insn = p.first;
auto css = p.second;
if (!m_call_site_summaries.count(css)) {
continue;
}
auto it = selected_css_sets.find(m_css_sets.find_set(css));
if (it == selected_css_sets.end()) {
continue;
}
auto reduced_css = it->second;
// This invoke got selected because including it together with all
// other invokes with the same css was beneficial on average. Check
// (and filter out) if it's not actually beneficial for this particular
// invoke.
auto& aev = m_arg_exclusivity.at(invoke_insn);
const auto& bindings = reduced_css->arguments.bindings();
if (std::find_if(aev.begin(), aev.end(), [&bindings](auto& q) {
return !bindings.at(q.first).is_top();
}) == aev.end()) {
continue;
}
selected_invokes->emplace(invoke_insn, reduced_css);
}
if (callee_estimated_savings > 0) {
TRACE(PA, 2,
"[PartialApplication] Selected %s(...) for %zu constant argument "
"combinations across %zu invokes with net savings %u",
SHOW(m_callee), partial_application_methods,
selected_invokes->size(), callee_estimated_savings);
*total_estimated_savings += callee_estimated_savings;
}
}
};
// From a call-site summary that include constant-arguments, derive the
// signature of the new helper methods that will bind them.
DexTypeList* get_partial_application_args(bool callee_is_static,
const DexProto* callee_proto,
const CallSiteSummary* css) {
const auto* args = callee_proto->get_args();
DexTypeList::ContainerType new_args;
param_index_t offset = 0;
if (!callee_is_static) {
always_assert(css->arguments.get(0).is_top());
offset++;
}
for (param_index_t i = 0; i < args->size(); i++) {
if (css->arguments.get(offset + i).is_top()) {
new_args.push_back(args->at(i));
}
}
return DexTypeList::make_type_list(std::move(new_args));
}
uint64_t get_stable_hash(uint64_t a, uint64_t b) { return a ^ b; }
uint64_t get_stable_hash(const std::string& s) {
uint64_t stable_hash{s.size()};
for (auto c : s) {
stable_hash = stable_hash * 7 + c;
}
return stable_hash;
}
using PaMethodRefs = ConcurrentMap<CalleeCallSiteSummary,
DexMethodRef*,
boost::hash<CalleeCallSiteSummary>>;
// Run the analysis over all callees.
void select_invokes_and_callers(
EnumUtilsCache& enum_utils_cache,
CallSiteSummarizer& call_site_summarizer,
const MethodToMethodOccurrences& callee_caller,
const std::unordered_map<const IRInstruction*, ArgExclusivityVector>&
arg_exclusivity,
const CalleeCallerClasses& callee_caller_classes,
size_t iteration,
std::atomic<size_t>* total_estimated_savings,
PaMethodRefs* pa_method_refs,
InvokeCallSiteSummaries* selected_invokes,
std::unordered_set<DexMethod*>* selected_callers) {
Timer t("select_invokes_and_callers");
std::vector<const DexMethod*> callees;
std::unordered_map<const DexType*, std::vector<const DexMethod*>>
callees_by_classes;
std::unordered_map<const DexMethod*, InvokeCallSiteSummaries>
selected_invokes_by_callees;
for (auto& p : callee_caller) {
auto callee = p.first;
callees.push_back(callee);
callees_by_classes[callee->get_class()].push_back(callee);
selected_invokes_by_callees[callee];
}
workqueue_run<const DexMethod*>(
[&](const DexMethod* callee) {
CalleeInvocationSelector cis(enum_utils_cache, call_site_summarizer,
callee, arg_exclusivity,
callee_caller_classes.at(callee).size());
cis.fill_pq();
cis.reduce_pq();
cis.select_invokes(total_estimated_savings,
&selected_invokes_by_callees.at(callee));
},
callees);
std::vector<const DexType*> callee_classes;
callee_classes.reserve(callees_by_classes.size());
for (auto& p : callees_by_classes) {
callee_classes.push_back(p.first);
}
std::mutex mutex;
workqueue_run<const DexType*>(
[&](const DexType* callee_class) {
auto& class_callees = callees_by_classes.at(callee_class);
std::sort(class_callees.begin(), class_callees.end(),
compare_dexmethods);
std::unordered_map<uint64_t, uint32_t> stable_hash_indices;
for (auto callee : class_callees) {
auto& callee_selected_invokes =
selected_invokes_by_callees.at(callee);
if (callee_selected_invokes.empty()) {
continue;
}
auto callee_stable_hash = get_stable_hash(show(callee));
std::map<const DexTypeList*,
std::unordered_set<const CallSiteSummary*>,
dextypelists_comparator>
ordered_pa_args_csses;
auto callee_is_static = is_static(callee);
auto callee_proto = callee->get_proto();
for (auto& p : callee_selected_invokes) {
auto css = p.second;
auto pa_args = get_partial_application_args(callee_is_static,
callee_proto, css);
auto inserted = ordered_pa_args_csses[pa_args].insert(css).second;
always_assert(true);
}
for (auto& p : ordered_pa_args_csses) {
auto pa_args = p.first;
auto& csses = p.second;
for (auto css : order_csses(csses)) {
auto css_stable_hash = get_stable_hash(css->get_key());
auto stable_hash =
get_stable_hash(callee_stable_hash, css_stable_hash);
auto stable_hash_index = stable_hash_indices[stable_hash]++;
std::ostringstream oss;
oss << callee->get_name()->str()
<< (is_static(callee) ? "$spa$" : "$ipa$") << iteration << "$"
<< ((boost::format("%08x") % stable_hash).str()) << "$"
<< stable_hash_index;
auto pa_name = DexString::make_string(oss.str());
auto pa_rtype =
css->result_used ? callee_proto->get_rtype() : type::_void();
auto pa_proto = DexProto::make_proto(pa_rtype, pa_args);
auto pa_type = callee->get_class();
auto pa_method_ref =
DexMethod::make_method(pa_type, pa_name, pa_proto);
CalleeCallSiteSummary ccss{callee, css};
pa_method_refs->emplace(ccss, pa_method_ref);
}
}
std::lock_guard<std::mutex> lock_guard(mutex);
selected_invokes->insert(callee_selected_invokes.begin(),
callee_selected_invokes.end());
for (auto& p : callee_caller.at(callee)) {
selected_callers->insert(p.first);
}
}
},
callee_classes);
}
IROpcode get_invoke_opcode(const DexMethod* callee) {
return callee->is_virtual() ? OPCODE_INVOKE_VIRTUAL
: is_static(callee) ? OPCODE_INVOKE_STATIC
: OPCODE_INVOKE_DIRECT;
}
// Given the analysis results, rewrite all callers to invoke the new helper
// methods with bound arguments.
void rewrite_callers(
const Scope& scope,
Shrinker& shrinker,
const GetCalleeFunction& get_callee_fn,
const std::unordered_map<const IRInstruction*, const CallSiteSummary*>&
selected_invokes,
const std::unordered_set<DexMethod*>& selected_callers,
PaMethodRefs& pa_method_refs,
std::atomic<size_t>* removed_args) {
Timer t("rewrite_callers");
auto make_partial_application_invoke_insn =
[&](DexMethod* caller, IRInstruction* insn) -> IRInstruction* {
if (!opcode::is_an_invoke(insn->opcode())) {
return nullptr;
}
auto it = selected_invokes.find(insn);
if (it == selected_invokes.end()) {
return nullptr;
}
auto callee = get_callee_fn(caller, insn);
always_assert(callee != nullptr);
auto css = it->second;
CalleeCallSiteSummary ccss{callee, css};
DexMethodRef* pa_method_ref = pa_method_refs.at_unsafe(ccss);
auto new_insn = (new IRInstruction(get_invoke_opcode(callee)))
->set_method(pa_method_ref);
new_insn->set_srcs_size(insn->srcs_size() - css->arguments.size());
param_index_t idx = 0;
for (param_index_t i = 0; i < insn->srcs_size(); i++) {
if (css->arguments.get(i).is_top()) {
new_insn->set_src(idx++, insn->src(i));
}
}
always_assert(idx == new_insn->srcs_size());
return new_insn;
};
walk::parallel::code(scope, [&](DexMethod* caller, IRCode& code) {
if (selected_callers.count(caller)) {
bool any_changes{false};
auto& cfg = code.cfg();
cfg::CFGMutation mutation(cfg);
auto ii = InstructionIterable(cfg);
size_t removed_srcs{0};
for (auto it = ii.begin(); it != ii.end(); it++) {
auto new_invoke_insn =
make_partial_application_invoke_insn(caller, it->insn);
if (!new_invoke_insn) {
continue;
}
removed_srcs += it->insn->srcs_size() - new_invoke_insn->srcs_size();
std::vector<IRInstruction*> new_insns{new_invoke_insn};
auto move_result_it = cfg.move_result_of(it);
if (!move_result_it.is_end()) {
new_insns.push_back(new IRInstruction(*move_result_it->insn));
}
mutation.replace(it, new_insns);
any_changes = true;
}
mutation.flush();
if (any_changes) {
TRACE(PA, 6, "[PartialApplication] Rewrote %s:\n%s", SHOW(caller),
SHOW(cfg));
shrinker.shrink_method(caller);
(*removed_args) += removed_srcs;
}
}
code.clear_cfg();
});
}
// Helper used to build the partial-assignment helper methods.
void push_callee_arg(EnumUtilsCache& enum_utils_cache,
DexType* type,
const ConstantValue& value,
MethodCreator* method_creator,
MethodBlock* main_block,
std::vector<Location>* callee_args) {
if (const auto& signed_value = value.maybe_get<SignedConstantDomain>()) {
auto c = signed_value->get_constant();
always_assert(c);
auto tmp = method_creator->make_local(type);
main_block->load_const(tmp, *c, type);
callee_args->push_back(tmp);
} else if (const auto& singleton_value =
value.maybe_get<SingletonObjectDomain>()) {
auto c = singleton_value->get_constant();
always_assert(c);
auto field = *c;
always_assert(is_static(field));
auto tmp = method_creator->make_local(type);
main_block->sfield_op(opcode::sget_opcode_for_field(field),
const_cast<DexField*>(field), tmp);
callee_args->push_back(tmp);
} else if (const auto& obj_or_none =
value.maybe_get<ObjectWithImmutAttrDomain>()) {
auto object = obj_or_none->get_constant();
always_assert(object);
if (DexField* field =
try_get_enum_utils_f_field(enum_utils_cache, *object)) {
auto tmp = method_creator->make_local(field->get_type());
main_block->sfield_op(opcode::sget_opcode_for_field(field), field, tmp);
callee_args->push_back(tmp);
} else {
always_assert(object->jvm_cached_singleton);
always_assert(object->attributes.size() == 1);
auto valueOf = type::get_value_of_method_for_type(object->type);
auto valueOf_arg_type = valueOf->get_proto()->get_args()->at(0);
auto tmp = method_creator->make_local(valueOf_arg_type);
const auto& signed_value2 =
object->attributes.front().value.maybe_get<SignedConstantDomain>();
always_assert(signed_value2);
auto c = signed_value2->get_constant();
always_assert(c);
main_block->load_const(tmp, *c, valueOf_arg_type);
main_block->invoke(OPCODE_INVOKE_STATIC, valueOf, {tmp});
tmp = method_creator->make_local(type);
main_block->move_result(tmp, type);
callee_args->push_back(tmp);
}
} else {
not_reached_log("unexpected value: %s", SHOW(value));
}
}
// Create all new helper methods that bind constant arguments
void create_partial_application_methods(EnumUtilsCache& enum_utils_cache,
PaMethodRefs& pa_method_refs) {
Timer t("create_partial_application_methods");
std::map<DexMethodRef*, const CalleeCallSiteSummary*, dexmethods_comparator>
inverse_ordered_pa_method_refs;
for (auto& p : pa_method_refs) {
bool success =
inverse_ordered_pa_method_refs.emplace(p.second, &p.first).second;
always_assert(success);
}
for (auto& p : inverse_ordered_pa_method_refs) {
auto pa_method_ref = p.first;
auto callee = p.second->method;
auto cls = type_class(callee->get_class());
always_assert(cls);
auto css = p.second->call_site_summary;
auto access = callee->get_access() & ~(ACC_ABSTRACT | ACC_NATIVE);
if (callee->is_virtual()) {
access |= ACC_FINAL;
}
MethodCreator method_creator(pa_method_ref, access);
auto main_block = method_creator.get_main_block();
std::vector<Location> callee_args;
param_index_t offset = 0;
param_index_t next_arg_idx = 0;
if (!is_static(callee)) {
always_assert(css->arguments.get(0).is_top());
offset++;
callee_args.push_back(method_creator.get_local(next_arg_idx++));
}
auto proto = callee->get_proto();
const auto* args = proto->get_args();
for (param_index_t i = 0; i < args->size(); i++) {
const auto& value = css->arguments.get(offset + i);
if (value.is_top()) {
callee_args.push_back(method_creator.get_local(next_arg_idx++));
} else {
push_callee_arg(enum_utils_cache, args->at(i), value, &method_creator,
main_block, &callee_args);
}
}
main_block->invoke(get_invoke_opcode(callee),
const_cast<DexMethod*>(callee), callee_args);
if (css->result_used) {
auto tmp = method_creator.make_local(proto->get_rtype());
main_block->move_result(tmp, proto->get_rtype());
main_block->ret(tmp);
} else {
main_block->ret_void();
}
auto pa_method = method_creator.create();
pa_method->rstate.set_generated();
pa_method->rstate.set_dont_inline();
if (!is_static(callee) && is_public(callee)) {
pa_method->set_virtual(true);
}
pa_method->set_deobfuscated_name(show_deobfuscated(pa_method));
cls->add_method(pa_method);
TRACE(PA, 5, "[PartialApplication] Created %s binding %s:\n%s",
SHOW(pa_method), css->get_key().c_str(), SHOW(pa_method->get_code()));
}
}
} // namespace
void PartialApplicationPass::bind_config() {
auto& pg = m_profile_guidance_config;
bind("use_method_profiles", pg.use_method_profiles, pg.use_method_profiles,
"Whether to use provided method-profiles configuration data to "
"determine if certain code should not be outlined from a method");
bind("method_profiles_appear_percent",
pg.method_profiles_appear_percent,
pg.method_profiles_appear_percent,
"Cut off when a method in a method profile is deemed relevant");
bind("method_profiles_hot_call_count",
pg.method_profiles_hot_call_count,
pg.method_profiles_hot_call_count,
"No code is outlined out of hot methods");
bind("method_profiles_warm_call_count",
pg.method_profiles_warm_call_count,
pg.method_profiles_warm_call_count,
"Loops are not outlined from warm methods");
std::string perf_sensitivity_str;
bind("perf_sensitivity", "always-hot", perf_sensitivity_str);
bind("block_profiles_hits",
pg.block_profiles_hits,
pg.block_profiles_hits,
"No code is outlined out of hot blocks in hot methods");
after_configuration([=]() {
always_assert(!perf_sensitivity_str.empty());
m_profile_guidance_config.perf_sensitivity =
parse_perf_sensitivity(perf_sensitivity_str);
});
}
void PartialApplicationPass::run_pass(DexStoresVector& stores,
ConfigFiles& conf,
PassManager& mgr) {
const auto scope = build_class_scope(stores);
init_classes::InitClassesWithSideEffects init_classes_with_side_effects(
scope, conf.create_init_class_insns());
auto excluded_classes = get_excluded_classes(stores);
int min_sdk = mgr.get_redex_options().min_sdk;
auto min_sdk_api = get_min_sdk_api(conf, mgr);
XStoreRefs xstores(stores);
// RefChecker store_idx is initialized with `largest_root_store_id()`, so that
// it rejects all the references from stores with id larger than the largest
// root_store id.
RefChecker ref_checker(&xstores, xstores.largest_root_store_id(),
min_sdk_api);
std::unordered_set<DexMethod*> sufficiently_warm_methods;
std::unordered_set<DexMethod*> sufficiently_hot_methods;
gather_sufficiently_warm_and_hot_methods(
scope, conf, mgr, m_profile_guidance_config, &sufficiently_warm_methods,
&sufficiently_hot_methods);
mgr.incr_metric("num_sufficiently_warm_methods",
sufficiently_warm_methods.size());
mgr.incr_metric("num_sufficiently_hot_methods",
sufficiently_hot_methods.size());
ShrinkerConfig shrinker_config;
shrinker_config.run_local_dce = true;
shrinker_config.compute_pure_methods = false;
Shrinker shrinker(stores, scope, init_classes_with_side_effects,
shrinker_config, min_sdk);
std::unordered_set<const IRInstruction*> excluded_invoke_insns;
auto get_callee_fn = [&excluded_classes, &excluded_invoke_insns](
DexMethod* caller,
IRInstruction* insn) -> DexMethod* {
if (!opcode::is_an_invoke(insn->opcode()) ||
insn->opcode() == OPCODE_INVOKE_SUPER ||
method::is_init(insn->get_method()) ||
excluded_invoke_insns.count(insn) ||
caller->rstate.no_optimizations() ||
excluded_classes.count(caller->get_class())) {
return nullptr;
}
auto callee =
resolve_method(insn->get_method(), opcode_to_search(insn), caller);
if (!callee || callee->is_external()) {
return nullptr;
}
auto cls = type_class(callee->get_class());
if (!cls || cls->is_external() || is_native(cls) ||
excluded_classes.count(cls->get_type())) {
return nullptr;
}
// We'd add helper methods to the class, so we also want to avoid that it's
// being used via reflection.
if (!can_rename(cls)) {
return nullptr;
}
// TODO: Support interface callees.
if (is_interface(cls)) {
return nullptr;
}
return callee;
};
MethodToMethodOccurrences callee_caller;
MethodToMethodOccurrences caller_callee;
std::unordered_map<const IRInstruction*, ArgExclusivityVector>
arg_exclusivity;
CalleeCallerClasses callee_caller_classes;
gather_caller_callees(
m_profile_guidance_config, scope, sufficiently_warm_methods,
sufficiently_hot_methods, get_callee_fn, &callee_caller, &caller_callee,
&arg_exclusivity, &excluded_invoke_insns, &callee_caller_classes);
TRACE(PA, 1, "[PartialApplication] %zu callers, %zu callees",
caller_callee.size(), callee_caller.size());
// By indicating to the call-site summarizer that any callee may have other
// call-sites, we effectively disable top-down constant-propagation, as that
// would be unlikely to find true constants, and yet would take more time by
// limiting parallelism.
auto has_callee_other_call_sites_fn = [](DexMethod*) -> bool { return true; };
EnumUtilsCache enum_utils_cache;
std::function<bool(const ConstantValue& value)> filter_fn =
[&](const ConstantValue& value) {
return filter(ref_checker, enum_utils_cache, value);
};
CallSiteSummaryStats call_site_summarizer_stats;
CallSiteSummarizer call_site_summarizer(
shrinker, callee_caller, caller_callee, get_callee_fn,
has_callee_other_call_sites_fn, &filter_fn, &call_site_summarizer_stats);
call_site_summarizer.summarize();
std::atomic<size_t> total_estimated_savings{0};
PaMethodRefs pa_method_refs;
std::unordered_map<const IRInstruction*, const CallSiteSummary*>
selected_invokes;
std::unordered_set<DexMethod*> selected_callers;
select_invokes_and_callers(
enum_utils_cache, call_site_summarizer, callee_caller, arg_exclusivity,
callee_caller_classes, m_iteration++, &total_estimated_savings,
&pa_method_refs, &selected_invokes, &selected_callers);
std::atomic<size_t> removed_args{0};
rewrite_callers(scope, shrinker, get_callee_fn, selected_invokes,
selected_callers, pa_method_refs, &removed_args);
create_partial_application_methods(enum_utils_cache, pa_method_refs);
TRACE(PA, 1,
"[PartialApplication] Created %zu methods with particular constant "
"argument combinations, rewriting %zu invokes across %zu callers, "
"removing %zu args, with (estimated) net savings %zu",
pa_method_refs.size(), selected_invokes.size(), selected_callers.size(),
(size_t)removed_args, (size_t)total_estimated_savings);
mgr.incr_metric("total_estimated_savings", total_estimated_savings);
mgr.incr_metric("rewritten_invokes", selected_invokes.size());
mgr.incr_metric("removed_args", removed_args);
mgr.incr_metric("affected_callers", selected_callers.size());
mgr.incr_metric("partial_application_methods", pa_method_refs.size());
}
static PartialApplicationPass s_pass;