opt/optimize_enums/OptimizeEnums.cpp (624 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "OptimizeEnums.h"
#include "ClassAssemblingUtils.h"
#include "ConfigFiles.h"
#include "EnumAnalyzeGeneratedMethods.h"
#include "EnumClinitAnalysis.h"
#include "EnumInSwitch.h"
#include "EnumTransformer.h"
#include "EnumUpcastAnalysis.h"
#include "IRCode.h"
#include "MatchFlow.h"
#include "OptimizeEnumsAnalysis.h"
#include "PassManager.h"
#include "ProguardMap.h"
#include "Resolver.h"
#include "ScopedCFG.h"
#include "SwitchEquivFinder.h"
#include "Trace.h"
#include "Walkers.h"
/**
* 1. The pass tries to remove synthetic switch map classes for enums
* completely, by replacing the access to thelookup table with the use of the
* enum ordinal itself.
* Background of synthetic switch map classes:
* javac converts enum switches to a packed switch. In order to do this, for
* every use of an enum in a switch statement, an anonymous class is generated
* in the class the switchis defined. This class will contain ONLY lookup
* tables (array) as static fields and a static initializer.
*
* 2. Try to replace enum objects with boxed Integer objects based on static
* analysis results.
*/
namespace {
// Map the field holding the lookup table to its associated enum type.
using LookupTableToEnum = std::unordered_map<DexField*, DexType*>;
// Map the static fields holding enumerands to their ordinal number (passed in
// to their constructor).
using EnumFieldToOrdinal = std::unordered_map<DexField*, size_t>;
// Sets of types. Intended to be sub-classes of Ljava/lang/Enum; but not
// guaranteed by the type.
using EnumTypes = std::unordered_set<DexType*>;
// Lookup tables in generated classes map enum ordinals to the integers they
// are represented by in switch statements using that lookup table:
//
// lookup[enum.ordinal()] = case;
//
// GeneratedSwitchCases represent the reverse mapping for a lookup table:
//
// gsc[lookup][case] = enum
//
// with lookup and enum identified by their fields.
using GeneratedSwitchCases =
std::unordered_map<DexField*, std::unordered_map<size_t, DexField*>>;
constexpr const char* METRIC_NUM_SYNTHETIC_CLASSES = "num_synthetic_classes";
constexpr const char* METRIC_NUM_LOOKUP_TABLES = "num_lookup_tables";
constexpr const char* METRIC_NUM_LOOKUP_TABLES_REMOVED =
"num_lookup_tables_replaced";
constexpr const char* METRIC_NUM_ENUM_CLASSES = "num_candidate_enum_classes";
constexpr const char* METRIC_NUM_ENUM_OBJS = "num_erased_enum_objs";
constexpr const char* METRIC_NUM_INT_OBJS = "num_generated_int_objs";
constexpr const char* METRIC_NUM_SWITCH_EQUIV_FINDER_FAILURES =
"num_switch_equiv_finder_failures";
constexpr const char* METRIC_NUM_CANDIDATE_GENERATED_METHODS =
"num_candidate_generated_enum_methods";
constexpr const char* METRIC_NUM_REMOVED_GENERATED_METHODS =
"num_removed_generated_enum_methods";
/**
* Simple analysis to determine which of the enums ctor argument
* is passed for the ordinal.
*
* Background: The ordinal for each enum instance is set through the
* super class's constructor.
*
* Here we determine for each constructor, which of the arguments is used
* to set the ordinal.
*/
bool analyze_enum_ctors(
const DexClass* cls,
const DexMethod* java_enum_ctor,
std::unordered_map<const DexMethod*, uint32_t>& ctor_to_arg_ordinal) {
struct DelegatingCall {
DexMethod* ctor;
cfg::ScopedCFG cfg;
IRInstruction* invoke;
DelegatingCall(DexMethod* ctor, cfg::ScopedCFG cfg, IRInstruction* invoke)
: ctor{ctor}, cfg{std::move(cfg)}, invoke{invoke} {}
};
std::queue<DelegatingCall> delegating_calls;
{ // Find delegate constructor calls and queue them up to be processed. The
// call might be to `Enum.<init>(String;I)` or to a difference constructor
// of the same class.
mf::flow_t f;
auto inv = f.insn(m::invoke_direct_(m::has_method(m::resolve_method(
MethodSearch::Direct,
m::equals(java_enum_ctor) ||
m::is_constructor<DexMethod>() &&
m::member_of<DexMethod>(m::equals(cls->get_type()))))));
for (const auto& ctor : cls->get_ctors()) {
auto code = ctor->get_code();
if (!code) {
return false;
}
cfg::ScopedCFG cfg{code};
auto res = f.find(*cfg, inv);
if (auto* inv_insn = res.matching(inv).unique()) {
delegating_calls.emplace(ctor, std::move(cfg), inv_insn);
} else {
return false;
}
}
}
// Ordinal represents the third argument.
// details: https://developer.android.com/reference/java/lang/Enum.html
ctor_to_arg_ordinal[java_enum_ctor] = 2;
// TODO: We could order them instead of looping ...
for (; !delegating_calls.empty(); delegating_calls.pop()) {
auto dc = std::move(delegating_calls.front());
auto* delegate =
resolve_method(dc.invoke->get_method(), MethodSearch::Direct);
uint32_t delegate_ordinal;
{ // Only proceed if the delegate constructor has already been processed.
auto it = ctor_to_arg_ordinal.find(delegate);
if (it == ctor_to_arg_ordinal.end()) {
delegating_calls.emplace(std::move(dc));
continue;
} else {
delegate_ordinal = it->second;
}
}
// Track which param in dc.ctor flows into the ordinal arg of the delegate.
mf::flow_t f;
auto param = f.insn(m::load_param_());
auto invoke_delegate =
f.insn(m::equals(dc.invoke))
.src(delegate_ordinal, param, mf::unique | mf::alias);
auto res = f.find(*dc.cfg, invoke_delegate);
auto* load_ordinal = res.matching(param).unique();
if (!load_ordinal) {
// Couldn't find a unique parameter flowing into the ordinal argument.
return false;
}
// Figure out which param is being loaded.
uint32_t ctor_ordinal = 0;
auto ii = InstructionIterable(dc.cfg->get_param_instructions());
for (auto it = ii.begin(), end = ii.end();; ++it, ++ctor_ordinal) {
always_assert(it != end && "Unable to locate load_ordinal");
if (it->insn == load_ordinal) break;
}
ctor_to_arg_ordinal[dc.ctor] = ctor_ordinal;
}
return true;
}
/**
* Discover the mapping from enums to cases in lookup tables defined on
* `generated_cls` by detecting the following patterns in its `<clinit>` (modulo
* ordering and interleaved unrelated instructions):
*
* sget-object <lookup>
* move-result-pseudo-object v0
*
* Or:
*
* new-array ..., [I
* move-result-pseudo-object v0
* sput-object v0, <lookup>
*
* Followed by:
*
* sget-object <enum>
* move-result-pseudo-object v1
* invoke-virtual {v1}, Ljava/lang/Enum;.ordinal:()I
* move-result v2
* const v3, <kase>
* aput v3, v0, v2
*
* For each instance of the pattern found, a `generated_switch_cases` entry is
* added:
*
* generated_switch_cases[lookup][kase] = enum;
*/
void collect_generated_switch_cases(
DexClass* generated_cls,
cfg::ControlFlowGraph& clinit_cfg,
const EnumTypes& collected_enums,
GeneratedSwitchCases& generated_switch_cases) {
mf::flow_t f;
DexMethod* Enum_ordinal =
resolve_method(DexMethod::get_method("Ljava/lang/Enum;.ordinal:()I"),
MethodSearch::Virtual);
always_assert(Enum_ordinal);
auto m__generated_field = m::has_field(
m::member_of<DexFieldRef>(m::equals(generated_cls->get_type())));
auto m__lookup = m::sget_object_(m__generated_field) || m::new_array_();
auto m__sget_enum = m::sget_object_(m::has_field(
m::member_of<DexFieldRef>(m::in<DexType*>(collected_enums))));
auto m__invoke_ordinal = m::invoke_virtual_(m::has_method(
m::resolve_method(MethodSearch::Virtual, m::equals(Enum_ordinal))));
auto uniq = mf::alias | mf::unique;
auto look = f.insn(m__lookup);
auto gete = f.insn(m__sget_enum);
auto kase = f.insn(m::const_());
auto ordi = f.insn(m__invoke_ordinal).src(0, gete, uniq);
auto aput = f.insn(m::aput_())
.src(0, kase, uniq)
.src(1, look, uniq)
.src(2, ordi, uniq);
auto res = f.find(clinit_cfg, aput);
std::unordered_map<IRInstruction*, IRInstruction*> new_array_to_sput;
for (auto* insn_look : res.matching(look)) {
if (opcode::is_new_array(insn_look->opcode())) {
new_array_to_sput.emplace(insn_look, nullptr);
}
}
// Some lookup tables are accessed fresh rather than via an sget-object, so
// look at where the new arrays are put to determine the field.
if (!new_array_to_sput.empty()) {
mf::flow_t g;
auto m__sput_lookup = m::sput_object_(m__generated_field);
auto newa = g.insn(m::in<IRInstruction*>(new_array_to_sput));
auto sput = g.insn(m__sput_lookup).src(0, newa, uniq);
auto res_sputs = g.find(clinit_cfg, sput);
for (auto* insn_sput : res_sputs.matching(sput)) {
auto* insn_newa = res_sputs.matching(sput, insn_sput, 0).unique();
new_array_to_sput[insn_newa] = insn_sput;
}
}
for (auto* insn_aput : res.matching(aput)) {
auto* insn_kase = res.matching(aput, insn_aput, 0).unique();
auto* insn_look = res.matching(aput, insn_aput, 1).unique();
auto* insn_ordi = res.matching(aput, insn_aput, 2).unique();
auto* insn_gete = res.matching(ordi, insn_ordi, 0).unique();
if (opcode::is_new_array(insn_look->opcode())) {
// If the array being assigned to came from a new-array, look for the sput
// it flowed into.
insn_look = new_array_to_sput.at(insn_look);
}
auto switch_case = insn_kase->get_literal();
auto* lookup_table =
resolve_field(insn_look->get_field(), FieldSearch::Static);
auto* enum_field =
resolve_field(insn_gete->get_field(), FieldSearch::Static);
always_assert(lookup_table);
always_assert(enum_field && is_enum(enum_field));
always_assert_log(switch_case > 0,
"The generated SwitchMap should have positive keys");
generated_switch_cases[lookup_table].emplace(switch_case, enum_field);
}
}
/**
* Get `java.lang.Enum`'s ctor.
* Details: https://developer.android.com/reference/java/lang/Enum.html
*/
DexMethod* get_java_enum_ctor() {
DexType* java_enum_type = type::java_lang_Enum();
DexClass* java_enum_cls = type_class(java_enum_type);
const std::vector<DexMethod*>& java_enum_ctors = java_enum_cls->get_ctors();
always_assert(java_enum_ctors.size() == 1);
return java_enum_ctors.at(0);
}
class OptimizeEnums {
public:
OptimizeEnums(DexStoresVector& stores, ConfigFiles& conf)
: m_stores(stores), m_pg_map(conf.get_proguard_map()) {
m_scope = build_class_scope(stores);
m_java_enum_ctor = get_java_enum_ctor();
}
void remove_redundant_generated_classes() {
auto generated_classes = collect_generated_classes();
auto enum_field_to_ordinal = collect_enum_field_ordinals();
EnumTypes collected_enums;
for (const auto& pair : enum_field_to_ordinal) {
collected_enums.emplace(pair.first->get_class());
}
LookupTableToEnum lookup_table_to_enum;
GeneratedSwitchCases generated_switch_cases;
for (const auto& generated_cls : generated_classes) {
auto generated_clinit = generated_cls->get_clinit();
cfg::ScopedCFG clinit_cfg{generated_clinit->get_code()};
associate_lookup_tables_to_enums(generated_cls, *clinit_cfg,
collected_enums, lookup_table_to_enum);
collect_generated_switch_cases(generated_cls, *clinit_cfg,
collected_enums, generated_switch_cases);
// update stats.
m_stats.num_lookup_tables += generated_cls->get_sfields().size();
}
remove_generated_classes_usage(lookup_table_to_enum, enum_field_to_ordinal,
generated_switch_cases);
}
void stats(PassManager& mgr) {
const auto& report = [&mgr](const char* name, size_t stat) {
mgr.set_metric(name, stat);
TRACE(ENUM, 1, "\t%s : %zu", name, stat);
};
report(METRIC_NUM_SYNTHETIC_CLASSES, m_stats.num_synthetic_classes);
report(METRIC_NUM_LOOKUP_TABLES, m_stats.num_lookup_tables);
report(METRIC_NUM_LOOKUP_TABLES_REMOVED, m_lookup_tables_replaced.size());
report(METRIC_NUM_ENUM_CLASSES, m_stats.num_enum_classes);
report(METRIC_NUM_ENUM_OBJS, m_stats.num_enum_objs);
report(METRIC_NUM_INT_OBJS, m_stats.num_int_objs);
report(METRIC_NUM_SWITCH_EQUIV_FINDER_FAILURES,
m_stats.num_switch_equiv_finder_failures);
report(METRIC_NUM_CANDIDATE_GENERATED_METHODS,
m_stats.num_candidate_generated_methods);
report(METRIC_NUM_REMOVED_GENERATED_METHODS,
m_stats.num_removed_generated_methods);
}
/**
* Replace enum with Boxed Integer object
*/
void replace_enum_with_int(int max_enum_size,
const std::vector<DexType*>& allowlist) {
if (max_enum_size <= 0) {
return;
}
optimize_enums::Config config(max_enum_size, allowlist);
const auto override_graph = method_override_graph::build_graph(m_scope);
calculate_param_summaries(m_scope, *override_graph,
&config.param_summary_map);
/**
* An enum is safe if it not external, has no interfaces, and has only one
* simple enum constructor. Static fields, primitive or string instance
* fields, and virtual methods are safe.
*/
auto is_safe_enum = [this](const DexClass* cls) {
if (is_enum(cls) && !cls->is_external() && is_final(cls) &&
can_delete(cls) && cls->get_interfaces()->size() == 0 &&
only_one_static_synth_field(cls)) {
const auto& ctors = cls->get_ctors();
if (ctors.size() != 1 ||
!is_simple_enum_constructor(cls, ctors.front())) {
return false;
}
for (auto& dmethod : cls->get_dmethods()) {
if (is_static(dmethod) || method::is_constructor(dmethod)) {
continue;
}
if (!can_rename(dmethod)) {
return false;
}
}
for (auto& vmethod : cls->get_vmethods()) {
if (!can_rename(vmethod)) {
return false;
}
}
const auto& ifields = cls->get_ifields();
return std::all_of(ifields.begin(), ifields.end(), [](DexField* field) {
auto type = field->get_type();
return type::is_primitive(type) || type == type::java_lang_String();
});
}
return false;
};
walk::parallel::classes(m_scope, [&config, is_safe_enum](DexClass* cls) {
if (is_safe_enum(cls)) {
config.candidate_enums.insert(cls->get_type());
}
});
optimize_enums::reject_unsafe_enums(m_scope, &config);
if (traceEnabled(ENUM, 4)) {
for (auto cls : config.candidate_enums) {
TRACE(ENUM, 4, "candidate_enum %s", SHOW(cls));
}
}
m_stats.num_enum_objs = optimize_enums::transform_enums(
config, &m_stores, &m_stats.num_int_objs);
m_stats.num_enum_classes = config.candidate_enums.size();
}
/**
* Remove the static methods `valueOf()` and `values()` when safe.
*/
void remove_enum_generated_methods() {
optimize_enums::EnumAnalyzeGeneratedMethods analyzer;
ConcurrentSet<const DexType*> types_used_as_instance_fields;
walk::parallel::classes(m_scope, [&](DexClass* cls) {
// We conservatively reject all enums that are instance fields of classes
// because we don't know if the classes will be serialized or not.
for (auto& ifield : cls->get_ifields()) {
types_used_as_instance_fields.insert(
type::get_element_type_if_array(ifield->get_type()));
}
});
auto should_consider_enum = [&](DexClass* cls) {
// Only consider enums that are final, not external, do not have
// interfaces, and are not instance fields of any classes.
return is_enum(cls) && !cls->is_external() && is_final(cls) &&
can_delete(cls) && cls->get_interfaces()->size() == 0 &&
!types_used_as_instance_fields.count(cls->get_type());
};
walk::parallel::classes(
m_scope, [&should_consider_enum, &analyzer](DexClass* cls) {
if (should_consider_enum(cls)) {
auto& dmethods = cls->get_dmethods();
auto valueof_mit = std::find_if(dmethods.begin(), dmethods.end(),
optimize_enums::is_enum_valueof);
auto values_mit = std::find_if(dmethods.begin(), dmethods.end(),
optimize_enums::is_enum_values);
if (valueof_mit != dmethods.end() && values_mit != dmethods.end()) {
analyzer.consider_enum_type(cls->get_type(), *valueof_mit,
*values_mit);
}
}
});
m_stats.num_candidate_generated_methods =
analyzer.num_candidate_enum_methods();
m_stats.num_removed_generated_methods = analyzer.transform_code(m_scope);
}
private:
/**
* There is usually one synthetic static field in enum class, typically named
* "$VALUES", but also may be renamed.
* Return true if there is only one static synthetic field in the class,
* otherwise return false.
*/
bool only_one_static_synth_field(const DexClass* cls) {
DexField* synth_field = nullptr;
auto synth_access = optimize_enums::synth_access();
for (auto field : cls->get_sfields()) {
if (check_required_access_flags(synth_access, field->get_access())) {
if (synth_field) {
TRACE(ENUM, 2, "Multiple synthetic fields %s %s", SHOW(synth_field),
SHOW(field));
return false;
}
synth_field = field;
}
}
if (!synth_field) {
TRACE(ENUM, 2, "No synthetic field found on %s", SHOW(cls));
return false;
}
return true;
}
/**
* Returns true if the constructor invokes `Enum.<init>`, sets its instance
* fields, and then returns. We want to make sure there are no side affects.
*
* SubEnum.<init>:(Ljava/lang/String;I[other parameters...])V
* load-param * // multiple load parameter instructions
* invoke-direct {} Ljava/lang/Enum;.<init>:(Ljava/lang/String;I)V
* (iput|const) * // put/const instructions for primitive instance fields
* return-void
*/
static bool is_simple_enum_constructor(const DexClass* cls,
const DexMethod* method) {
const auto* params = method->get_proto()->get_args();
if (!is_private(method) || params->size() < 2) {
return false;
}
auto code = InstructionIterable(method->get_code());
auto it = code.begin();
// Load parameter instructions.
while (it != code.end() && opcode::is_a_load_param(it->insn->opcode())) {
++it;
}
if (it == code.end()) {
return false;
}
// invoke-direct {} Ljava/lang/Enum;.<init>:(Ljava/lang/String;I)V
if (!opcode::is_invoke_direct(it->insn->opcode())) {
return false;
} else {
const DexMethodRef* ref = it->insn->get_method();
// Enum.<init>
if (ref->get_class() != type::java_lang_Enum() ||
!method::is_constructor(ref)) {
return false;
}
}
if (++it == code.end()) {
return false;
}
auto is_iput_or_const = [](IROpcode opcode) {
// `const-string` is followed by `move-result-pseudo-object`
return opcode::is_an_iput(opcode) || opcode::is_a_literal_const(opcode) ||
opcode == OPCODE_CONST_STRING ||
opcode == IOPCODE_MOVE_RESULT_PSEUDO_OBJECT;
};
while (it != code.end() && is_iput_or_const(it->insn->opcode())) {
++it;
}
if (it == code.end()) {
return false;
}
// return-void is the last instruction
return opcode::is_return_void(it->insn->opcode()) && (++it) == code.end();
}
/**
* We determine which classes are generated based on:
* - classes that only have 1 dmethods: <clinit>
* - no instance fields, nor virtual methods
* - all static fields match `$SwitchMap$<enum_path>`
*/
std::vector<DexClass*> collect_generated_classes() {
std::vector<DexClass*> generated_classes;
// To avoid any cross store references, only accept generated classes
// that are in the root store (same for the Enums they reference).
XStoreRefs xstores(m_stores);
for (const auto& cls : m_scope) {
size_t cls_store_idx = xstores.get_store_idx(cls->get_type());
if (cls_store_idx > 1) {
continue;
}
auto& sfields = cls->get_sfields();
const auto all_sfield_names_contain = [&sfields](const char* sub) {
return std::all_of(sfields.begin(), sfields.end(),
[sub](DexField* sfield) {
const auto& deobfuscated_name =
sfield->get_deobfuscated_name_or_empty();
const auto& name = deobfuscated_name.empty()
? sfield->get_name()->str()
: deobfuscated_name;
return name.find(sub) != std::string::npos;
});
};
// We expect the generated classes to ONLY contain the lookup tables
// and the static initializer (<clinit>)
//
// Lookup tables for Java Enums all contain $SwitchMap$ in the field name
// and lookup tables for Kotlin Enums all contain $EnumSwitchMapping$ in
// the field name. The two are not expected to mix in a single generated
// class.
if (!sfields.empty() && cls->get_dmethods().size() == 1 &&
cls->get_vmethods().empty() && cls->get_ifields().empty()) {
if (all_sfield_names_contain("$SwitchMap$") ||
all_sfield_names_contain("$EnumSwitchMapping$")) {
generated_classes.emplace_back(cls);
}
}
}
// Update stats.
m_stats.num_synthetic_classes = generated_classes.size();
return generated_classes;
}
EnumFieldToOrdinal collect_enum_field_ordinals() {
EnumFieldToOrdinal enum_field_to_ordinal;
for (const auto& cls : m_scope) {
if (is_enum(cls)) {
collect_enum_field_ordinals(cls, enum_field_to_ordinal);
}
}
return enum_field_to_ordinal;
}
/**
* Collect enum fields to ordinal, if <clinit> is defined.
*/
void collect_enum_field_ordinals(const DexClass* cls,
EnumFieldToOrdinal& enum_field_to_ordinal) {
if (!cls) {
return;
}
auto clinit = cls->get_clinit();
if (!clinit || !clinit->get_code()) {
return;
}
std::unordered_map<const DexMethod*, uint32_t> ctor_to_arg_ordinal;
if (!analyze_enum_ctors(cls, m_java_enum_ctor, ctor_to_arg_ordinal)) {
return;
}
optimize_enums::OptimizeEnumsAnalysis analysis(cls, ctor_to_arg_ordinal);
analysis.collect_ordinals(enum_field_to_ordinal);
}
/**
* Removes the usage of the generated lookup table, by rewriting switch cases
* based on enum ordinals.
*
* The initial switch looks like:
*
* switch (enum_element) {
* case enum_0:
* // do something
* case enum_7:
* // do something
* }
*
* which was re-written to:
*
* switch (int_element) {
* case 1:
* // do something for enum_0
* case 2:
* // do something for enum_7
* }
*
* which we are changing to:
*
* switch (ordinal_element) {
* case 0:
* // do something for enum_0
* case 7:
* // do something for enum_7
* }
*/
void remove_generated_classes_usage(
const LookupTableToEnum& lookup_table_to_enum,
const EnumFieldToOrdinal& enum_field_to_ordinal,
const GeneratedSwitchCases& generated_switch_cases) {
namespace cp = constant_propagation;
walk::parallel::code(m_scope, [&](DexMethod*, IRCode& code) {
cfg::ScopedCFG cfg(&code);
cfg->calculate_exit_block();
optimize_enums::Iterator fixpoint(cfg.get());
fixpoint.run(optimize_enums::Environment());
std::unordered_set<IRInstruction*> switches;
for (const auto& info : fixpoint.collect()) {
const auto pair = switches.insert((*info.branch)->insn);
bool insert_occurred = pair.second;
if (!insert_occurred) {
// Make sure we don't have any duplicate switch opcodes. We can't
// change the register of a switch opcode to two different registers.
continue;
}
if (!check_lookup_table_usage(lookup_table_to_enum,
generated_switch_cases, info)) {
continue;
}
remove_lookup_table_usage(enum_field_to_ordinal, generated_switch_cases,
info);
}
});
}
/**
* Check to make sure this is a valid match. Return false to abort the
* optimization.
*/
bool check_lookup_table_usage(
const LookupTableToEnum& lookup_table_to_enum,
const GeneratedSwitchCases& generated_switch_cases,
const optimize_enums::Info& info) {
// Check this is called on an enum.
auto invoke_ordinal = (*info.invoke)->insn;
auto invoke_type = invoke_ordinal->get_method()->get_class();
auto invoke_cls = type_class(invoke_type);
if (!invoke_cls ||
(invoke_type != type::java_lang_Enum() && !is_enum(invoke_cls))) {
return false;
}
auto lookup_table = info.array_field;
if (!lookup_table || lookup_table_to_enum.count(lookup_table) == 0) {
return false;
}
// Check the current enum corresponds.
auto current_enum = lookup_table_to_enum.at(lookup_table);
if (invoke_type != type::java_lang_Enum() && current_enum != invoke_type) {
return false;
}
return true;
}
/**
* Replaces the usage of the lookup table, by converting
*
* INVOKE_VIRTUAL <v_enum> Enum;.ordinal:()
* MOVE_RESULT <v_ordinal>
* AGET <v_field>, <v_ordinal>
* MOVE_RESULT_PSEUDO <v_dest>
* *_SWITCH <v_dest> ; or IF_EQZ <v_dest> <v_some_constant>
*
* to
*
* INVOKE_VIRTUAL <v_enum> Enum;.ordinal:()
* MOVE_RESULT <v_ordinal>
* MOVE <v_dest>, <v_ordinal>
* SPARSE_SWITCH <v_dest>
*
* if <v_field> was fetched using SGET_OBJECT <lookup_table_holder>
*
* and updating switch cases (on the edges) to the enum field's ordinal
*
* NOTE: We leave unused code around, since LDCE should remove it
* if it isn't used afterwards (which is expected), but we are
* being conservative.
*/
void remove_lookup_table_usage(
const EnumFieldToOrdinal& enum_field_to_ordinal,
const GeneratedSwitchCases& generated_switch_cases,
const optimize_enums::Info& info) {
auto& cfg = info.branch->cfg();
auto branch_block = info.branch->block();
// Use the SwitchEquivFinder to handle not just switch statements but also
// trees of if and switch statements
SwitchEquivFinder finder(&cfg, *info.branch, *info.reg,
50 /* leaf_duplication_threshold */);
if (!finder.success()) {
++m_stats.num_switch_equiv_finder_failures;
return;
}
std::vector<std::pair<int32_t, cfg::Block*>> cases;
const auto& field_enum_map = generated_switch_cases.at(info.array_field);
const auto& key_to_case = finder.key_to_case();
auto fallthrough_it = key_to_case.find(boost::none);
cfg::Block* fallthrough =
fallthrough_it == key_to_case.end() ? nullptr : fallthrough_it->second;
for (const auto& pair : key_to_case) {
auto old_case_key = pair.first;
cfg::Block* leaf = pair.second;
// if-else chains will load constants to compare against. Sometimes the
// leaves will use these values so we have to copy those values to the
// beginning of the leaf blocks. Any dead instructions will be cleaned up
// by LDCE.
const auto& extra_loads = finder.extra_loads();
const auto& loads_for_this_leaf = extra_loads.find(leaf);
if (loads_for_this_leaf != extra_loads.end()) {
for (const auto& register_and_insn : loads_for_this_leaf->second) {
IRInstruction* insn = register_and_insn.second;
if (insn != nullptr) {
// null instruction pointers are used to signify the upper half of a
// wide load.
auto copy = new IRInstruction(*insn);
TRACE(ENUM, 4, "adding %s to B%zu", SHOW(copy), leaf->id());
leaf->push_front(copy);
}
}
}
if (old_case_key == boost::none) {
continue;
}
auto search = field_enum_map.find(*old_case_key);
if (search != field_enum_map.end()) {
auto field_enum = search->second;
auto new_case_key = enum_field_to_ordinal.at(field_enum);
cases.emplace_back(new_case_key, leaf);
} else {
// Ignore blocks with...
// - negative case key, which should be dead code
// - 0 case key, as long as the leaf block is the fallthrough block, as
// 0 encodes the default case
always_assert_log(
*old_case_key < 0 || (*old_case_key == 0 && fallthrough == leaf),
"can't find case key %d leaving block %zu\n%s\nin %s\n",
*old_case_key, branch_block->id(), info.str().c_str(), SHOW(cfg));
}
}
// Add a new register to hold the ordinal and then use it to
// switch on the actual ordinal, instead of using the lookup table.
//
// Basically, the bytecode will be:
//
// INVOKE_VIRTUAL <v_enum> <Enum>;.ordinal:()
// MOVE_RESULT <v_ordinal>
// MOVE <v_new_reg> <v_ordinal> // Newly added
// ...
// AGET <v_field>, <v_ordinal>
// MOVE_RESULT_PSEUDO <v_dest>
// ...
// SPARSE_SWITCH <v_new_reg> // will use <v_new_reg> instead of <v_dest>
//
// NOTE: We leave CopyPropagation to clean up the extra moves and
// LDCE the array access.
auto move_ordinal_it = cfg.move_result_of(*info.invoke);
if (move_ordinal_it.is_end()) {
return;
}
// Remove the switch statement so we can rebuild it with the correct case
// keys. This removes all edges to the if-else blocks and the blocks will
// eventually be removed by cfg.simplify()
cfg.remove_insn(*info.branch);
auto move_ordinal = move_ordinal_it->insn;
auto reg_ordinal = move_ordinal->dest();
auto new_ordinal_reg = cfg.allocate_temp();
auto move_ordinal_result = new IRInstruction(OPCODE_MOVE);
move_ordinal_result->set_src(0, reg_ordinal);
move_ordinal_result->set_dest(new_ordinal_reg);
cfg.insert_after(move_ordinal_it, move_ordinal_result);
// TODO?: Would it be possible to keep the original if-else tree form that
// D8 made? It would probably be better to build a more general purpose
// switch -> if-else converter pass and run it after this pass.
if (cases.size() > 1) {
// Dex Lowering will decide if packed or sparse would be better
IRInstruction* new_switch = new IRInstruction(OPCODE_SWITCH);
new_switch->set_src(0, new_ordinal_reg);
cfg.create_branch(branch_block, new_switch, fallthrough, cases);
} else if (cases.size() == 1) {
// Only one non-fallthrough case, we can use an if statement.
// const vKey, case_key
// if-eqz vOrdinal, vKey
int32_t key = cases[0].first;
cfg::Block* target = cases[0].second;
IRInstruction* const_load = new IRInstruction(OPCODE_CONST);
auto key_reg = cfg.allocate_temp();
const_load->set_dest(key_reg);
const_load->set_literal(key);
branch_block->push_back(const_load);
IRInstruction* new_if = new IRInstruction(OPCODE_IF_EQ);
new_if->set_src(0, new_ordinal_reg);
new_if->set_src(1, key_reg);
cfg.create_branch(branch_block, new_if, fallthrough, target);
} else {
// Just one case, set the goto edge's target to the fallthrough block.
always_assert(cases.empty());
if (fallthrough != nullptr) {
auto existing_goto =
cfg.get_succ_edge_of_type(branch_block, cfg::EDGE_GOTO);
always_assert(existing_goto != nullptr);
cfg.set_edge_target(existing_goto, fallthrough);
}
}
m_lookup_tables_replaced.emplace(info.array_field);
}
/**
* In the following example, `lookup_table` corresponds to `$SwitchMap$Foo`,
* and `clinit_cfg` is expected to be the body of the static initializer:
*
* private static class $1 {
* public static final synthetic int[] $SwitchMap$Foo;
* static {
* $SwitchMap$Foo = new int[Foo.values().length];
* $SwitchMap$Foo[Foo.Bar.ordinal()] = 1;
* $SwitchMap$Foo[Foo.Baz.ordinal()] = 2;
* // ...
* }
* }
*
* This function finds the enum class corresponding to `lookup_table` (`Foo`
* in the example) by tracing back from its initialization:
*
* invoke-static {}, LFoo;.values:()[LFoo; <- Find this,
* move-result-object v0
* array-length v0
* move-result-pseudo v1
* new-array v1
* move-result-pseudo-object v2
* sput-object v2, $1;.$SwitchMap$Foo:[I <- Starting here.
*
* Populates `mapping` with all the Enum types corresponding to lookup table
* fields initialised in the provided CFG.
*/
void associate_lookup_tables_to_enums(DexClass* generated_cls,
cfg::ControlFlowGraph& clinit_cfg,
const EnumTypes& collected_enums,
LookupTableToEnum& mapping) {
mf::flow_t f;
auto m__invoke_values = m::invoke_static_(m::has_method(
m::named<DexMethodRef>("values") &&
m::member_of<DexMethodRef>(m::in<DexType*>(collected_enums))));
auto m__sput_lookup = m::sput_object_(m::has_field(
m::member_of<DexFieldRef>(m::equals(generated_cls->get_type()))));
auto uniq = mf::alias | mf::unique;
auto vals = f.insn(m__invoke_values);
auto alen = f.insn(m::array_length_()).src(0, vals, uniq);
auto newa = f.insn(m::new_array_()).src(0, alen, uniq);
auto sput = f.insn(m__sput_lookup).src(0, newa, uniq);
auto res = f.find(clinit_cfg, sput);
for (auto* insn_sput : res.matching(sput)) {
auto* insn_newa = res.matching(sput, insn_sput, 0).unique();
auto* insn_alen = res.matching(newa, insn_newa, 0).unique();
auto* insn_vals = res.matching(alen, insn_alen, 0).unique();
always_assert(insn_vals && "sput only valid if unique vals exists.");
auto* lookup_field =
resolve_field(insn_sput->get_field(), FieldSearch::Static);
auto* enum_type = insn_vals->get_method()->get_class();
always_assert(lookup_field);
mapping.emplace(lookup_field, enum_type);
}
}
Scope m_scope;
DexStoresVector& m_stores;
struct Stats {
size_t num_synthetic_classes{0};
size_t num_lookup_tables{0};
size_t num_enum_classes{0};
size_t num_enum_objs{0};
size_t num_int_objs{0};
std::atomic<size_t> num_switch_equiv_finder_failures{0};
size_t num_candidate_generated_methods{0};
size_t num_removed_generated_methods{0};
};
Stats m_stats;
ConcurrentSet<DexField*> m_lookup_tables_replaced;
const DexMethod* m_java_enum_ctor;
const ProguardMap& m_pg_map;
};
} // namespace
namespace optimize_enums {
void OptimizeEnumsPass::bind_config() {
bind("max_enum_size", 100, m_max_enum_size,
"The maximum number of enum field substitutions that are generated and "
"stored in primary dex.");
bind("break_reference_equality_allowlist", {}, m_enum_to_integer_allowlist,
"A allowlist of enum classes that may have more than `max_enum_size` "
"enum fields, try to erase them without considering reference equality "
"of the enum objects. Do not add enums to the allowlist!");
}
void OptimizeEnumsPass::run_pass(DexStoresVector& stores,
ConfigFiles& conf,
PassManager& mgr) {
OptimizeEnums opt_enums(stores, conf);
opt_enums.remove_redundant_generated_classes();
opt_enums.replace_enum_with_int(m_max_enum_size, m_enum_to_integer_allowlist);
opt_enums.remove_enum_generated_methods();
opt_enums.stats(mgr);
}
static OptimizeEnumsPass s_pass;
} // namespace optimize_enums