driver/iam_proxy.cc (71 lines of code) (raw):

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // // This program is free software; you can redistribute it and/or modify // it under the terms of the GNU General Public License, version 2.0 // (GPLv2), as published by the Free Software Foundation, with the // following additional permissions: // // This program is distributed with certain software that is licensed // under separate terms, as designated in a particular file or component // or in the license documentation. Without limiting your rights under // the GPLv2, the authors of this program hereby grant you an additional // permission to link the program and your derivative works with the // separately licensed software that they have included with the program. // // Without limiting the foregoing grant of rights under the GPLv2 and // additional permission as to separately licensed software, this // program is also subject to the Universal FOSS Exception, version 1.0, // a copy of which can be found along with its FAQ at // http://oss.oracle.com/licenses/universal-foss-exception. // // This program is distributed in the hope that it will be useful, but // WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. // See the GNU General Public License, version 2.0, for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. #include <functional> #include <tuple> #include "driver.h" #include "iam_proxy.h" std::unordered_map<std::string, TOKEN_INFO> IAM_PROXY::token_cache; std::mutex IAM_PROXY::token_cache_mutex; IAM_PROXY::IAM_PROXY(DBC* dbc, DataSource* ds) : IAM_PROXY(dbc, ds, nullptr) {}; IAM_PROXY::IAM_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) { this->next_proxy = next_proxy; const char* region = ds->opt_AUTH_REGION ? static_cast<const char*>(ds->opt_AUTH_REGION) : Aws::Region::US_EAST_1; this->auth_util = std::make_shared<AUTH_UTIL>(region); } IAM_PROXY::~IAM_PROXY() { this->auth_util.reset(); } #ifdef UNIT_TEST_BUILD IAM_PROXY::IAM_PROXY(DBC *dbc, DataSource *ds, CONNECTION_PROXY *next_proxy, std::shared_ptr<AUTH_UTIL> auth_util) : CONNECTION_PROXY(dbc, ds) { this->next_proxy = next_proxy; this->auth_util = auth_util; } #endif bool IAM_PROXY::connect(const char* host, const char* user, const char* password, const char* database, unsigned int port, const char* socket, unsigned long flags) { auto f = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port, socket, flags); return invoke_func_with_generated_token(f); } bool IAM_PROXY::change_user(const char* user, const char* passwd, const char* db) { auto f = std::bind(&CONNECTION_PROXY::change_user, next_proxy, user, std::placeholders::_1, db); return invoke_func_with_generated_token(f); } void IAM_PROXY::clear_token_cache() { std::unique_lock<std::mutex> lock(token_cache_mutex); token_cache.clear(); } bool IAM_PROXY::invoke_func_with_generated_token(std::function<bool(const char*)> func) { // Use user provided auth host if present, otherwise, use server host const char *auth_host = ds->opt_AUTH_HOST ? (const char *)ds->opt_AUTH_HOST : (const char *)ds->opt_SERVER; // Go with default region if region is not provided. const char *region = ds->opt_AUTH_REGION ? (const char *)ds->opt_AUTH_REGION : Aws::Region::US_EAST_1; int iam_port = ds->opt_AUTH_PORT; if (iam_port == UNDEFINED_PORT) { // Use regular port if user does not provide IAM port iam_port = ds->opt_PORT; } std::string auth_token; bool using_cached_token; std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token( token_cache, token_cache_mutex, auth_host, region, iam_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION); bool connect_result = func(auth_token.c_str()); if (!connect_result) { if (using_cached_token) { // Retry func with a fresh token std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token(token_cache, token_cache_mutex, auth_host, region, iam_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION, true); if (func(auth_token.c_str())) { return true; } } Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider; Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials(); if (credentials.IsEmpty()) { this->set_custom_error_message( "Could not find AWS Credentials for IAM Authentication. Please set up AWS credentials."); } else if (credentials.IsExpired()) { this->set_custom_error_message( "AWS Credentials for IAM Authentication are expired. Please refresh AWS credentials."); } } return connect_result; }