opt/remove_redundant_check_casts/CheckCastAnalysis.cpp (480 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "CheckCastAnalysis.h"
#include "DexUtil.h"
#include "ReachingDefinitions.h"
#include "Show.h"
#include "StlUtil.h"
namespace check_casts {
namespace impl {
// Nullptr indicates that the type demand could not be computed exactly, and no
// weakening should take place.
DexType* CheckCastAnalysis::get_type_demand(IRInstruction* insn,
size_t src_index) const {
always_assert(src_index < insn->srcs_size());
switch (insn->opcode()) {
case OPCODE_GOTO:
case IOPCODE_LOAD_PARAM:
case IOPCODE_LOAD_PARAM_OBJECT:
case IOPCODE_LOAD_PARAM_WIDE:
case OPCODE_NOP:
case IOPCODE_MOVE_RESULT_PSEUDO:
case OPCODE_MOVE_RESULT:
case IOPCODE_MOVE_RESULT_PSEUDO_OBJECT:
case OPCODE_MOVE_RESULT_OBJECT:
case IOPCODE_MOVE_RESULT_PSEUDO_WIDE:
case OPCODE_MOVE_RESULT_WIDE:
case OPCODE_MOVE_EXCEPTION:
case OPCODE_RETURN_VOID:
case OPCODE_CONST:
case OPCODE_CONST_WIDE:
case OPCODE_CONST_STRING:
case OPCODE_CONST_CLASS:
case OPCODE_NEW_INSTANCE:
case OPCODE_SGET:
case OPCODE_SGET_BOOLEAN:
case OPCODE_SGET_BYTE:
case OPCODE_SGET_CHAR:
case OPCODE_SGET_SHORT:
case OPCODE_SGET_WIDE:
case OPCODE_SGET_OBJECT:
case OPCODE_RETURN:
case OPCODE_RETURN_WIDE:
case OPCODE_MOVE:
case OPCODE_MOVE_WIDE:
case OPCODE_NEW_ARRAY:
case OPCODE_SWITCH:
case OPCODE_NEG_INT:
case OPCODE_NOT_INT:
case OPCODE_INT_TO_BYTE:
case OPCODE_INT_TO_CHAR:
case OPCODE_INT_TO_SHORT:
case OPCODE_INT_TO_LONG:
case OPCODE_INT_TO_FLOAT:
case OPCODE_INT_TO_DOUBLE:
case OPCODE_ADD_INT:
case OPCODE_SUB_INT:
case OPCODE_MUL_INT:
case OPCODE_AND_INT:
case OPCODE_OR_INT:
case OPCODE_XOR_INT:
case OPCODE_SHL_INT:
case OPCODE_SHR_INT:
case OPCODE_USHR_INT:
case OPCODE_DIV_INT:
case OPCODE_REM_INT:
case OPCODE_ADD_INT_LIT16:
case OPCODE_RSUB_INT:
case OPCODE_MUL_INT_LIT16:
case OPCODE_AND_INT_LIT16:
case OPCODE_OR_INT_LIT16:
case OPCODE_XOR_INT_LIT16:
case OPCODE_ADD_INT_LIT8:
case OPCODE_RSUB_INT_LIT8:
case OPCODE_MUL_INT_LIT8:
case OPCODE_AND_INT_LIT8:
case OPCODE_OR_INT_LIT8:
case OPCODE_XOR_INT_LIT8:
case OPCODE_SHL_INT_LIT8:
case OPCODE_SHR_INT_LIT8:
case OPCODE_USHR_INT_LIT8:
case OPCODE_DIV_INT_LIT16:
case OPCODE_REM_INT_LIT16:
case OPCODE_DIV_INT_LIT8:
case OPCODE_REM_INT_LIT8:
case OPCODE_CMPL_FLOAT:
case OPCODE_CMPG_FLOAT:
case OPCODE_NEG_FLOAT:
case OPCODE_FLOAT_TO_INT:
case OPCODE_FLOAT_TO_LONG:
case OPCODE_FLOAT_TO_DOUBLE:
case OPCODE_ADD_FLOAT:
case OPCODE_SUB_FLOAT:
case OPCODE_MUL_FLOAT:
case OPCODE_DIV_FLOAT:
case OPCODE_REM_FLOAT:
case OPCODE_CMPL_DOUBLE:
case OPCODE_CMPG_DOUBLE:
case OPCODE_NEG_DOUBLE:
case OPCODE_DOUBLE_TO_INT:
case OPCODE_DOUBLE_TO_LONG:
case OPCODE_DOUBLE_TO_FLOAT:
case OPCODE_ADD_DOUBLE:
case OPCODE_SUB_DOUBLE:
case OPCODE_MUL_DOUBLE:
case OPCODE_DIV_DOUBLE:
case OPCODE_REM_DOUBLE:
case OPCODE_CMP_LONG:
case OPCODE_NEG_LONG:
case OPCODE_NOT_LONG:
case OPCODE_LONG_TO_INT:
case OPCODE_LONG_TO_FLOAT:
case OPCODE_LONG_TO_DOUBLE:
case OPCODE_ADD_LONG:
case OPCODE_SUB_LONG:
case OPCODE_MUL_LONG:
case OPCODE_AND_LONG:
case OPCODE_OR_LONG:
case OPCODE_XOR_LONG:
case OPCODE_DIV_LONG:
case OPCODE_REM_LONG:
case OPCODE_SHL_LONG:
case OPCODE_SHR_LONG:
case OPCODE_USHR_LONG:
case OPCODE_IF_LTZ:
case OPCODE_IF_GEZ:
case OPCODE_IF_GTZ:
case OPCODE_IF_LEZ:
case OPCODE_IF_LT:
case OPCODE_IF_GE:
case OPCODE_IF_GT:
case OPCODE_IF_LE:
case OPCODE_SPUT:
case OPCODE_SPUT_BOOLEAN:
case OPCODE_SPUT_BYTE:
case OPCODE_SPUT_CHAR:
case OPCODE_SPUT_SHORT:
case OPCODE_SPUT_WIDE:
case IOPCODE_INIT_CLASS:
not_reached();
case OPCODE_FILLED_NEW_ARRAY:
return type::get_array_component_type(insn->get_type());
case OPCODE_RETURN_OBJECT:
return m_method->get_proto()->get_rtype();
case OPCODE_MOVE_OBJECT:
case OPCODE_MONITOR_ENTER:
case OPCODE_MONITOR_EXIT:
return type::java_lang_Object();
case OPCODE_ARRAY_LENGTH:
case OPCODE_FILL_ARRAY_DATA:
case OPCODE_AGET:
case OPCODE_AGET_BOOLEAN:
case OPCODE_AGET_BYTE:
case OPCODE_AGET_CHAR:
case OPCODE_AGET_SHORT:
case OPCODE_AGET_WIDE:
case OPCODE_AGET_OBJECT:
return nullptr;
case OPCODE_THROW:
return type::java_lang_Throwable();
case OPCODE_IGET:
case OPCODE_IGET_BOOLEAN:
case OPCODE_IGET_BYTE:
case OPCODE_IGET_CHAR:
case OPCODE_IGET_SHORT:
case OPCODE_IGET_WIDE:
case OPCODE_IGET_OBJECT:
return insn->get_field()->get_class();
case OPCODE_INSTANCE_OF:
case OPCODE_CHECK_CAST:
return type::java_lang_Object();
case OPCODE_IF_EQ:
case OPCODE_IF_NE:
case OPCODE_IF_EQZ:
case OPCODE_IF_NEZ:
return type::java_lang_Object();
case OPCODE_APUT_OBJECT:
if (src_index == 0) {
// There seems to be very little static verification for this
// instruction, as most is deferred to runtime.
// https://android.googlesource.com/platform/dalvik/+/android-cts-4.4_r4/vm/analysis/CodeVerify.cpp#186
// So, we can just get away with the following:
return type::java_lang_Object();
}
if (src_index == 1) {
return DexType::make_type("[Ljava/lang/Object;");
}
BOOST_FALLTHROUGH;
case OPCODE_APUT:
case OPCODE_APUT_BOOLEAN:
case OPCODE_APUT_BYTE:
case OPCODE_APUT_CHAR:
case OPCODE_APUT_SHORT:
case OPCODE_APUT_WIDE:
return nullptr;
case OPCODE_IPUT_OBJECT:
if (src_index == 0) {
return insn->get_field()->get_type();
}
BOOST_FALLTHROUGH;
case OPCODE_IPUT:
case OPCODE_IPUT_BOOLEAN:
case OPCODE_IPUT_BYTE:
case OPCODE_IPUT_CHAR:
case OPCODE_IPUT_SHORT:
case OPCODE_IPUT_WIDE:
if (src_index == 1) {
return insn->get_field()->get_class();
}
return nullptr;
case OPCODE_SPUT_OBJECT:
return insn->get_field()->get_type();
case OPCODE_INVOKE_VIRTUAL:
case OPCODE_INVOKE_SUPER:
case OPCODE_INVOKE_DIRECT:
case OPCODE_INVOKE_STATIC:
case OPCODE_INVOKE_INTERFACE: {
DexMethodRef* insn_method = insn->get_method();
const auto* arg_types = insn_method->get_proto()->get_args();
size_t expected_args =
(insn->opcode() != OPCODE_INVOKE_STATIC ? 1 : 0) + arg_types->size();
always_assert(insn->srcs_size() == expected_args);
if (insn->opcode() != OPCODE_INVOKE_STATIC) {
// The first argument is a reference to the object instance on which the
// method is invoked.
if (src_index-- == 0) return insn_method->get_class();
}
return arg_types->at(src_index);
}
case OPCODE_INVOKE_CUSTOM:
case OPCODE_INVOKE_POLYMORPHIC:
not_reached_log("Unsupported instruction {%s}\n", SHOW(insn));
}
}
// This function is conservative and returns false if type_class is missing.
// A type is "interfacy" if it's an interface, or an array of an interface.
static bool is_not_interfacy(DexType* type) {
auto cls = type_class(type::get_element_type_if_array(type));
return cls && !is_interface(cls);
}
// Weakens the given type in a way that's aware of the check-cast relationship
// of arrays. (However, it does not consider interfaces in a special way.)
static DexType* weaken_type(DexType* type) {
if (type::is_array(type)) {
auto element_type = type::get_array_element_type(type);
if (!type::is_primitive(element_type)) {
auto weakened_element_type = weaken_type(element_type);
if (weakened_element_type) {
return type::make_array_type(weakened_element_type);
}
}
}
auto cls = type_class(type);
if (!cls) {
return nullptr;
}
return cls->get_super_class();
}
DexType* CheckCastAnalysis::weaken_to_demand(
IRInstruction* insn, DexType* type, bool weaken_to_not_interfacy) const {
if (!m_insn_demands) {
// Weakening is disabled.
return type;
}
auto it = m_insn_demands->find(insn);
if (it == m_insn_demands->end()) {
return type::java_lang_Object();
}
auto& demands = it->second;
always_assert(!demands.empty());
if (demands.size() == 1) {
auto weakened_type = *demands.begin();
// Nullptr indicates that the type demand could not be computed exactly, and
// no weakening should take place.
if (weakened_type == nullptr) {
return type;
}
if (weakened_type == type::java_lang_Enum()) {
// TODO: Weaking across enums is technically correct, but exposes a
// limitation in the EnumTransformer, so we just don't do it for now
return type;
}
// Note that this singleton-demand may be an interface
if (!weaken_to_not_interfacy || is_not_interfacy(weakened_type)) {
return weakened_type;
}
}
always_assert(!demands.count(nullptr));
auto meets_demands = [&](DexType* t) {
for (auto d : demands) {
if (!type::check_cast(t, d)) {
return false;
}
}
return true;
};
// A function that checks if a given type can be safely used.
// In particular, we need to filter out external types that are not already
// explicitly mentioned (in the demand set), as they might refer to a type
// that's only available on a particular Android platform.
auto is_safe = [&](DexType* t) {
auto u = type::is_array(t) ? type::get_array_element_type(t) : t;
auto cls = type_class(u);
return cls && (!cls->is_external() || demands.count(t));
};
while (true) {
auto weakened_type = weaken_type(type);
if (weakened_type == nullptr || !meets_demands(weakened_type) ||
!is_safe(weakened_type)) {
return type;
}
if (weakened_type == type::java_lang_Enum()) {
// TODO: Weaking across enums is technically correct, but exposes a
// limitation in the EnumTransformer, so we just don't do it for now
return type;
}
type = weakened_type;
}
}
CheckCastAnalysis::CheckCastAnalysis(const CheckCastConfig& config,
DexMethod* method)
: m_class_cast_exception_type(
DexType::make_type("Ljava/lang/ClassCastException;")),
m_method(method) {
always_assert(m_class_cast_exception_type);
if (!method || !method->get_code()) {
return;
}
if (method->str().find("$xXX") != std::string::npos) {
// There is some Ultralight/SwitchInline magic that trips up when
// casts get weakened, so that we don't operate on those magic methods.
return;
}
auto& cfg = method->get_code()->cfg();
auto iterable = cfg::InstructionIterable(cfg);
for (auto it = iterable.begin(); it != iterable.end(); ++it) {
IRInstruction* insn = it->insn;
if (insn->opcode() == OPCODE_CHECK_CAST) {
m_check_cast_its.push_back(it);
}
}
if (m_check_cast_its.empty()) {
return;
}
if (!config.weaken) {
return;
}
m_insn_demands = std::make_unique<InstructionTypeDemands>();
reaching_defs::MoveAwareFixpointIterator reaching_definitions(cfg);
reaching_definitions.run({});
for (cfg::Block* block : cfg.blocks()) {
auto env = reaching_definitions.get_entry_state_at(block);
if (env.is_bottom()) {
continue;
}
for (auto& mie : InstructionIterable(block)) {
IRInstruction* insn = mie.insn;
for (size_t src_index = 0; src_index < insn->srcs_size(); src_index++) {
auto src = insn->src(src_index);
const auto& defs = env.get(src);
always_assert(!defs.is_bottom() && !defs.is_top());
for (auto def : defs.elements()) {
auto def_opcode = def->opcode();
if (def_opcode == OPCODE_CHECK_CAST) {
// When two check-casts interact, we prevent weakening of the
// first to avoid situations where both get removed as they may
// make each other redundant.
auto t = insn->opcode() == OPCODE_CHECK_CAST
? nullptr
: get_type_demand(insn, src_index);
always_assert(t == nullptr || type::is_object(t));
if (t != type::java_lang_Object()) {
(*m_insn_demands)[def].insert(t);
}
}
}
}
reaching_definitions.analyze_instruction(insn, &env);
}
}
// Simplify demands
for (auto& p : *m_insn_demands) {
auto& demands = p.second;
if (demands.count(nullptr)) {
// no need to keep around anything else
std20::erase_if(demands, [](auto* t) { return t; });
always_assert(demands.count(nullptr));
always_assert(demands.size() == 1);
continue;
}
// Remove weakened types.
std::unordered_set<DexType*> weakened_types;
std::queue<DexType*> queue;
auto enqueue_weakened_types = [&queue](DexType* type) {
auto weakened_type = weaken_type(type);
if (weakened_type) {
queue.push(weakened_type);
}
// We also handle interface hierarchies here.
auto cls = type_class(type);
if (cls) {
for (auto interface : *cls->get_interfaces()) {
queue.push(interface);
}
}
};
for (auto demand : demands) {
enqueue_weakened_types(demand);
}
while (!queue.empty()) {
auto weakened_type = queue.front();
queue.pop();
if (weakened_types.insert(weakened_type).second) {
enqueue_weakened_types(weakened_type);
}
}
for (auto weakened_type : weakened_types) {
if (demands.erase(weakened_type)) {
// Double check that the just erased demand was indeed redundant
always_assert(
std::find_if(demands.begin(), demands.end(), [&](DexType* demand) {
return !weakened_types.count(demand) &&
type::check_cast(demand, weakened_type);
}) != demands.end());
}
}
}
}
CheckCastReplacements CheckCastAnalysis::collect_redundant_checks_replacement()
const {
CheckCastReplacements redundant_check_casts;
for (const auto& it : m_check_cast_its) {
cfg::Block* block = it.block();
IRInstruction* insn = it->insn;
always_assert(insn->opcode() == OPCODE_CHECK_CAST);
auto check_type = insn->get_type();
if (!can_catch_class_cast_exception(block)) {
check_type = weaken_to_demand(insn, check_type,
/* weaken_to_not_interfacy */ false);
}
if (is_check_cast_redundant(insn, check_type)) {
auto src = insn->src(0);
auto move = m_method->get_code()->cfg().move_result_of(it);
if (move.is_end()) {
continue;
}
auto dst = move->insn->dest();
if (src == dst) {
redundant_check_casts.emplace_back(block, insn, boost::none,
boost::none);
} else {
auto new_move = new IRInstruction(OPCODE_MOVE_OBJECT);
new_move->set_src(0, src);
new_move->set_dest(dst);
redundant_check_casts.emplace_back(
block,
insn,
boost::optional<IRInstruction*>(new_move),
boost::none);
}
} else if (check_type != insn->get_type()) {
// We don't want to weaken a class to an interface for performance reason.
// Re-compute the weakened type in that case, excluding interfaces.
if (is_not_interfacy(insn->get_type()) && !is_not_interfacy(check_type)) {
check_type = weaken_to_demand(insn, insn->get_type(),
/* weaken_to_not_interfacy */ true);
}
if (check_type != insn->get_type()) {
redundant_check_casts.emplace_back(
block, insn, boost::none, boost::optional<DexType*>(check_type));
}
}
}
return redundant_check_casts;
}
bool CheckCastAnalysis::is_check_cast_redundant(IRInstruction* insn,
DexType* check_type) const {
always_assert(insn->opcode() == OPCODE_CHECK_CAST);
if (check_type == type::java_lang_Object()) {
return true;
}
auto reg = insn->src(0);
auto type_inference = get_type_inference();
auto& envs = type_inference->get_type_environments();
auto& env = envs.at(insn);
auto type = env.get_type(reg);
if (type.equals(type_inference::TypeDomain(ZERO))) {
return true;
}
auto dex_type = env.get_dex_type(reg);
if (dex_type && type::check_cast(*dex_type, check_type)) {
return true;
}
return false;
}
type_inference::TypeInference* CheckCastAnalysis::get_type_inference() const {
if (!m_type_inference) {
m_type_inference = std::make_unique<type_inference::TypeInference>(
m_method->get_code()->cfg());
m_type_inference->run(m_method);
}
return m_type_inference.get();
}
bool CheckCastAnalysis::can_catch_class_cast_exception(
cfg::Block* block) const {
for (auto edge : block->succs()) {
if (edge->type() != cfg::EDGE_THROW) {
continue;
}
auto catch_type = edge->throw_info()->catch_type;
if (!catch_type ||
type::is_subclass(catch_type, m_class_cast_exception_type)) {
return true;
}
}
return false;
}
} // namespace impl
} // namespace check_casts