driver/auth_util.cc (71 lines of code) (raw):

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // // This program is free software; you can redistribute it and/or modify // it under the terms of the GNU General Public License, version 2.0 // (GPLv2), as published by the Free Software Foundation, with the // following additional permissions: // // This program is distributed with certain software that is licensed // under separate terms, as designated in a particular file or component // or in the license documentation. Without limiting your rights under // the GPLv2, the authors of this program hereby grant you an additional // permission to link the program and your derivative works with the // separately licensed software that they have included with the program. // // Without limiting the foregoing grant of rights under the GPLv2 and // additional permission as to separately licensed software, this // program is also subject to the Universal FOSS Exception, version 1.0, // a copy of which can be found along with its FAQ at // http://oss.oracle.com/licenses/universal-foss-exception. // // This program is distributed in the hope that it will be useful, but // WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. // See the GNU General Public License, version 2.0, for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. #include "auth_util.h" #include "aws_sdk_helper.h" #include "driver.h" namespace { AWS_SDK_HELPER SDK_HELPER; } AUTH_UTIL::AUTH_UTIL(const char* region) { ++SDK_HELPER; Aws::RDS::RDSClientConfiguration client_config; if (region) { client_config.region = region; } this->rds_client = std::make_shared<Aws::RDS::RDSClient>( Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config); }; AUTH_UTIL::AUTH_UTIL(const char* region, Aws::Auth::AWSCredentials credentials) { ++SDK_HELPER; Aws::RDS::RDSClientConfiguration client_config; if (region) { client_config.region = region; } this->rds_client = std::make_shared<Aws::RDS::RDSClient>(credentials, client_config); } std::pair<std::string, bool> AUTH_UTIL::get_auth_token(std::unordered_map<std::string, TOKEN_INFO>& token_cache, std::mutex& token_cache_mutex, const char* host, const char* region, unsigned int port, const char* user, unsigned int time_until_expiration, bool force_generate_new_token) { if (!host) { host = ""; } if (!region) { region = ""; } if (!user) { user = ""; } std::string auth_token; const std::string cache_key = build_cache_key(host, region, port, user); bool using_cached_token = false; { std::unique_lock<std::mutex> lock(token_cache_mutex); if (force_generate_new_token) { token_cache.erase(cache_key); } else { // Search for token in cache auto find_token = token_cache.find(cache_key); if (find_token != token_cache.end()) { TOKEN_INFO info = find_token->second; if (info.is_expired()) { token_cache.erase(cache_key); } else { using_cached_token = true; return std::make_pair(info.token, using_cached_token); } } } // Generate new token auth_token = this->generate_token(host, region, port, user); token_cache[cache_key] = TOKEN_INFO(auth_token, time_until_expiration); } return std::make_pair(auth_token, using_cached_token); } std::string AUTH_UTIL::generate_token(const char* host, const char* region, unsigned int port, const char* user) { return this->rds_client->GenerateConnectAuthToken(host, region, port, user); } std::string AUTH_UTIL::build_cache_key(const char* host, const char* region, unsigned int port, const char* user) { // Format should be "<region>:<host>:<port>:<user>" return std::string(region).append(":").append(host).append(":").append(std::to_string(port)).append(":").append(user); } AUTH_UTIL::~AUTH_UTIL() { this->rds_client.reset(); --SDK_HELPER; }