driver/custom_endpoint_monitor.cc (147 lines of code) (raw):

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // // This program is free software; you can redistribute it and/or modify // it under the terms of the GNU General Public License, version 2.0 // (GPLv2), as published by the Free Software Foundation, with the // following additional permissions: // // This program is distributed with certain software that is licensed // under separate terms, as designated in a particular file or component // or in the license documentation. Without limiting your rights under // the GPLv2, the authors of this program hereby grant you an additional // permission to link the program and your derivative works with the // separately licensed software that they have included with the program. // // Without limiting the foregoing grant of rights under the GPLv2 and // additional permission as to separately licensed software, this // program is also subject to the Universal FOSS Exception, version 1.0, // a copy of which can be found along with its FAQ at // http://oss.oracle.com/licenses/universal-foss-exception. // // This program is distributed in the hope that it will be useful, but // WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. // See the GNU General Public License, version 2.0, for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. #include <aws/core/auth/AWSCredentialsProviderChain.h> #include <aws/rds/model/DBClusterEndpoint.h> #include <aws/rds/model/DescribeDBClusterEndpointsRequest.h> #include <aws/rds/model/Filter.h> #include <utility> #include <vector> #include "allowed_and_blocked_hosts.h" #include "aws_sdk_helper.h" #include "custom_endpoint_monitor.h" #include "driver.h" #include "mylog.h" namespace { AWS_SDK_HELPER SDK_HELPER; } CACHE_MAP<std::string, std::shared_ptr<CUSTOM_ENDPOINT_INFO>> CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache; CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service, const std::string& custom_endpoint_host, const std::string& endpoint_identifier, const std::string& region, long long refresh_rate_nanos, ctpl::thread_pool& thread_pool, bool enable_logging) : topology_service(topology_service), custom_endpoint_host(custom_endpoint_host), endpoint_identifier(endpoint_identifier), region(region), refresh_rate_nanos(refresh_rate_nanos), thread_pool(thread_pool), enable_logging(enable_logging) { if (enable_logging) { this->logger = init_log_file(); } this->run(); } #ifdef UNIT_TEST_BUILD CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service, const std::string& custom_endpoint_host, const std::string& endpoint_identifier, const std::string& region, long long refresh_rate_nanos, ctpl::thread_pool& thread_pool, bool enable_logging, std::shared_ptr<Aws::RDS::RDSClient> client) : topology_service(topology_service), custom_endpoint_host(custom_endpoint_host), endpoint_identifier(endpoint_identifier), region(region), refresh_rate_nanos(refresh_rate_nanos), thread_pool(thread_pool), enable_logging(enable_logging) { if (enable_logging) { this->logger = init_log_file(); } this->run(); } #endif bool CUSTOM_ENDPOINT_MONITOR::should_dispose() { return true; } bool CUSTOM_ENDPOINT_MONITOR::has_custom_endpoint_info() const { auto default_val = std::shared_ptr<CUSTOM_ENDPOINT_INFO>(nullptr); return custom_endpoint_cache.get(this->custom_endpoint_host, default_val) != default_val; } void CUSTOM_ENDPOINT_MONITOR::run() { if (thread_pool.size() == 1) { // Each monitor should only have 1 thread. return; } MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str()); thread_pool.resize(1); thread_pool.push([=](int id) { ++SDK_HELPER; Aws::RDS::RDSClientConfiguration client_config; if (!region.empty()) { client_config.region = region; } const Aws::RDS::RDSClient rds_client(Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config); Aws::RDS::Model::Filter filter; filter.SetName("db-cluster-endpoint-type"); filter.AddValues("custom"); Aws::RDS::Model::DescribeDBClusterEndpointsRequest request; request.SetDBClusterEndpointIdentifier(this->endpoint_identifier); // TODO: Investigate why filters returns `InvalidParameterCombination` error saying filter values are null. // request.AddFilters(filter); try { while (!should_stop) { const std::chrono::time_point start = std::chrono::steady_clock::now(); const auto response = rds_client.DescribeDBClusterEndpoints(request); const auto custom_endpoints = response.GetResult().GetDBClusterEndpoints(); if (custom_endpoints.size() != 1) { MYLOG_TRACE(this->logger, 0, "Unexpected number of custom endpoints with endpoint identifier '%s' in region '%s'. Expected 1 " "custom endpoint, but found %d. Endpoints: %s", endpoint_identifier.c_str(), region.c_str(), custom_endpoints.size(), this->get_endpoints_as_string(custom_endpoints).c_str()); std::this_thread::sleep_for(std::chrono::nanoseconds(this->refresh_rate_nanos)); continue; } const std::shared_ptr<CUSTOM_ENDPOINT_INFO> endpoint_info = CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]); const std::shared_ptr<CUSTOM_ENDPOINT_INFO> cache_endpoint_info = custom_endpoint_cache.get(this->custom_endpoint_host, nullptr); if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) { const long long elapsed_time = std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count(); std::this_thread::sleep_for( std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time))); continue; } MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}", custom_endpoint_host.c_str(), endpoint_info->to_string().c_str()); // The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts. std::shared_ptr<ALLOWED_AND_BLOCKED_HOSTS> allowed_and_blocked_hosts; if (endpoint_info->get_member_list_type() == STATIC_LIST) { allowed_and_blocked_hosts = std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(endpoint_info->get_static_members(), std::set<std::string>()); } else { allowed_and_blocked_hosts = std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>( std::set<std::string>(), endpoint_info->get_excluded_members()); } this->topology_service->set_allowed_and_blocked_hosts(allowed_and_blocked_hosts); custom_endpoint_cache.put(this->custom_endpoint_host, endpoint_info, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS); const long long elapsed_time = std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count(); std::this_thread::sleep_for( std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time))); } --SDK_HELPER; } catch (const std::exception& e) { // Log and continue monitoring. --SDK_HELPER; MYLOG_TRACE(this->logger, 0, "Error while monitoring custom endpoint: %s", e.what()); } should_stop = true; }); } std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string( const std::vector<Aws::RDS::Model::DBClusterEndpoint>& custom_endpoints) { if (custom_endpoints.empty()) { return "<no endpoints>"; } std::string endpoints("["); for (auto const& e : custom_endpoints) { endpoints += e.GetDBClusterEndpointIdentifier(); endpoints += ","; } endpoints.pop_back(); endpoints += "]"; return endpoints; } void CUSTOM_ENDPOINT_MONITOR::stop() { should_stop = true; thread_pool.stop(true); thread_pool.resize(0); custom_endpoint_cache.remove(this->custom_endpoint_host); MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str()); } void CUSTOM_ENDPOINT_MONITOR::clear_cache() { custom_endpoint_cache.clear(); }