source/CallPositionFrames.h (179 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <boost/iterator/transform_iterator.hpp>
#include <initializer_list>
#include <ostream>
#include <json/json.h>
#include <AbstractDomain.h>
#include <PatriciaTreeMapAbstractPartition.h>
#include <mariana-trench/FlattenIterator.h>
#include <mariana-trench/Frame.h>
#include <mariana-trench/GroupHashedSetAbstractDomain.h>
namespace marianatrench {
/**
* Represents a set of frames with the same call position.
* Based on its position in `Taint`, it is expected that all frames within
* this class have the same callee and call position.
*/
class CallPositionFrames final
: public sparta::AbstractDomain<CallPositionFrames> {
private:
using Frames =
GroupHashedSetAbstractDomain<Frame, Frame::GroupHash, Frame::GroupEqual>;
using FramesByKind =
sparta::PatriciaTreeMapAbstractPartition<const Kind*, Frames>;
private:
// Iterator based on `FlattenIterator`.
struct KindToFramesMapDereference {
static Frames::iterator begin(const std::pair<const Kind*, Frames>& pair) {
return pair.second.begin();
}
static Frames::iterator end(const std::pair<const Kind*, Frames>& pair) {
return pair.second.end();
}
};
using ConstIterator = FlattenIterator<
/* OuterIterator */ FramesByKind::MapType::iterator,
/* InnerIterator */ Frames::iterator,
KindToFramesMapDereference>;
public:
// C++ container concept member types
using iterator = ConstIterator;
using const_iterator = ConstIterator;
using value_type = Frame;
using difference_type = std::ptrdiff_t;
using size_type = std::size_t;
using const_reference = const Frame&;
using const_pointer = const Frame*;
private:
explicit CallPositionFrames(
const Position* MT_NULLABLE position,
FramesByKind frames)
: position_(position), frames_(std::move(frames)) {}
public:
/* Create the bottom (i.e, empty) frame set. */
CallPositionFrames() : position_(nullptr), frames_(FramesByKind::bottom()) {}
explicit CallPositionFrames(std::initializer_list<Frame> frames);
CallPositionFrames(const CallPositionFrames&) = default;
CallPositionFrames(CallPositionFrames&&) = default;
CallPositionFrames& operator=(const CallPositionFrames&) = default;
CallPositionFrames& operator=(CallPositionFrames&&) = default;
static CallPositionFrames bottom() {
return CallPositionFrames(
/* position */ nullptr, FramesByKind::bottom());
}
static CallPositionFrames top() {
return CallPositionFrames(
/* position */ nullptr, FramesByKind::top());
}
bool is_bottom() const override {
return frames_.is_bottom();
}
bool is_top() const override {
return frames_.is_top();
}
void set_to_bottom() override {
position_ = nullptr;
frames_.set_to_bottom();
}
void set_to_top() override {
position_ = nullptr;
frames_.set_to_top();
}
bool empty() const {
return frames_.is_bottom();
}
const Position* MT_NULLABLE position() const {
return position_;
}
void add(const Frame& frame);
bool leq(const CallPositionFrames& other) const override;
bool equals(const CallPositionFrames& other) const override;
void join_with(const CallPositionFrames& other) override;
void widen_with(const CallPositionFrames& other) override;
void meet_with(const CallPositionFrames& other) override;
void narrow_with(const CallPositionFrames& other) override;
void difference_with(const CallPositionFrames& other);
void map(const std::function<void(Frame&)>& f);
ConstIterator begin() const {
return ConstIterator(frames_.bindings().begin(), frames_.bindings().end());
}
ConstIterator end() const {
return ConstIterator(frames_.bindings().end(), frames_.bindings().end());
}
void add_inferred_features(const FeatureMayAlwaysSet& features);
LocalPositionSet local_positions() const;
void add_local_position(const Position* position);
void set_local_positions(const LocalPositionSet& positions);
void add_inferred_features_and_local_position(
const FeatureMayAlwaysSet& features,
const Position* MT_NULLABLE position);
/**
* Propagate the taint from the callee to the caller.
*
* Return bottom if the taint should not be propagated.
*/
CallPositionFrames propagate(
const Method* callee,
const AccessPath& callee_port,
const Position* call_position,
int maximum_source_sink_distance,
Context& context,
const std::vector<const DexType * MT_NULLABLE>& source_register_types,
const std::vector<std::optional<std::string>>& source_constant_arguments)
const;
/* Return the set of leaf frames with the given position. */
CallPositionFrames attach_position(const Position* position) const;
CallPositionFrames transform_kind_with_features(
const std::function<std::vector<const Kind*>(const Kind*)>&,
const std::function<FeatureMayAlwaysSet(const Kind*)>&) const;
void append_callee_port(
Path::Element path_element,
const std::function<bool(const Kind*)>& filter);
void filter_invalid_frames(
const std::function<
bool(const Method* MT_NULLABLE, const AccessPath&, const Kind*)>&
is_valid);
bool contains_kind(const Kind*) const;
template <class T>
std::unordered_map<T, CallPositionFrames> partition_by_kind(
const std::function<T(const Kind*)>& map_kind) const {
std::unordered_map<T, CallPositionFrames> result;
for (const auto& [kind, kind_frames] : frames_.bindings()) {
T mapped_value = map_kind(kind);
auto new_frames = CallPositionFrames(
position_, FramesByKind{std::pair(kind, kind_frames)});
auto existing = result.find(mapped_value);
auto existing_or_bottom = existing == result.end()
? CallPositionFrames::bottom()
: existing->second;
existing_or_bottom.join_with(new_frames);
result[mapped_value] = existing_or_bottom;
}
return result;
}
template <class T>
std::unordered_map<T, std::vector<std::reference_wrapper<const Frame>>>
partition_map(const std::function<T(const Frame&)>& map) const {
std::unordered_map<T, std::vector<std::reference_wrapper<const Frame>>>
result;
for (const auto& [_, frames] : frames_.bindings()) {
for (const auto& frame : frames) {
auto value = map(frame);
result[value].push_back(std::cref(frame));
}
}
return result;
}
friend std::ostream& operator<<(
std::ostream& out,
const CallPositionFrames& frames);
private:
Frame propagate_frames(
const Method* callee,
const AccessPath& callee_port,
const Position* call_position,
int maximum_source_sink_distance,
Context& context,
const std::vector<const DexType * MT_NULLABLE>& source_register_types,
const std::vector<std::optional<std::string>>& source_constant_arguments,
std::vector<std::reference_wrapper<const Frame>> frames,
std::vector<const Feature*>& via_type_of_features_added) const;
CallPositionFrames propagate_crtex_frames(
const Method* callee,
const AccessPath& callee_port,
const Position* call_position,
int maximum_source_sink_distance,
Context& context,
const std::vector<const DexType * MT_NULLABLE>& source_register_types,
std::vector<std::reference_wrapper<const Frame>> frames) const;
private:
const Position* MT_NULLABLE position_;
FramesByKind frames_;
// TODO(T91357916): Move local_positions and local_features here from `Frame`.
};
} // namespace marianatrench