plugins/experimental/wasm/lib/include/proxy-wasm/wasm.h (332 lines of code) (raw):
// Copyright 2016-2019 Envoy Project Authors
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "include/proxy-wasm/context.h"
#include "include/proxy-wasm/exports.h"
#include "include/proxy-wasm/wasm_vm.h"
#include "include/proxy-wasm/vm_id_handle.h"
namespace proxy_wasm {
#include "proxy_wasm_common.h"
class ContextBase;
class WasmHandleBase;
using WasmVmFactory = std::function<std::unique_ptr<WasmVm>()>;
using CallOnThreadFunction = std::function<void(std::function<void()>)>;
struct SanitizationConfig {
std::vector<std::string> argument_list;
bool is_allowlist;
};
using AllowedCapabilitiesMap = std::unordered_map<std::string, SanitizationConfig>;
// Wasm execution instance. Manages the host side of the Wasm interface.
class WasmBase : public std::enable_shared_from_this<WasmBase> {
public:
WasmBase(std::unique_ptr<WasmVm> wasm_vm, std::string_view vm_id,
std::string_view vm_configuration, std::string_view vm_key,
std::unordered_map<std::string, std::string> envs,
AllowedCapabilitiesMap allowed_capabilities);
WasmBase(const std::shared_ptr<WasmHandleBase> &base_wasm_handle, const WasmVmFactory &factory);
virtual ~WasmBase();
bool load(const std::string &code, bool allow_precompiled = false);
bool initialize();
void startVm(ContextBase *root_context);
bool configure(ContextBase *root_context, std::shared_ptr<PluginBase> plugin);
// Returns the root ContextBase or nullptr if onStart returns false.
ContextBase *start(const std::shared_ptr<PluginBase> &plugin);
std::string_view vm_id() const { return vm_id_; }
std::string_view vm_key() const { return vm_key_; }
WasmVm *wasm_vm() const { return wasm_vm_.get(); }
ContextBase *vm_context() const { return vm_context_.get(); }
ContextBase *getRootContext(const std::shared_ptr<PluginBase> &plugin, bool allow_closed);
ContextBase *getContext(uint32_t id) {
auto it = contexts_.find(id);
if (it != contexts_.end())
return it->second;
return nullptr;
}
uint32_t allocContextId();
bool isFailed() { return failed_ != FailState::Ok; }
FailState fail_state() { return failed_; }
const std::string &vm_configuration() const;
const std::string &moduleBytecode() const { return module_bytecode_; }
const std::string &modulePrecompiled() const { return module_precompiled_; }
const std::unordered_map<uint32_t, std::string> functionNames() const { return function_names_; }
void timerReady(uint32_t root_context_id);
void queueReady(uint32_t root_context_id, uint32_t token);
void startShutdown(std::string_view plugin_key);
void startShutdown();
WasmResult done(ContextBase *root_context);
void finishShutdown();
// Proxy specific extension points.
//
virtual void registerCallbacks(); // Register functions called out from Wasm.
virtual void getFunctions(); // Get functions call into Wasm.
virtual CallOnThreadFunction callOnThreadFunction() {
unimplemented();
return nullptr;
}
// Capability restriction (restricting/exposing the ABI).
bool capabilityAllowed(std::string capability_name) {
return allowed_capabilities_.empty() ||
allowed_capabilities_.find(capability_name) != allowed_capabilities_.end();
}
virtual ContextBase *createVmContext() { return new ContextBase(this); }
virtual ContextBase *createRootContext(const std::shared_ptr<PluginBase> &plugin) {
return new ContextBase(this, plugin);
}
virtual ContextBase *createContext(const std::shared_ptr<PluginBase> &plugin) {
return new ContextBase(this, plugin);
}
virtual void setTimerPeriod(uint32_t root_context_id, std::chrono::milliseconds period) {
timer_period_[root_context_id] = period;
}
// Support functions.
//
void *allocMemory(uint64_t size, uint64_t *address);
// Allocate a null-terminated string in the VM and return the pointer to use as a call arguments.
uint64_t copyString(std::string_view s);
// Copy the data in 's' into the VM along with the pointer-size pair. Returns true on success.
bool copyToPointerSize(std::string_view s, uint64_t ptr_ptr, uint64_t size_ptr);
template <typename T> bool setDatatype(uint64_t ptr, const T &t);
void fail(FailState fail_state, std::string_view message) {
error(message);
failed_ = fail_state;
}
virtual void error(std::string_view message) { std::cerr << message << "\n"; }
virtual void unimplemented() { error("unimplemented proxy-wasm API"); }
AbiVersion abiVersion() const { return abi_version_; }
const std::unordered_map<std::string, std::string> &envs() { return envs_; }
// Called to raise the flag which indicates that the context should stop iteration regardless of
// returned filter status from Proxy-Wasm extensions. For example, we ignore
// FilterHeadersStatus::Continue after a local reponse is sent by the host.
void stopNextIteration(bool stop) { stop_iteration_ = stop; };
bool isNextIterationStopped() { return stop_iteration_; };
void addAfterVmCallAction(std::function<void()> f) { after_vm_call_actions_.push_back(f); }
void doAfterVmCallActions() {
// NB: this may be deleted by a delayed function unless prevented.
if (!after_vm_call_actions_.empty()) {
auto self = shared_from_this();
while (!self->after_vm_call_actions_.empty()) {
auto f = std::move(self->after_vm_call_actions_.front());
self->after_vm_call_actions_.pop_front();
f();
}
}
}
static const uint32_t kMetricTypeMask = 0x3; // Enough to cover the 3 types.
static const uint32_t kMetricIdIncrement = 0x4; // Enough to cover the 3 types.
bool isCounterMetricId(uint32_t metric_id) {
return (metric_id & kMetricTypeMask) == static_cast<uint32_t>(MetricType::Counter);
}
bool isGaugeMetricId(uint32_t metric_id) {
return (metric_id & kMetricTypeMask) == static_cast<uint32_t>(MetricType::Gauge);
}
bool isHistogramMetricId(uint32_t metric_id) {
return (metric_id & kMetricTypeMask) == static_cast<uint32_t>(MetricType::Histogram);
}
uint32_t nextCounterMetricId() { return next_counter_metric_id_ += kMetricIdIncrement; }
uint32_t nextGaugeMetricId() { return next_gauge_metric_id_ += kMetricIdIncrement; }
uint32_t nextHistogramMetricId() { return next_histogram_metric_id_ += kMetricIdIncrement; }
enum class CalloutType : uint32_t {
HttpCall = 0,
GrpcCall = 1,
GrpcStream = 2,
};
static const uint32_t kCalloutTypeMask = 0x3; // Enough to cover the 3 types.
static const uint32_t kCalloutIncrement = 0x4; // Enough to cover the 3 types.
bool isHttpCallId(uint32_t callout_id) {
return (callout_id & kCalloutTypeMask) == static_cast<uint32_t>(CalloutType::HttpCall);
}
bool isGrpcCallId(uint32_t callout_id) {
return (callout_id & kCalloutTypeMask) == static_cast<uint32_t>(CalloutType::GrpcCall);
}
bool isGrpcStreamId(uint32_t callout_id) {
return (callout_id & kCalloutTypeMask) == static_cast<uint32_t>(CalloutType::GrpcStream);
}
uint32_t nextHttpCallId() {
// TODO(PiotrSikora): re-add rollover protection (requires at least 1 billion callouts).
return next_http_call_id_ += kCalloutIncrement;
}
uint32_t nextGrpcCallId() {
// TODO(PiotrSikora): re-add rollover protection (requires at least 1 billion callouts).
return next_grpc_call_id_ += kCalloutIncrement;
}
uint32_t nextGrpcStreamId() {
// TODO(PiotrSikora): re-add rollover protection (requires at least 1 billion callouts).
return next_grpc_stream_id_ += kCalloutIncrement;
}
protected:
friend class ContextBase;
class ShutdownHandle;
void establishEnvironment(); // Language specific environments.
std::string vm_id_; // User-provided vm_id.
std::string vm_key_; // vm_id + hash of code.
std::unique_ptr<WasmVm> wasm_vm_;
std::optional<Cloneable> started_from_;
uint32_t next_context_id_ = 1; // 0 is reserved for the VM context.
std::shared_ptr<ContextBase> vm_context_; // Context unrelated to any specific root or stream
// (e.g. for global constructors).
std::unordered_map<std::string, std::unique_ptr<ContextBase>> root_contexts_; // Root contexts.
std::unordered_map<std::string, std::unique_ptr<ContextBase>> pending_done_; // Root contexts.
std::unordered_set<std::unique_ptr<ContextBase>> pending_delete_; // Root contexts.
std::unordered_map<uint32_t, ContextBase *> contexts_; // Contains all contexts.
std::unordered_map<uint32_t, std::chrono::milliseconds> timer_period_; // per root_id.
std::unique_ptr<ShutdownHandle> shutdown_handle_;
std::unordered_map<std::string, std::string>
envs_; // environment variables passed through wasi.environ_get
WasmCallVoid<0> _initialize_; /* WASI reactor (Emscripten v1.39.17+, Rust nightly) */
WasmCallVoid<0> _start_; /* WASI command (Emscripten v1.39.0+, TinyGo) */
WasmCallWord<2> main_;
WasmCallWord<1> malloc_;
// Calls into the VM.
WasmCallWord<2> validate_configuration_;
WasmCallWord<2> on_vm_start_;
WasmCallWord<2> on_configure_;
WasmCallVoid<1> on_tick_;
WasmCallVoid<2> on_context_create_;
WasmCallWord<1> on_new_connection_;
WasmCallWord<3> on_downstream_data_;
WasmCallWord<3> on_upstream_data_;
WasmCallVoid<2> on_downstream_connection_close_;
WasmCallVoid<2> on_upstream_connection_close_;
WasmCallWord<2> on_request_headers_abi_01_;
WasmCallWord<3> on_request_headers_abi_02_;
WasmCallWord<3> on_request_body_;
WasmCallWord<2> on_request_trailers_;
WasmCallWord<2> on_request_metadata_;
WasmCallWord<2> on_response_headers_abi_01_;
WasmCallWord<3> on_response_headers_abi_02_;
WasmCallWord<3> on_response_body_;
WasmCallWord<2> on_response_trailers_;
WasmCallWord<2> on_response_metadata_;
WasmCallVoid<5> on_http_call_response_;
WasmCallVoid<3> on_grpc_receive_;
WasmCallVoid<3> on_grpc_close_;
WasmCallVoid<3> on_grpc_create_initial_metadata_;
WasmCallVoid<3> on_grpc_receive_initial_metadata_;
WasmCallVoid<3> on_grpc_receive_trailing_metadata_;
WasmCallVoid<2> on_queue_ready_;
WasmCallVoid<3> on_foreign_function_;
WasmCallWord<1> on_done_;
WasmCallVoid<1> on_log_;
WasmCallVoid<1> on_delete_;
#define FOR_ALL_MODULE_FUNCTIONS(_f) \
_f(validate_configuration) _f(on_vm_start) _f(on_configure) _f(on_tick) _f(on_context_create) \
_f(on_new_connection) _f(on_downstream_data) _f(on_upstream_data) \
_f(on_downstream_connection_close) _f(on_upstream_connection_close) _f(on_request_body) \
_f(on_request_trailers) _f(on_request_metadata) _f(on_response_body) \
_f(on_response_trailers) _f(on_response_metadata) _f(on_http_call_response) \
_f(on_grpc_receive) _f(on_grpc_close) _f(on_grpc_receive_initial_metadata) \
_f(on_grpc_receive_trailing_metadata) _f(on_queue_ready) _f(on_done) \
_f(on_log) _f(on_delete)
// Capabilities which are allowed to be linked to the module. If this is empty, restriction
// is not enforced.
AllowedCapabilitiesMap allowed_capabilities_;
std::shared_ptr<WasmHandleBase> base_wasm_handle_;
// Used by the base_wasm to enable non-clonable thread local Wasm(s) to be constructed.
std::string module_bytecode_;
std::string module_precompiled_;
std::unordered_map<uint32_t, std::string> function_names_;
// ABI version.
AbiVersion abi_version_ = AbiVersion::Unknown;
std::string vm_configuration_;
bool stop_iteration_ = false;
FailState failed_ = FailState::Ok; // Wasm VM fatal error.
// Plugin Stats/Metrics
uint32_t next_counter_metric_id_ = static_cast<uint32_t>(MetricType::Counter);
uint32_t next_gauge_metric_id_ = static_cast<uint32_t>(MetricType::Gauge);
uint32_t next_histogram_metric_id_ = static_cast<uint32_t>(MetricType::Histogram);
// HTTP/gRPC callouts.
uint32_t next_http_call_id_ = static_cast<uint32_t>(CalloutType::HttpCall);
uint32_t next_grpc_call_id_ = static_cast<uint32_t>(CalloutType::GrpcCall);
uint32_t next_grpc_stream_id_ = static_cast<uint32_t>(CalloutType::GrpcStream);
// Actions to be done after the call into the VM returns.
std::deque<std::function<void()>> after_vm_call_actions_;
std::shared_ptr<VmIdHandle> vm_id_handle_;
};
using WasmHandleFactory = std::function<std::shared_ptr<WasmHandleBase>(std::string_view vm_id)>;
using WasmHandleCloneFactory =
std::function<std::shared_ptr<WasmHandleBase>(std::shared_ptr<WasmHandleBase> wasm)>;
// Handle which enables shutdown operations to run post deletion (e.g. post listener drain).
class WasmHandleBase : public std::enable_shared_from_this<WasmHandleBase> {
public:
explicit WasmHandleBase(std::shared_ptr<WasmBase> wasm_base) : wasm_base_(wasm_base) {}
~WasmHandleBase() {
if (wasm_base_) {
wasm_base_->startShutdown();
}
}
bool canary(const std::shared_ptr<PluginBase> &plugin,
const WasmHandleCloneFactory &clone_factory);
void kill() { wasm_base_ = nullptr; }
std::shared_ptr<WasmBase> &wasm() { return wasm_base_; }
protected:
std::shared_ptr<WasmBase> wasm_base_;
std::unordered_map<std::string, bool> plugin_canary_cache_;
};
std::string makeVmKey(std::string_view vm_id, std::string_view configuration,
std::string_view code);
// Returns nullptr on failure (i.e. initialization of the VM fails).
std::shared_ptr<WasmHandleBase> createWasm(const std::string &vm_key, const std::string &code,
const std::shared_ptr<PluginBase> &plugin,
const WasmHandleFactory &factory,
const WasmHandleCloneFactory &clone_factory,
bool allow_precompiled);
// Get an existing ThreadLocal VM matching 'vm_key' or nullptr if there isn't one.
std::shared_ptr<WasmHandleBase> getThreadLocalWasm(std::string_view vm_key);
class PluginHandleBase : public std::enable_shared_from_this<PluginHandleBase> {
public:
explicit PluginHandleBase(std::shared_ptr<WasmHandleBase> wasm_handle,
std::shared_ptr<PluginBase> plugin)
: plugin_(plugin), wasm_handle_(wasm_handle) {}
~PluginHandleBase() {
if (wasm_handle_) {
wasm_handle_->wasm()->startShutdown(plugin_->key());
}
}
std::shared_ptr<PluginBase> &plugin() { return plugin_; }
std::shared_ptr<WasmBase> &wasm() { return wasm_handle_->wasm(); }
protected:
std::shared_ptr<PluginBase> plugin_;
std::shared_ptr<WasmHandleBase> wasm_handle_;
};
using PluginHandleFactory = std::function<std::shared_ptr<PluginHandleBase>(
std::shared_ptr<WasmHandleBase> base_wasm, std::shared_ptr<PluginBase> plugin)>;
// Get an existing ThreadLocal VM matching 'vm_id' or create one using 'base_wavm' by cloning or by
// using it it as a template.
std::shared_ptr<PluginHandleBase> getOrCreateThreadLocalPlugin(
const std::shared_ptr<WasmHandleBase> &base_handle, const std::shared_ptr<PluginBase> &plugin,
const WasmHandleCloneFactory &clone_factory, const PluginHandleFactory &plugin_factory);
// Clear Base Wasm cache and the thread-local Wasm sandbox cache for the calling thread.
void clearWasmCachesForTesting();
inline const std::string &WasmBase::vm_configuration() const {
if (base_wasm_handle_)
return base_wasm_handle_->wasm()->vm_configuration_;
return vm_configuration_;
}
inline void *WasmBase::allocMemory(uint64_t size, uint64_t *address) {
if (!malloc_) {
return nullptr;
}
wasm_vm_->setRestrictedCallback(
true, {// logging (Proxy-Wasm)
"env.proxy_log",
// logging (stdout/stderr)
"wasi_unstable.fd_write", "wasi_snapshot_preview1.fd_write",
// time
"wasi_unstable.clock_time_get", "wasi_snapshot_preview1.clock_time_get"});
Word a = malloc_(vm_context(), size);
wasm_vm_->setRestrictedCallback(false);
if (!a.u64_) {
return nullptr;
}
auto memory = wasm_vm_->getMemory(a.u64_, size);
if (!memory) {
return nullptr;
}
*address = a.u64_;
return const_cast<void *>(reinterpret_cast<const void *>(memory.value().data()));
}
inline uint64_t WasmBase::copyString(std::string_view s) {
if (s.empty()) {
return 0; // nullptr
}
uint64_t pointer = 0;
uint8_t *m = static_cast<uint8_t *>(allocMemory((s.size() + 1), &pointer));
memcpy(m, s.data(), s.size());
m[s.size()] = 0;
return pointer;
}
inline bool WasmBase::copyToPointerSize(std::string_view s, uint64_t ptr_ptr, uint64_t size_ptr) {
uint64_t pointer = 0;
uint64_t size = s.size();
void *p = nullptr;
if (size > 0) {
p = allocMemory(size, &pointer);
if (!p) {
return false;
}
memcpy(p, s.data(), size);
}
if (!wasm_vm_->setWord(ptr_ptr, Word(pointer))) {
return false;
}
if (!wasm_vm_->setWord(size_ptr, Word(size))) {
return false;
}
return true;
}
template <typename T> inline bool WasmBase::setDatatype(uint64_t ptr, const T &t) {
return wasm_vm_->setMemory(ptr, sizeof(T), &t);
}
} // namespace proxy_wasm