opt/delinit/DelInit.cpp (479 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include "DelInit.h" #include <algorithm> #include <map> #include <string> #include <unordered_map> #include <unordered_set> #include <vector> #include "ConcurrentContainers.h" #include "DexClass.h" #include "DexUtil.h" #include "IRInstruction.h" #include "PassManager.h" #include "ReachableClasses.h" #include "Resolver.h" #include "Show.h" #include "Trace.h" #include "Walkers.h" /* * This is not a visitor pattern dead-code eliminator with explicit * entry points. Rather it's a delete everything that's never * referenced elimnator. Thus the name "delinit". */ namespace { constexpr const char* METRIC_INIT_METHODS_REMOVED = "num_init_methods_removed"; constexpr const char* METRIC_VMETHODS_REMOVED = "num_vmethods_removed"; constexpr const char* METRIC_IFIELDS_REMOVED = "num_ifields_removed"; constexpr const char* METRIC_DMETHODS_REMOVED = "num_dmethods_removed"; static ConcurrentSet<const DexClass*> referenced_classes; // List of packages on the white list static std::vector<std::string> package_filter; // Note: this method will return nullptr if the dotname refers to an unknown // type. DexType* get_dextype_from_dotname(const char* dotname) { if (dotname == nullptr) { return nullptr; } std::string buf; buf.reserve(strlen(dotname) + 2); buf += 'L'; buf += dotname; buf += ';'; std::replace(buf.begin(), buf.end(), '.', '/'); return DexType::get_type(buf.c_str()); } // Search a class name in a list of package names, return true if there is a // match bool find_package(const char* name) { // If there's no allowed package, optimize every package by default if (package_filter.empty()) { return true; } for (auto& el_str : package_filter) { auto const el_name = el_str.c_str(); if (strncmp(name, el_name, strlen(el_name)) == 0) { return true; } } return false; }; void process_signature_anno(const DexString* dstring) { const char* cstr = dstring->c_str(); size_t len = strlen(cstr); if (len < 3) return; if (cstr[0] != 'L') return; if (cstr[len - 1] == ';') { auto dtype = DexType::get_type(dstring); referenced_classes.insert(type_class(dtype)); return; } std::string buf(cstr); buf += ';'; auto dtype = DexType::get_type(buf.c_str()); referenced_classes.insert(type_class(dtype)); } void find_referenced_classes(const Scope& scope) { walk::parallel::annotations(scope, [&](DexAnnotation* anno) { static DexType* dalviksig = DexType::get_type("Ldalvik/annotation/Signature;"); // Signature annotations contain strings that Jackson uses // to construct the underlying types. if (anno->type() == dalviksig) { auto& elems = anno->anno_elems(); for (auto const& elem : elems) { auto& ev = elem.encoded_value; if (ev->evtype() != DEVT_ARRAY) continue; auto arrayev = static_cast<DexEncodedValueArray*>(ev.get()); auto const& evs = arrayev->evalues(); for (auto& strev : *evs) { if (strev->evtype() != DEVT_STRING) continue; auto stringev = static_cast<DexEncodedValueString*>(strev.get()); process_signature_anno(stringev->string()); } } return; } // Class literals in annotations. // Example: // @JsonDeserialize(using=MyJsonDeserializer.class) if (anno->runtime_visible()) { auto& elems = anno->anno_elems(); for (auto const& dae : elems) { auto& evalue = dae.encoded_value; std::vector<DexType*> ltype; evalue->gather_types(ltype); if (!ltype.empty()) { for (auto dextype : ltype) { referenced_classes.insert(type_class(dextype)); } } } } }); walk::parallel::code( scope, [](DexMethod*) { return true; }, [&](DexMethod* meth, IRCode& code) { for (const auto& mie : InstructionIterable(meth->get_code())) { auto opcode = mie.insn; // Matches any stringref that name-aliases a type. if (opcode->has_string()) { const DexString* dsclzref = opcode->get_string(); DexType* dtexclude = get_dextype_from_dotname(dsclzref->c_str()); if (dtexclude == nullptr) continue; TRACE(PGR, 3, "string_ref: %s", SHOW(dtexclude)); referenced_classes.insert(type_class(dtexclude)); } if (opcode->has_type()) { TRACE(PGR, 3, "type_ref: %s", SHOW(opcode->get_type())); referenced_classes.insert(type_class(opcode->get_type())); } } }); } bool can_remove(const DexClass* cls) { return !root_or_string(cls) && !referenced_classes.count_unsafe(cls); } bool can_remove(const DexMethod* m, const ConcurrentSet<DexMethod*>& callers) { return callers.count_unsafe(const_cast<DexMethod*>(m)) == 0 && (can_remove(type_class(m->get_class())) || !root_or_string(m)); } /** * A constructor can be removed if: * - the class can be removed. * or * - it can be deleted * - there is another constructor for the class that is used. */ bool can_remove_init(const DexMethod* m, const ConcurrentSet<DexMethod*>& called) { DexClass* clazz = type_class(m->get_class()); if (can_remove(clazz)) { return true; } else if (m->get_proto()->get_args()->size() == 0) { // If the class is kept, we should probably keep the no argument constructor // Because it may be invoked with `Class.newInstance()`. return false; } if (root_or_string(m)) { return false; } auto const& dmeths = clazz->get_dmethods(); for (auto meth : dmeths) { if (meth->get_code() == nullptr) continue; if (method::is_init(meth)) { if (meth != m && called.count_unsafe(meth) > 0) { return true; } } } return false; } bool can_remove(const DexField* f) { return can_remove(type_class(f->get_class())) || !root_or_string(f); } /** * Return true for classes that should not be processed by the optimization. */ bool filter_class(DexClass* clazz) { always_assert(!clazz->is_external()); if (!find_package(clazz->get_name()->c_str())) { return true; } return is_interface(clazz) || is_annotation(clazz); } using MethodSet = std::unordered_set<DexMethod*>; using FieldSet = std::unordered_set<DexField*>; using MethodVector = std::vector<DexMethod*>; /** * Main class to track DelInit optimizations. * For each pass collects all the instance data (vmethods and ifields) * for classes that have no ctor or all unreachable ctors. * Then it walks all the opcodes to see if there are references to any of those * members and if so the member (method or field) is not deleted. * In the process it also finds all the methods and ctors unreachable. * Repeat the process until no more methods are removed. */ struct DeadRefs { // all the data is per pass, so it is cleared at the proper time // in each step. // list of classes that have no reachable ctor ConcurrentSet<DexClass*> classes; // set of invoked methods ConcurrentSet<DexMethod*> called; struct ClassInfo { // list of vmethods from classes with no reachable ctor MethodSet vmethods; // list of ifields from classes with no reachable ctor FieldSet ifields; // set of all ctors that are known MethodVector initmethods; // set of dmethods (no init or clinit) that are known MethodVector dmethods; }; std::unordered_map<DexClass*, ClassInfo> class_infos; // statistic info struct stats { size_t deleted_inits{0}; size_t deleted_vmeths{0}; size_t deleted_ifields{0}; size_t deleted_dmeths{0}; } del_init_res; void delinit(Scope& scope); int find_new_unreachable(Scope& scope); void find_unreachable(Scope& scope); void find_unreachable_data(DexClass* clazz); void collect_dmethods(Scope& scope); void track_callers(Scope& scope); int remove_unreachable(Scope& scope); }; /** * Entry point for DelInit. * Loop through the different steps until no more methods are deleted. */ void DeadRefs::delinit(Scope& scope) { int removed = 0; int passnum = 0; for (auto cls : scope) { class_infos.emplace(cls, ClassInfo{}); } do { passnum++; TRACE(DELINIT, 2, "Summary for pass %d", passnum); removed = find_new_unreachable(scope); collect_dmethods(scope); track_callers(scope); removed += remove_unreachable(scope); } while (removed > 0); } /** * Find new unreachable classes. * First it deletes all unreachable ctor then calls into find_unreachable. */ int DeadRefs::find_new_unreachable(Scope& scope) { struct LocalStats { int init_deleted = 0; int init_called = 0; int init_cant_delete = 0; }; ConcurrentMap<DexClass*, LocalStats> local_stats; walk::parallel::classes(scope, [&](DexClass* clazz) { auto& ci = class_infos.at(clazz); LocalStats stats; for (auto init : ci.initmethods) { if (called.count_unsafe(init) > 0) { stats.init_called++; continue; } if (!can_remove_init(init, called)) { stats.init_cant_delete++; continue; } always_assert(clazz == type_class(init->get_class())); clazz->remove_method(init); TRACE(DELINIT, 5, "Delete init %s.%s %s", SHOW(init->get_class()), SHOW(init->get_name()), SHOW(init->get_proto())); stats.init_deleted++; } local_stats.emplace(clazz, stats); }); LocalStats acc; for (auto& p : local_stats) { acc.init_called += p.second.init_called; acc.init_cant_delete += p.second.init_cant_delete; acc.init_deleted += p.second.init_deleted; } TRACE(DELINIT, 2, "Removed %d <init> methods", acc.init_deleted); TRACE(DELINIT, 3, "%d <init> methods called", acc.init_called); TRACE(DELINIT, 3, "%d <init> methods do not delete", acc.init_cant_delete); find_unreachable(scope); del_init_res.deleted_inits += acc.init_deleted; return acc.init_deleted; } /* Collect instance data for classes that do not have <init> routines. * This means the vtable and the ifields. */ void DeadRefs::find_unreachable(Scope& scope) { classes.clear(); walk::parallel::classes(scope, [&](DexClass* clazz) { auto& ci = class_infos.at(clazz); ci.vmethods.clear(); ci.ifields.clear(); if (filter_class(clazz)) return; auto const& dmeths = clazz->get_dmethods(); bool hasInit = false; for (auto meth : dmeths) { if (method::is_init(meth)) { hasInit = true; break; } } if (hasInit) return; find_unreachable_data(clazz); }); size_t vmethods = 0, ifields = 0; for (const auto& p : class_infos) { vmethods += p.second.vmethods.size(); ifields += p.second.ifields.size(); } TRACE(DELINIT, 2, "Uninstantiable classes %ld: vmethods %ld, ifields %ld", classes.size(), vmethods, ifields); } /** * Collect all instance data (ifields, vmethods) given the class is * uninstantiable. */ void DeadRefs::find_unreachable_data(DexClass* clazz) { ClassInfo& ci = class_infos.at(clazz); for (const auto& meth : clazz->get_vmethods()) { if (!can_remove(meth, called)) continue; ci.vmethods.insert(meth); } for (const auto& field : clazz->get_ifields()) { if (!can_remove(field)) continue; ci.ifields.insert(field); } classes.insert(clazz); } /** * Collect all init and direct methods but not vm methods (clint, '<...'). */ void DeadRefs::collect_dmethods(Scope& scope) { struct LocalStats { size_t initmethods{0}; size_t dmethods{0}; }; ConcurrentMap<DexClass*, LocalStats> local_stats; walk::parallel::classes(scope, [&](DexClass* clazz) { auto& ci = class_infos.at(clazz); ci.initmethods.clear(); ci.dmethods.clear(); if (filter_class(clazz)) return; auto const& dmeths = clazz->get_dmethods(); for (auto meth : dmeths) { if (meth->get_code() == nullptr) continue; if (method::is_init(meth)) { ci.initmethods.push_back(meth); } else { // Method names beginning with '<' are internal VM calls // except <init> if (meth->get_name()->c_str()[0] != '<') { ci.dmethods.push_back(meth); } } } if (!ci.initmethods.empty() || !ci.dmethods.empty()) { local_stats.emplace( clazz, (LocalStats){ci.initmethods.size(), ci.dmethods.size()}); } }); LocalStats acc; for (auto& p : local_stats) { acc.initmethods += p.second.initmethods; acc.dmethods += p.second.dmethods; } TRACE(DELINIT, 3, "Found %ld init and %ld dmethods", acc.initmethods, acc.dmethods); } /** * Walk all opcodes and find all methods called (live in scope). * Also remove all potentially unreachable members - if a reference exists - * from the set of removable instance data. */ void DeadRefs::track_callers(Scope& scope) { called.clear(); struct ToErase { MethodSet vmethods; FieldSet ifields; }; ConcurrentMap<DexType*, ToErase> to_erase; walk::parallel::opcodes( scope, [](DexMethod*) { return true; }, [&](DexMethod* m, IRInstruction* insn) { if (insn->has_method()) { auto callee = resolve_method(insn->get_method(), opcode_to_search(insn), m); if (callee == nullptr || !callee->is_concrete()) return; to_erase.update( callee->get_class(), [callee](const DexType*, ToErase& te, bool /* exists */) { te.vmethods.insert(callee); }); called.insert(callee); return; } if (insn->has_field()) { auto field = resolve_field( insn->get_field(), opcode::is_an_ifield_op(insn->opcode()) ? FieldSearch::Instance : opcode::is_an_sfield_op(insn->opcode()) ? FieldSearch::Static : FieldSearch::Any); if (field == nullptr || !field->is_concrete()) return; to_erase.update( field->get_class(), [field](const DexType*, ToErase& te, bool /* exists */) { te.ifields.insert(field); }); return; } }); std::vector<DexClass*> classes_to_erase; size_t vmethods = 0; size_t ifields = 0; for (auto& p : to_erase) { auto cls = type_class(p.first); always_assert(cls); classes_to_erase.push_back(cls); vmethods += p.second.vmethods.size(); ifields += p.second.ifields.size(); } walk::parallel::classes(classes_to_erase, [&](DexClass* cls) { if (!classes.count(cls)) { return; } auto& ci = class_infos.at(cls); auto& te = to_erase.at_unsafe(cls->get_type()); for (DexMethod* vmethod : te.vmethods) { if (ci.vmethods.count(vmethod)) { ci.vmethods.erase(vmethod); } } for (DexField* ifield : te.ifields) { if (ci.ifields.count(ifield)) { ci.ifields.erase(ifield); } } }); TRACE(DELINIT, 3, "Unreachable (not called) %ld vmethods and %ld ifields", vmethods, ifields); } /** * Delete of all unreachable members. */ int DeadRefs::remove_unreachable(Scope& scope) { struct LocalStats { int vmethodcnt = 0; int dmethodcnt = 0; int ifieldcnt = 0; int called_dmeths = 0; int dont_delete_dmeths = 0; }; ConcurrentMap<DexClass*, LocalStats> local_stats; walk::parallel::classes(scope, [&](DexClass* cls) { auto ci = class_infos.at(cls); LocalStats stats; for (const auto& meth : ci.vmethods) { redex_assert(meth->is_virtual()); always_assert(cls == type_class(meth->get_class())); auto& methods = cls->get_vmethods(); auto meth_it = std::find(methods.begin(), methods.end(), meth); if (meth_it == methods.end()) continue; methods.erase(meth_it); stats.vmethodcnt++; TRACE(DELINIT, 6, "Delete vmethod: %s.%s %s", SHOW(meth->get_class()), SHOW(meth->get_name()), SHOW(meth->get_proto())); } for (const auto& field : ci.ifields) { redex_assert(!is_static(field)); always_assert(cls == type_class(field->get_class())); auto& fields = cls->get_ifields(); auto field_it = std::find(fields.begin(), fields.end(), field); if (field_it == fields.end()) continue; fields.erase(field_it); stats.ifieldcnt++; TRACE(DELINIT, 6, "Delete ifield: %s.%s %s", SHOW(field->get_class()), SHOW(field->get_name()), SHOW(field->get_type())); } for (const auto& meth : ci.dmethods) { redex_assert(!meth->is_virtual()); if (called.count_unsafe(meth) > 0) { stats.called_dmeths++; continue; } if (!can_remove(meth, called)) { stats.dont_delete_dmeths++; continue; } auto clazz = type_class(meth->get_class()); clazz->remove_method(meth); stats.dmethodcnt++; TRACE(DELINIT, 6, "Delete dmethod: %s.%s %s", SHOW(meth->get_class()), SHOW(meth->get_name()), SHOW(meth->get_proto())); } if (stats.vmethodcnt || stats.dmethodcnt || stats.ifieldcnt || stats.called_dmeths || stats.dont_delete_dmeths) { local_stats.emplace(cls, stats); } }); LocalStats acc; for (auto& p : local_stats) { acc.vmethodcnt += p.second.vmethodcnt; acc.dmethodcnt += p.second.dmethodcnt; acc.ifieldcnt += p.second.ifieldcnt; acc.called_dmeths += p.second.called_dmeths; acc.dont_delete_dmeths += p.second.dont_delete_dmeths; } del_init_res.deleted_vmeths += acc.vmethodcnt; TRACE(DELINIT, 2, "Removed %d vmethods", acc.vmethodcnt); del_init_res.deleted_ifields += acc.ifieldcnt; TRACE(DELINIT, 2, "Removed %d ifields", acc.ifieldcnt); del_init_res.deleted_dmeths += acc.dmethodcnt; TRACE(DELINIT, 2, "Removed %d dmethods", acc.dmethodcnt); TRACE(DELINIT, 3, "%d called dmethods", acc.called_dmeths); TRACE(DELINIT, 3, "%d don't delete dmethods", acc.dont_delete_dmeths); return acc.vmethodcnt + acc.ifieldcnt + acc.dmethodcnt; } } // namespace void DelInitPass::run_pass(DexStoresVector& stores, ConfigFiles& /* conf */, PassManager& mgr) { if (mgr.no_proguard_rules()) { TRACE( DELINIT, 1, "DelInitPass not run because no ProGuard configuration was provided."); return; } package_filter = m_package_filter; auto scope = build_class_scope(stores); find_referenced_classes(scope); DeadRefs drefs; drefs.delinit(scope); TRACE(DELINIT, 1, "Removed %zu <init> methods", drefs.del_init_res.deleted_inits); TRACE(DELINIT, 1, "Removed %zu vmethods", drefs.del_init_res.deleted_vmeths); TRACE(DELINIT, 1, "Removed %zu ifields", drefs.del_init_res.deleted_ifields); TRACE(DELINIT, 1, "Removed %zu dmethods", drefs.del_init_res.deleted_dmeths); mgr.incr_metric(METRIC_INIT_METHODS_REMOVED, drefs.del_init_res.deleted_inits); mgr.incr_metric(METRIC_VMETHODS_REMOVED, drefs.del_init_res.deleted_vmeths); mgr.incr_metric(METRIC_IFIELDS_REMOVED, drefs.del_init_res.deleted_ifields); mgr.incr_metric(METRIC_DMETHODS_REMOVED, drefs.del_init_res.deleted_dmeths); post_dexen_changes(scope, stores); } static DelInitPass s_pass;