driver/adfs_proxy.cc (201 lines of code) (raw):
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.
#include "adfs_proxy.h"
#include <regex>
#include "driver.h"
#define SIGN_IN_PAGE_URL "/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=urn:amazon:webservices"
std::unordered_map<std::string, TOKEN_INFO> ADFS_PROXY::token_cache;
std::mutex ADFS_PROXY::token_cache_mutex;
ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds) : ADFS_PROXY(dbc, ds, nullptr) {};
ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
std::string host{static_cast<const char*>(ds->opt_IDP_ENDPOINT)};
host += ":" + std::to_string(ds->opt_IDP_PORT);
const int client_connect_timeout = ds->opt_CLIENT_CONNECT_TIMEOUT;
const int client_socket_timeout = ds->opt_CLIENT_SOCKET_TIMEOUT;
const bool enable_ssl = ds->opt_ENABLE_SSL;
this->saml_util = std::make_shared<ADFS_SAML_UTIL>(host, client_connect_timeout, client_socket_timeout, enable_ssl);
}
void ADFS_PROXY::clear_token_cache() {
std::unique_lock<std::mutex> lock(token_cache_mutex);
token_cache.clear();
}
ADFS_SAML_UTIL::ADFS_SAML_UTIL(const std::shared_ptr<SAML_HTTP_CLIENT>& client) { this->http_client = client; }
ADFS_SAML_UTIL::ADFS_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) {
this->http_client =
std::make_shared<SAML_HTTP_CLIENT>("https://" + host, connect_timeout, socket_timeout, enable_ssl);
}
std::string ADFS_SAML_UTIL::get_saml_assertion(DataSource* ds) {
nlohmann::json res;
try {
res = this->http_client->get(std::string(SIGN_IN_PAGE_URL));
} catch (SAML_HTTP_EXCEPTION& e) {
const std::string error =
"Failed to get sign-in page from ADFS: " + e.error_message() + ". Please verify your IDP endpoint.";
throw SAML_HTTP_EXCEPTION(error);
}
const auto body = std::string(res);
std::smatch m;
if (!std::regex_search(body, m, ADFS_REGEX::FORM_ACTION_PATTERN)) {
return std::string();
}
std::string form_action = unescape_html_entity(m.str(1));
const std::string params = get_parameters_from_html(ds, body);
const std::string content = get_form_action_body(form_action, params);
if (std::regex_search(content, m, ADFS_REGEX::SAML_RESPONSE_PATTERN)) {
return m.str(1);
}
return std::string();
}
std::string ADFS_SAML_UTIL::unescape_html_entity(const std::string& html) {
std::string retval("");
int i = 0;
int length = html.length();
while (i < length) {
char c = html[i];
if (c != '&') {
retval.append(1, c);
i++;
continue;
}
if (html.substr(i, 4) == "<") {
retval.append(1, '<');
i += 4;
} else if (html.substr(i, 4) == ">") {
retval.append(1, '>');
i += 4;
} else if (html.substr(i, 5) == "&") {
retval.append(1, '&');
i += 5;
} else if (html.substr(i, 6) == "'") {
retval.append(1, '\'');
i += 6;
} else if (html.substr(i, 6) == """) {
retval.append(1, '"');
i += 6;
} else {
retval.append(1, c);
++i;
}
}
return retval;
}
std::vector<std::string> ADFS_SAML_UTIL::get_input_tags_from_html(const std::string& body) {
std::unordered_set<std::string> hashSet;
std::vector<std::string> retval;
std::smatch matches;
std::regex pattern(ADFS_REGEX::INPUT_TAG_PATTERN);
std::string source = body;
while (std::regex_search(source, matches, pattern)) {
std::string tag = matches.str(0);
std::string tagName = get_value_by_key(tag, std::string("name"));
std::transform(tagName.begin(), tagName.end(), tagName.begin(), [](unsigned char c) { return std::tolower(c); });
if (!tagName.empty() && hashSet.find(tagName) == hashSet.end()) {
hashSet.insert(tagName);
retval.push_back(tag);
}
source = matches.suffix().str();
}
return retval;
}
std::string ADFS_SAML_UTIL::get_value_by_key(const std::string& input, const std::string& key) {
std::string pattern("(");
pattern += key;
pattern += ")\\s*=\\s*\"(.*?)\"";
std::smatch matches;
if (std::regex_search(input, matches, std::regex(pattern))) {
MYLOG_TRACE(init_log_file(), 0, "get_value_by_key");
return unescape_html_entity(matches.str(2));
}
return "";
}
std::string ADFS_SAML_UTIL::get_parameters_from_html(DataSource* ds, const std::string& body) {
std::map<std::string, std::string> parameters;
for (auto& inputTag : get_input_tags_from_html(body)) {
std::string name = get_value_by_key(inputTag, std::string("name"));
std::string value = get_value_by_key(inputTag, std::string("value"));
std::string nameLower = name;
std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(),
[](unsigned char c) { return std::tolower(c); });
const std::string username = static_cast<const char*>(ds->opt_IDP_USERNAME);
const std::string password = static_cast<const char*>(ds->opt_IDP_PASSWORD);
if (nameLower.find("username") != std::string::npos) {
parameters.insert(std::pair<std::string, std::string>(name, username));
} else if ((nameLower.find("authmethod") != std::string::npos) && !value.empty()) {
parameters.insert(std::pair<std::string, std::string>(name, value));
} else if (nameLower.find("password") != std::string::npos) {
parameters.insert(std::pair<std::string, std::string>(name, password));
} else if (!name.empty()) {
parameters.insert(std::pair<std::string, std::string>(name, value));
}
}
// Convert parameters to a & delimited string, e.g. username=u&password=p
const std::string delimiter = "&";
const std::string result =
std::accumulate(parameters.begin(), parameters.end(), std::string(),
[delimiter](const std::string& s, const std::pair<const std::string, std::string>& p) {
return s + (s.empty() ? std::string() : delimiter) + p.first + "=" + p.second;
});
return result;
}
std::string ADFS_SAML_UTIL::get_form_action_body(const std::string& url, const std::string& params) {
nlohmann::json res;
try {
res = this->http_client->post(url, params, "application/x-www-form-urlencoded");
} catch (SAML_HTTP_EXCEPTION& e) {
const std::string error =
"Failed to get SAML Assertion from ADFS : " + e.error_message() + ". Please verify your ADFS credentials.";
throw SAML_HTTP_EXCEPTION(error);
}
return res.empty() ? "" : res;
}
#ifdef UNIT_TEST_BUILD
ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr<AUTH_UTIL> auth_util,
const std::shared_ptr<SAML_HTTP_CLIENT>& client)
: CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
this->auth_util = auth_util;
this->saml_util = std::make_shared<ADFS_SAML_UTIL>(client);
}
#endif
ADFS_PROXY::~ADFS_PROXY() { this->auth_util.reset(); }
bool ADFS_PROXY::connect(const char* host, const char* user, const char* password, const char* database,
unsigned int port, const char* socket, unsigned long flags) {
auto func = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port,
socket, flags);
const char* region =
ds->opt_FED_AUTH_REGION ? static_cast<const char*>(ds->opt_FED_AUTH_REGION) : Aws::Region::US_EAST_1;
std::string assertion;
try {
assertion = this->saml_util->get_saml_assertion(ds);
} catch (SAML_HTTP_EXCEPTION& e) {
this->set_custom_error_message(e.error_message().c_str());
return false;
}
auto idp_host = static_cast<const char*>(ds->opt_IDP_ENDPOINT);
auto iam_role_arn = static_cast<const char*>(ds->opt_IAM_ROLE_ARN);
auto idp_arn = static_cast<const char*>(ds->opt_IAM_IDP_ARN);
const Aws::Auth::AWSCredentials credentials =
this->saml_util->get_aws_credentials(idp_host, region, iam_role_arn, idp_arn, assertion);
this->auth_util = std::make_shared<AUTH_UTIL>(region, credentials);
const char* auth_host = ds->opt_FED_AUTH_HOST ? static_cast<const char*>(ds->opt_FED_AUTH_HOST)
: static_cast<const char*>(ds->opt_SERVER);
const int auth_port = ds->opt_FED_AUTH_PORT;
std::string auth_token;
bool using_cached_token;
std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token(
token_cache, token_cache_mutex, auth_host, region, auth_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION);
bool connect_result = func(auth_token.c_str());
if (!connect_result) {
if (using_cached_token) {
// Retry func with a fresh token
std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token(
token_cache, token_cache_mutex, auth_host, region, auth_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION, true);
if (func(auth_token.c_str())) {
return true;
}
}
if (credentials.IsEmpty()) {
this->set_custom_error_message(
"Unable to generate temporary AWS credentials from the SAML assertion. Please ensure the ADFS identity "
"provider is correctly configured with AWS.");
}
}
return connect_result;
}