source/CallGraph.cpp (626 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include <re2/re2.h> #include <IRInstruction.h> #include <Resolver.h> #include <Show.h> #include <SpartaWorkQueue.h> #include <mariana-trench/Assert.h> #include <mariana-trench/CallGraph.h> #include <mariana-trench/Features.h> #include <mariana-trench/JsonValidation.h> #include <mariana-trench/Log.h> #include <mariana-trench/Methods.h> namespace marianatrench { namespace { bool is_virtual_invoke(const IRInstruction* instruction) { switch (instruction->opcode()) { case OPCODE_INVOKE_VIRTUAL: case OPCODE_INVOKE_INTERFACE: return true; default: return false; } } /* Return the resolved base callee. */ const DexMethod* MT_NULLABLE resolve_call( const Types& types, const Method* caller, const IRInstruction* instruction) { mt_assert(caller != nullptr); mt_assert(opcode::is_an_invoke(instruction->opcode())); DexMethodRef* dex_method_reference = instruction->get_method(); mt_assert_log( dex_method_reference != nullptr, "invoke instruction has no method reference"); DexMethod* method = nullptr; switch (instruction->opcode()) { case OPCODE_INVOKE_DIRECT: case OPCODE_INVOKE_STATIC: case OPCODE_INVOKE_SUPER: { // No need to consider the runtime type. method = resolve_method( dex_method_reference, opcode_to_search(instruction->opcode()), caller->dex_method()); break; } case OPCODE_INVOKE_VIRTUAL: case OPCODE_INVOKE_INTERFACE: { // Use the inferred runtime type to refine the call. const DexType* type = types.receiver_type(caller, instruction); const DexClass* klass = type ? type_class(type) : nullptr; if (!klass) { method = resolve_method(dex_method_reference, MethodSearch::Virtual); } else { method = resolve_method( klass, dex_method_reference->get_name(), dex_method_reference->get_proto(), MethodSearch::Virtual); } if (!method) { // `MethodSearch::Virtual` returns null for interface methods. method = resolve_method(dex_method_reference, MethodSearch::Interface); } break; } default: mt_assert_log(false, "unexpected opcode"); } return method; } const DexField* MT_NULLABLE resolve_field_access(const Method* caller, const IRInstruction* instruction) { mt_assert(caller != nullptr); mt_assert( opcode::is_an_iget(instruction->opcode()) || opcode::is_an_sget(instruction->opcode()) || opcode::is_an_iput(instruction->opcode()) || opcode::is_an_sput(instruction->opcode())); DexFieldRef* dex_field_reference = instruction->get_field(); mt_assert_log( dex_field_reference != nullptr, "Field access (iget, sget, iput) instruction has no field reference"); if (opcode::is_an_sget(instruction->opcode()) || opcode::is_an_sput(instruction->opcode())) { return resolve_field(dex_field_reference, FieldSearch::Static); } return resolve_field(dex_field_reference, FieldSearch::Instance); } bool is_anonymous_class(const DexType* type) { static const re2::RE2 regex("^.*\\$\\d+;$"); return re2::RE2::FullMatch(show(type), regex); } // Return mapping of argument index to argument type for all anonymous class // arguments. ParameterTypeOverrides anonymous_class_arguments( const Types& types, const Method* caller, const IRInstruction* instruction, const DexMethod* callee) { mt_assert(callee != nullptr); ParameterTypeOverrides parameters; const auto& environment = types.environment(caller, instruction); auto sources = instruction->srcs_vec(); for (std::size_t source_position = 0; source_position < sources.size(); source_position++) { auto parameter_position = source_position; if (!is_static(callee)) { if (source_position == 0) { // Do not override `this`. continue; } else { // Do not count the `this` parameter parameter_position--; } } auto found = environment.find(sources[source_position]); if (found == environment.end()) { continue; } const auto* type = found->second; if (type && is_anonymous_class(type)) { parameters.emplace(parameter_position, type); } } return parameters; } ArtificialCallees anonymous_class_artificial_callees( const Methods& method_factory, const IRInstruction* instruction, const DexType* anonymous_class_type, Register register_id, const FeatureSet& features = {}) { if (!is_anonymous_class(anonymous_class_type)) { return {}; } const DexClass* anonymous_class = type_class(anonymous_class_type); if (!anonymous_class) { return {}; } ArtificialCallees callees; for (const auto* dex_method : anonymous_class->get_vmethods()) { const auto* method = method_factory.get(dex_method); mt_assert(!method->is_constructor()); mt_assert(!method->is_static()); callees.push_back(ArtificialCallee{ /* call_target */ CallTarget::static_call(instruction, method), /* register_parameters */ {register_id}, /* features */ features, }); } return callees; } ArtificialCallees artificial_callees_from_arguments( const Methods& method_factory, const Features& features, const IRInstruction* instruction, const DexMethod* callee, const ParameterTypeOverrides& parameter_type_overrides) { ArtificialCallees callees; // For each anonymous class parameter, simulate calls to all its methods. for (auto [parameter, anonymous_class_type] : parameter_type_overrides) { auto artificial_callees_from_parameter = anonymous_class_artificial_callees( method_factory, instruction, anonymous_class_type, /* register */ instruction->src(parameter + (is_static(callee) ? 0 : 1)), /* features */ FeatureSet{features.get("via-anonymous-class-to-obscure")}); callees.insert( callees.end(), std::make_move_iterator(artificial_callees_from_parameter.begin()), std::make_move_iterator(artificial_callees_from_parameter.end())); } return callees; } /* * Given the DexMethod representing the callee of an instruction, get or create * the Method corresponding to the call */ const Method* get_callee_from_resolved_call( const DexMethod* dex_callee, const IRInstruction* instruction, const ParameterTypeOverrides& parameter_type_overrides, const Options& options, Methods& method_factory, const Features& features, ArtificialCallees& artificial_callees) { const Method* callee = nullptr; if (dex_callee->get_code() == nullptr) { // When passing an anonymous class into a callee, add artificial // calls to all methods of the anonymous class. auto artificial_callees_for_instruction = artificial_callees_from_arguments( method_factory, features, instruction, dex_callee, parameter_type_overrides); if (!artificial_callees_for_instruction.empty()) { artificial_callees = std::move(artificial_callees_for_instruction); } // No need to use type overrides since we don't have the code. callee = method_factory.get(dex_callee); } else if (options.disable_parameter_type_overrides()) { callee = method_factory.get(dex_callee); } else { // Analyze the callee with these particular types. callee = method_factory.create(dex_callee, parameter_type_overrides); } mt_assert(callee != nullptr); return callee; } struct InstructionCallGraphInformation { std::optional<const Method*> callee; ArtificialCallees artificial_callees = {}; std::optional<const Field*> field_access; }; InstructionCallGraphInformation process_instruction( const Method* caller, const IRInstruction* instruction, ConcurrentSet<const Method*>& worklist, ConcurrentSet<const Method*>& processed, const Options& options, Methods& method_factory, Fields& field_factory, const Types& types, Overrides& override_factory, const Features& features) { InstructionCallGraphInformation instruction_information; if (opcode::is_an_iput(instruction->opcode())) { // Add artificial calls to all methods in an anonymous class. const auto* iput_type = types.source_type(caller, instruction, /* source_position */ 0); if (iput_type && is_anonymous_class(iput_type)) { auto artificial_callees_for_instruction = anonymous_class_artificial_callees( method_factory, instruction, iput_type, /* register */ instruction->src(0), /* features */ FeatureSet{features.get("via-anonymous-class-to-field")}); if (!artificial_callees_for_instruction.empty()) { instruction_information.artificial_callees = std::move(artificial_callees_for_instruction); } } const auto* field = resolve_field_access(caller, instruction); if (field != nullptr) { instruction_information.field_access = field_factory.get(field); } return instruction_information; } if (opcode::is_an_iget(instruction->opcode()) || opcode::is_an_sget(instruction->opcode()) || opcode::is_an_sput(instruction->opcode())) { const auto* field = resolve_field_access(caller, instruction); if (field != nullptr) { instruction_information.field_access = field_factory.get(field); } return instruction_information; } if (!opcode::is_an_invoke(instruction->opcode())) { return instruction_information; } const DexMethod* dex_callee = resolve_call(types, caller, instruction); if (!dex_callee) { return instruction_information; } ParameterTypeOverrides parameter_type_overrides = anonymous_class_arguments(types, caller, instruction, dex_callee); const auto* callee = get_callee_from_resolved_call( dex_callee, instruction, parameter_type_overrides, options, method_factory, features, instruction_information.artificial_callees); instruction_information.callee = callee; if (callee->parameter_type_overrides().empty() || processed.count(callee) != 0) { return instruction_information; } // This is a newly introduced method with parameter type // overrides. We need to generate it's method overrides, // and compute callees for them. const Method* original_callee = method_factory.get(callee->dex_method()); std::unordered_set<const Method*> original_methods = override_factory.get(original_callee); original_methods.insert(original_callee); for (const Method* original_method : original_methods) { const Method* method = method_factory.create( original_method->dex_method(), callee->parameter_type_overrides()); std::unordered_set<const Method*> overrides; for (const Method* original_override : override_factory.get(original_method)) { overrides.insert(method_factory.create( original_override->dex_method(), callee->parameter_type_overrides())); } if (!overrides.empty()) { override_factory.set(method, std::move(overrides)); } if (processed.count(method) == 0) { worklist.insert(method); } } return instruction_information; } } // namespace CallTarget::CallTarget( const IRInstruction* instruction, const Method* MT_NULLABLE resolved_base_callee, const DexType* MT_NULLABLE receiver_type, const std::unordered_set<const Method*>* MT_NULLABLE overrides, const std::unordered_set<const DexType*>* MT_NULLABLE receiver_extends) : instruction_(instruction), resolved_base_callee_(resolved_base_callee), receiver_type_(receiver_type), overrides_(overrides), receiver_extends_(receiver_extends) {} CallTarget CallTarget::static_call( const IRInstruction* instruction, const Method* MT_NULLABLE callee) { return CallTarget( instruction, /* resolved_base_callee */ callee, /* receiver_type */ nullptr, /* overrides */ nullptr, /* receiver_extends */ nullptr); } CallTarget CallTarget::virtual_call( const IRInstruction* instruction, const Method* MT_NULLABLE resolved_base_callee, const DexType* MT_NULLABLE receiver_type, const ClassHierarchies& class_hierarchies, const Overrides& override_factory) { // All overrides are potential callees. const std::unordered_set<const Method*>* overrides = nullptr; if (resolved_base_callee != nullptr) { overrides = &override_factory.get(resolved_base_callee); } else { overrides = &override_factory.empty_method_set(); } // If the receiver type does not define the method, `resolved_base_callee` // will reference a method on a parent class. Taking all overrides of // `resolved_base_callee` can be imprecise since it would include overrides // that don't extend the receiver type. Filtering overrides based on classes // extending the receiver type fixes the problem. // // For instance: // ``` // class A { void f() { ... } } // class B implements A {} // class C extends B { void f() { ... } } // class D implements A { void f() { ... } } // ``` // A virtual call to `B::f` has a resolved base callee of `A::f`. Overrides // of `A::f` includes `D::f`, but `D::f` cannot be called since `D` does not // extend `B`. const std::unordered_set<const DexType*>* receiver_extends = nullptr; if (receiver_type != nullptr && receiver_type != type::java_lang_Object()) { receiver_extends = &class_hierarchies.extends(receiver_type); } return CallTarget( instruction, resolved_base_callee, receiver_type, overrides, receiver_extends); } CallTarget CallTarget::from_call_instruction( const Method* caller, const IRInstruction* instruction, const Method* MT_NULLABLE resolved_base_callee, const Types& types, const ClassHierarchies& class_hierarchies, const Overrides& override_factory) { mt_assert(opcode::is_an_invoke(instruction->opcode())); if (is_virtual_invoke(instruction)) { return CallTarget::virtual_call( instruction, resolved_base_callee, types.receiver_type(caller, instruction), class_hierarchies, override_factory); } else { return CallTarget::static_call(instruction, resolved_base_callee); } } bool CallTarget::FilterOverrides::operator()(const Method* method) const { return extends == nullptr || extends->count(method->get_class()) > 0; } CallTarget::OverridesRange CallTarget::overrides() const { mt_assert(resolved()); mt_assert(is_virtual()); return boost::make_iterator_range( boost::make_filter_iterator( FilterOverrides{receiver_extends_}, overrides_->cbegin()), boost::make_filter_iterator( FilterOverrides{receiver_extends_}, overrides_->cend())); } bool CallTarget::operator==(const CallTarget& other) const { return instruction_ == other.instruction_ && resolved_base_callee_ == other.resolved_base_callee_ && receiver_type_ == other.receiver_type_ && overrides_ == other.overrides_ && receiver_extends_ == other.receiver_extends_; } std::ostream& operator<<(std::ostream& out, const CallTarget& call_target) { out << "CallTarget(instruction=`" << show(call_target.instruction()) << "`, resolved_base_callee=`" << show(call_target.resolved_base_callee()) << "`"; if (call_target.is_virtual()) { out << ", receiver_type=`" << show(call_target.receiver_type()) << "`, overrides={"; for (const auto* method : call_target.overrides()) { out << "`" << show(method) << "`, "; } out << "}"; } return out << ")"; } bool ArtificialCallee::operator==(const ArtificialCallee& other) const { return call_target == other.call_target && register_parameters == other.register_parameters && features == other.features; } std::ostream& operator<<(std::ostream& out, const ArtificialCallee& callee) { out << "ArtificialCallee(call_target=" << callee.call_target << ", register_parameters=["; for (auto register_id : callee.register_parameters) { out << register_id << ", "; } return out << "], features=" << callee.features << ")"; } CallGraph::CallGraph( const Options& options, Methods& method_factory, Fields& field_factory, const Types& types, const ClassHierarchies& class_hierarchies, Overrides& override_factory, const Features& features) : types_(types), class_hierarchies_(class_hierarchies), overrides_(override_factory) { ConcurrentSet<const Method*> worklist; ConcurrentSet<const Method*> processed; for (const Method* method : method_factory) { worklist.insert(method); } std::atomic<std::size_t> method_iteration(0); std::size_t number_methods = 0; while (worklist.size() > 0) { auto queue = sparta::work_queue<const Method*>( [&](const Method* caller) { method_iteration++; if (method_iteration % 10000 == 0) { LOG(1, "Processed {}/{} methods.", method_iteration.load(), number_methods); } auto* code = caller->get_code(); if (!code) { return; } mt_assert(code->cfg_built()); std::unordered_map<const IRInstruction*, const Method*> callees; std::unordered_map<const IRInstruction*, ArtificialCallees> artificial_callees; std::unordered_map<const IRInstruction*, const Field*> field_accesses; for (const auto* block : code->cfg().blocks()) { for (const auto& entry : *block) { if (entry.type != MFLOW_OPCODE) { continue; } const auto* instruction = entry.insn; auto instruction_information = process_instruction( caller, instruction, worklist, processed, options, method_factory, field_factory, types, override_factory, features); if (instruction_information.callee) { callees.emplace(instruction, *(instruction_information.callee)); } if (instruction_information.artificial_callees.size() > 0) { artificial_callees.emplace( instruction, instruction_information.artificial_callees); } if (instruction_information.field_access) { field_accesses.emplace( instruction, *(instruction_information.field_access)); } } } if (!callees.empty()) { resolved_base_callees_.insert_or_assign( std::make_pair(caller, std::move(callees))); } if (!artificial_callees.empty()) { artificial_callees_.insert_or_assign( std::make_pair(caller, std::move(artificial_callees))); } if (!field_accesses.empty()) { resolved_fields_.insert_or_assign( std::make_pair(caller, std::move(field_accesses))); } }, sparta::parallel::default_num_threads()); for (const auto* method : worklist) { queue.add_item(method); processed.insert(method); } worklist.clear(); number_methods = method_factory.size(); queue.run_all(); } if (options.dump_call_graph()) { auto call_graph_path = options.call_graph_output_path(); LOG(1, "Writing call graph to `{}`", call_graph_path.native()); JsonValidation::write_json_file( call_graph_path, to_json(/* with_overrides */ false)); } } std::vector<CallTarget> CallGraph::callees(const Method* caller) const { // Note that `find` is not thread-safe, but this is fine because // `resolved_base_callees_` is read-only after the constructor completed. auto callees = resolved_base_callees_.find(caller); if (callees == resolved_base_callees_.end()) { return {}; } std::vector<CallTarget> call_targets; for (auto [instruction, resolved_base_callee] : callees->second) { call_targets.push_back(CallTarget::from_call_instruction( caller, instruction, resolved_base_callee, types_, class_hierarchies_, overrides_)); } return call_targets; } CallTarget CallGraph::callee( const Method* caller, const IRInstruction* instruction) const { return CallTarget::from_call_instruction( caller, instruction, resolved_base_callee(caller, instruction), types_, class_hierarchies_, overrides_); } const Method* MT_NULLABLE CallGraph::resolved_base_callee( const Method* caller, const IRInstruction* instruction) const { // Note that `find` is not thread-safe, but this is fine because // `resolved_base_callees_` is read-only after the constructor completed. auto callees = resolved_base_callees_.find(caller); if (callees == resolved_base_callees_.end()) { return nullptr; } auto callee = callees->second.find(instruction); if (callee == callees->second.end()) { return nullptr; } return callee->second; } const std::unordered_map<const IRInstruction*, ArtificialCallees>& CallGraph::artificial_callees(const Method* caller) const { // Note that `find` is not thread-safe, but this is fine because // `artificial_callees_` is read-only after the constructor completed. auto artificial_callees_map = artificial_callees_.find(caller); if (artificial_callees_map == artificial_callees_.end()) { return empty_artificial_callees_map_; } else { return artificial_callees_map->second; } } const ArtificialCallees& CallGraph::artificial_callees( const Method* caller, const IRInstruction* instruction) const { const auto& artificial_callees_map = this->artificial_callees(caller); auto artificial_callees = artificial_callees_map.find(instruction); if (artificial_callees == artificial_callees_map.end()) { return empty_artificial_callees_; } else { return artificial_callees->second; } } const Field* MT_NULLABLE CallGraph::resolved_field_access( const Method* caller, const IRInstruction* instruction) const { auto fields = resolved_fields_.find(caller); if (fields == resolved_fields_.end()) { return nullptr; } auto field = fields->second.find(instruction); if (field == fields->second.end()) { return nullptr; } return field->second; } Json::Value CallGraph::to_json(bool with_overrides) const { auto value = Json::Value(Json::objectValue); for (const auto& [method, callees] : resolved_base_callees_) { auto method_value = Json::Value(Json::objectValue); std::unordered_set<const Method*> static_callees; std::unordered_set<const Method*> virtual_callees; for (const auto& [instruction, resolved_base_callee] : callees) { auto call_target = CallTarget::from_call_instruction( method, instruction, resolved_base_callee, types_, class_hierarchies_, overrides_); if (!call_target.resolved()) { continue; } else if (call_target.is_virtual()) { virtual_callees.insert(call_target.resolved_base_callee()); if (with_overrides) { for (const auto* override : call_target.overrides()) { virtual_callees.insert(override); } } } else { static_callees.insert(call_target.resolved_base_callee()); } } if (!static_callees.empty()) { auto static_callees_value = Json::Value(Json::arrayValue); for (const auto* callee : static_callees) { static_callees_value.append(Json::Value(show(callee))); } method_value["static"] = static_callees_value; } if (!virtual_callees.empty()) { auto virtual_callees_value = Json::Value(Json::arrayValue); for (const auto* callee : virtual_callees) { virtual_callees_value.append(Json::Value(show(callee))); } method_value["virtual"] = virtual_callees_value; } value[show(method)] = method_value; } for (const auto& [method, instruction_artificial_callees] : artificial_callees_) { std::unordered_set<const Method*> callees; for (const auto& [instruction, artificial_callees] : instruction_artificial_callees) { for (const auto& artificial_callee : artificial_callees) { callees.insert(artificial_callee.call_target.resolved_base_callee()); } } auto callees_value = Json::Value(Json::arrayValue); for (const auto* callee : callees) { callees_value.append(Json::Value(show(callee))); } value[show(method)]["artificial"] = callees_value; } return value; } } // namespace marianatrench