proxy_agent/src/proxy/authorization_rules.rs (529 lines of code) (raw):
// Copyright (c) Microsoft Corporation
// SPDX-License-Identifier: MIT
//! This module contains the logic to authorize the request based on the authorization rules.
//! The authorization rules is from user inputted access control rules.
//!
//! Example
//! ```rust
//! use proxy_agent::authorization_rules;
//! use proxy_agent::proxy_connection::ConnectionLogger;
//!
//! // convert the authorization item to access control rules
//! let access_control_rules = AccessControlRules::from_authorization_item(authorization_item);
//!
//! // check if the request is allowed based on the access control rules
//! let is_allowed = access_control_rules.is_allowed(connection_id, request_url, claims);
//!
//! ```
use super::{proxy_connection::ConnectionLogger, Claims};
use crate::common::logger;
use crate::key_keeper::key::{AuthorizationItem, AuthorizationRules, Identity, Privilege, Role};
use proxy_agent_shared::logger::LoggerLevel;
use proxy_agent_shared::misc_helpers;
use serde_derive::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::path::Path;
use std::str::FromStr;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub enum AuthorizationMode {
Disabled,
Audit,
Enforce,
}
impl std::fmt::Display for AuthorizationMode {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
AuthorizationMode::Disabled => write!(f, "disabled"),
AuthorizationMode::Audit => write!(f, "audit"),
AuthorizationMode::Enforce => write!(f, "enforce"),
}
}
}
impl std::str::FromStr for AuthorizationMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"disabled" => Ok(AuthorizationMode::Disabled),
"audit" => Ok(AuthorizationMode::Audit),
"enforce" => Ok(AuthorizationMode::Enforce),
_ => Err(format!("Invalid AuthorizationMode: {}", s)),
}
}
}
#[derive(Serialize, Deserialize, Clone)]
#[allow(non_snake_case)]
pub struct ComputedAuthorizationItem {
pub id: String,
// The default access: allow -> true, deny-> false
pub defaultAllowed: bool,
// disabled, audit, enforce
pub mode: AuthorizationMode,
// all the defined unique privileges, distinct by name
pub privileges: HashMap<String, Privilege>,
// The identities assigned to this privilege
// key - privilege name, value - the assigned identity names
pub privilegeAssignments: HashMap<String, HashSet<String>>,
// all the defined unique identities, distinct by name
// key - identity name, value - identity object
pub identities: HashMap<String, Identity>,
}
#[allow(dead_code)]
impl ComputedAuthorizationItem {
pub fn from_authorization_item(
authorization_item: AuthorizationItem,
) -> ComputedAuthorizationItem {
let authorization_mode = match AuthorizationMode::from_str(&authorization_item.mode) {
Ok(mode) => mode,
Err(err) => {
// This should not happen, log the error and set the mode to disabled
logger::write_error(format!("Failed to parse authorization mode: {}", err));
AuthorizationMode::Disabled
}
};
// Initialize with empty dictionaries
let mut privilege_dict: HashMap<String, Privilege> = HashMap::new();
let mut identity_dict: HashMap<String, Identity> = HashMap::new();
let mut privilege_assignments: HashMap<String, HashSet<String>> = HashMap::new();
if let Some(input_rules) = authorization_item.rules {
if let (Some(privileges), Some(identities), Some(roles), Some(role_assignments)) = (
input_rules.privileges,
input_rules.identities,
input_rules.roles,
input_rules.roleAssignments,
) {
let role_dict = roles
.into_iter()
.map(|role| (role.name.clone(), role))
.collect::<HashMap<String, Role>>();
identity_dict = identities
.into_iter()
.map(|identity| (identity.name.clone(), identity))
.collect::<HashMap<String, Identity>>();
privilege_dict = privileges
.into_iter()
.map(|privilege| (privilege.name.clone(), privilege))
.collect::<HashMap<String, Privilege>>();
for role_assignment in role_assignments {
match role_dict.get(&role_assignment.role) {
Some(role) => {
for privilege_name in &role.privileges {
if privilege_dict.contains_key(privilege_name) {
let assignments =
if privilege_assignments.contains_key(privilege_name) {
privilege_assignments.get_mut(privilege_name).unwrap()
} else {
let assignments = HashSet::new();
privilege_assignments
.insert(privilege_name.clone(), assignments);
privilege_assignments.get_mut(privilege_name).unwrap()
};
for identity_name in &role_assignment.identities {
if !identity_dict.contains_key(identity_name) {
// skip the identity if the identity is not defined
continue;
}
assignments.insert(identity_name.clone());
}
}
}
}
None => {
// skip the assignment if the role is not defined
logger::write_error(format!(
"Role '{}' is not defined, skip the role assignment.",
role_assignment.role
));
continue;
}
}
}
}
}
ComputedAuthorizationItem {
id: authorization_item.id,
defaultAllowed: authorization_item.defaultAccess.to_lowercase() == "allow",
mode: authorization_mode,
identities: identity_dict,
privileges: privilege_dict,
privilegeAssignments: privilege_assignments,
}
}
pub fn is_allowed(
&self,
logger: &mut ConnectionLogger,
request_url: hyper::Uri,
claims: Claims,
) -> bool {
if self.mode == AuthorizationMode::Disabled {
logger.write(
LoggerLevel::Trace,
"Access control is in disabled state, skip....".to_string(),
);
return true;
}
let mut any_privilege_matched = false;
for privilege in self.privileges.values() {
let privilege_name = &privilege.name;
if privilege.is_match(logger, &request_url) {
any_privilege_matched = true;
logger.write(
LoggerLevel::Trace,
format!("Request matched privilege '{}'.", privilege_name),
);
if let Some(assignments) = self.privilegeAssignments.get(privilege_name) {
for assignment in assignments {
let identity_name = assignment.clone();
if let Some(identity) = self.identities.get(&identity_name) {
if identity.is_match(logger, &claims) {
logger.write(
LoggerLevel::Trace,
format!(
"Request matched privilege '{}' and identity '{}'.",
privilege_name, identity_name
),
);
return true;
}
}
}
logger.write(
LoggerLevel::Trace,
format!(
"Request matched privilege '{}' but no identity matched.",
privilege_name
),
);
} else {
logger.write(
LoggerLevel::Trace,
format!(
"Request matched privilege '{}' but no identity assigned.",
privilege_name
),
);
}
} else {
logger.write(
LoggerLevel::Trace,
format!("Request does not match privilege '{}'.", privilege_name),
);
}
}
if any_privilege_matched {
logger.write(
LoggerLevel::Info,
"Privilege matched at least once, but no identity matches, deny the access."
.to_string(),
);
return false;
}
logger.write(
LoggerLevel::Trace,
format!(
"No privilege matched, fall back to use the default access: {}.",
self.defaultAllowed
),
);
self.defaultAllowed
}
}
#[derive(Serialize, Deserialize, Clone)]
#[allow(non_snake_case)]
pub struct ComputedAuthorizationRules {
#[serde(skip_serializing_if = "Option::is_none")]
pub imds: Option<ComputedAuthorizationItem>,
#[serde(skip_serializing_if = "Option::is_none")]
pub wireserver: Option<ComputedAuthorizationItem>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hostga: Option<ComputedAuthorizationItem>,
}
#[derive(Serialize, Deserialize, Clone)]
#[allow(non_snake_case)]
pub struct AuthorizationRulesForLogging {
#[serde(skip_serializing_if = "Option::is_none")]
pub inputRules: Option<AuthorizationRules>,
pub computedRules: ComputedAuthorizationRules,
}
impl AuthorizationRulesForLogging {
pub fn new(
input_rules: Option<AuthorizationRules>,
computed_rules: ComputedAuthorizationRules,
) -> AuthorizationRulesForLogging {
AuthorizationRulesForLogging {
inputRules: input_rules,
computedRules: computed_rules,
}
}
/// Write the authorization rules to a file for support purpose
/// The file name is in the format of "AuthorizationRules_{timestamp}.json"
/// The content is the json string of the AuthorizationRulesForLogging object
/// The file is written to the path_dir specified by the input parameter
pub fn write_all(&self, path_dir: &Path, max_file_count: usize) {
// remove the old files
let files = match misc_helpers::search_files(path_dir, r"^AuthorizationRules_.*\.json$") {
Ok(files) => files,
Err(e) => {
// This should not happen, log the error and skip write the file
logger::write_error(format!(
"Failed to search the old authorization rules files under dir {} with error: {}",
path_dir.display(),
e
));
return;
}
};
if files.len() >= max_file_count {
let mut count = max_file_count;
for file in &files {
std::fs::remove_file(file).unwrap_or_else(|e| {
logger::write_error(format!(
"Failed to remove the old authorization rules file {} with error: {}",
file.display(),
e
));
});
count += 1;
if count > files.len() {
break;
}
}
}
// compute the file name
let new_file_name = format!(
"AuthorizationRules_{}-{}.json",
misc_helpers::get_date_time_string_with_milliseconds(),
misc_helpers::get_date_time_unix_nano()
)
.replace(':', ".");
let full_file_path = path_dir.join(new_file_name);
match misc_helpers::json_write_to_file(&self, &full_file_path) {
Ok(_) => {
logger::write_information(format!(
"Authorization rules are written to file: {}",
full_file_path.display()
));
}
Err(e) => {
logger::write_error(format!(
"Failed to write the authorization rules to file {} with error: {}",
full_file_path.display(),
e
));
}
};
}
}
#[cfg(test)]
mod tests {
use super::{AuthorizationRulesForLogging, ComputedAuthorizationRules};
use crate::key_keeper::key::{
AccessControlRules, AuthorizationItem, AuthorizationRules, Identity, Privilege, Role,
RoleAssignment,
};
use crate::proxy::authorization_rules::{AuthorizationMode, ComputedAuthorizationItem};
use crate::proxy::{proxy_connection::ConnectionLogger, Claims};
use proxy_agent_shared::misc_helpers;
use std::ffi::OsString;
use std::path::PathBuf;
use std::str::FromStr;
#[tokio::test]
async fn test_authorization_rules() {
let logger_key = "test_authorization_rules";
let mut temp_test_path = std::env::temp_dir();
temp_test_path.push(logger_key);
let mut test_logger = ConnectionLogger::new(0, 0);
// Test Enforce Mode
let access_control_rules = AccessControlRules {
roles: Some(vec![Role {
name: "test".to_string(),
privileges: vec!["test".to_string(), "test1".to_string()],
}]),
privileges: Some(vec![Privilege {
name: "test".to_string(),
path: "/test".to_string(),
queryParameters: None,
}]),
identities: Some(vec![Identity {
name: "test".to_string(),
exePath: Some("test".to_string()),
groupName: Some("test".to_string()),
processName: Some("test".to_string()),
userName: Some("test".to_string()),
}]),
roleAssignments: Some(vec![RoleAssignment {
role: "test".to_string(),
identities: vec!["test".to_string()],
}]),
};
let authorization_item: AuthorizationItem = AuthorizationItem {
defaultAccess: "deny".to_string(),
mode: "enforce".to_string(),
rules: Some(access_control_rules),
id: "0".to_string(),
};
let rules = ComputedAuthorizationItem::from_authorization_item(authorization_item);
let _clone_rules = rules.clone();
assert!(!rules.defaultAllowed);
assert_eq!(rules.mode, AuthorizationMode::Enforce);
assert!(!rules.privilegeAssignments.is_empty());
assert!(!rules.identities.is_empty());
assert!(!rules.privileges.is_empty());
let mut claims = Claims {
userId: 0,
userName: "test".to_string(),
userGroups: vec!["test".to_string()],
processId: 0,
processFullPath: PathBuf::from("test"),
clientIp: "0".to_string(),
clientPort: 0, // doesn't matter for this test
processName: OsString::from("test"),
processCmdLine: "test".to_string(),
runAsElevated: true,
};
// assert the claim is allowed given the rules above
let url = hyper::Uri::from_str("http://localhost/test/test").unwrap();
assert!(rules.is_allowed(&mut test_logger, url, claims.clone()));
let relative_url = hyper::Uri::from_str("/test/test").unwrap();
assert!(rules.is_allowed(&mut test_logger, relative_url.clone(), claims.clone()));
claims.userName = "test1".to_string();
assert!(!rules.is_allowed(&mut test_logger, relative_url, claims.clone()));
// Test Audit Mode
let access_control_rules = AccessControlRules {
roles: Some(vec![Role {
name: "test".to_string(),
privileges: vec!["test".to_string(), "test1".to_string()],
}]),
privileges: Some(vec![Privilege {
name: "test".to_string(),
path: "/test".to_string(),
queryParameters: None,
}]),
identities: Some(vec![Identity {
name: "test".to_string(),
exePath: Some("test".to_string()),
groupName: Some("test".to_string()),
processName: Some("test".to_string()),
userName: Some("test".to_string()),
}]),
roleAssignments: Some(vec![RoleAssignment {
role: "test".to_string(),
identities: vec!["test".to_string()],
}]),
};
let authorization_item: AuthorizationItem = AuthorizationItem {
defaultAccess: "deny".to_string(),
mode: "audit".to_string(),
rules: Some(access_control_rules),
id: "0".to_string(),
};
let rules = ComputedAuthorizationItem::from_authorization_item(authorization_item);
assert!(!rules.defaultAllowed);
assert_eq!(rules.mode, AuthorizationMode::Audit);
assert!(!rules.privilegeAssignments.is_empty());
assert!(!rules.identities.is_empty());
assert!(!rules.privileges.is_empty());
// Test Disabled Mode
let access_control_rules = AccessControlRules {
roles: Some(vec![Role {
name: "test".to_string(),
privileges: vec!["test".to_string(), "test1".to_string()],
}]),
privileges: Some(vec![Privilege {
name: "test".to_string(),
path: "/test".to_string(),
queryParameters: None,
}]),
identities: Some(vec![Identity {
name: "test".to_string(),
exePath: Some("test".to_string()),
groupName: Some("test".to_string()),
processName: Some("test".to_string()),
userName: Some("test".to_string()),
}]),
roleAssignments: Some(vec![RoleAssignment {
role: "test".to_string(),
identities: vec!["test".to_string()],
}]),
};
let authorization_item: AuthorizationItem = AuthorizationItem {
defaultAccess: "deny".to_string(),
mode: "disabled".to_string(),
rules: Some(access_control_rules),
id: "0".to_string(),
};
let rules = ComputedAuthorizationItem::from_authorization_item(authorization_item);
assert!(!rules.defaultAllowed);
assert_eq!(rules.mode, AuthorizationMode::Disabled);
assert!(!rules.privilegeAssignments.is_empty());
assert!(!rules.identities.is_empty());
assert!(!rules.privileges.is_empty());
let url = hyper::Uri::from_str("http://localhost/test/test1").unwrap();
assert!(rules.is_allowed(&mut test_logger, url, claims.clone()));
let relative_url = hyper::Uri::from_str("/test/test1").unwrap();
assert!(rules.is_allowed(&mut test_logger, relative_url, claims.clone()));
// Test enforce mode, identity not match
let access_control_rules = AccessControlRules {
roles: Some(vec![Role {
name: "test".to_string(),
privileges: vec!["test".to_string(), "test1".to_string()],
}]),
privileges: Some(vec![Privilege {
name: "test".to_string(),
path: "/test".to_string(),
queryParameters: None,
}]),
identities: Some(vec![Identity {
name: "test1".to_string(),
exePath: Some("test".to_string()),
groupName: Some("test".to_string()),
processName: Some("test".to_string()),
userName: Some("test".to_string()),
}]),
roleAssignments: Some(vec![RoleAssignment {
role: "test".to_string(),
identities: vec!["test1".to_string()],
}]),
};
let authorization_item: AuthorizationItem = AuthorizationItem {
defaultAccess: "deny".to_string(),
mode: "enforce".to_string(),
rules: Some(access_control_rules),
id: "0".to_string(),
};
let rules = ComputedAuthorizationItem::from_authorization_item(authorization_item);
assert!(!rules.defaultAllowed);
assert_eq!(rules.mode, AuthorizationMode::Enforce);
assert!(!rules.privilegeAssignments.is_empty());
assert!(!rules.identities.is_empty());
assert!(!rules.privileges.is_empty());
let url = hyper::Uri::from_str("http://localhost/test?").unwrap();
assert!(!rules.is_allowed(&mut test_logger, url, claims.clone()));
let relativeurl = hyper::Uri::from_str("/test?").unwrap();
assert!(!rules.is_allowed(&mut test_logger, relativeurl, claims.clone()));
}
#[tokio::test]
async fn test_authorization_rules_for_logging() {
let mut temp_test_path = std::env::temp_dir();
temp_test_path.push("test_authorization_rules_for_logging");
let mut log_dir = temp_test_path.to_path_buf();
log_dir.push("Logs");
// clean up and ignore the clean up errors
match std::fs::remove_dir_all(&temp_test_path) {
Ok(_) => {}
Err(e) => {
print!("Failed to remove_dir_all with error {}.", e);
}
}
misc_helpers::try_create_folder(&temp_test_path).unwrap();
let access_control_rules = AccessControlRules {
roles: Some(vec![Role {
name: "test".to_string(),
privileges: vec!["test".to_string(), "test1".to_string()],
}]),
privileges: Some(vec![Privilege {
name: "test".to_string(),
path: "/test".to_string(),
queryParameters: None,
}]),
identities: Some(vec![Identity {
name: "test".to_string(),
exePath: Some("test".to_string()),
groupName: Some("test".to_string()),
processName: Some("test".to_string()),
userName: Some("test".to_string()),
}]),
roleAssignments: Some(vec![RoleAssignment {
role: "test".to_string(),
identities: vec!["test".to_string()],
}]),
};
let authorization_item: AuthorizationItem = AuthorizationItem {
defaultAccess: "deny".to_string(),
mode: "enforce".to_string(),
rules: Some(access_control_rules),
id: "0".to_string(),
};
let computed_authorization_item =
ComputedAuthorizationItem::from_authorization_item(authorization_item.clone());
let authorization_rules_for_logging = AuthorizationRulesForLogging::new(
Some(AuthorizationRules {
imds: Some(authorization_item.clone()),
wireserver: Some(authorization_item.clone()),
hostga: Some(authorization_item.clone()),
}),
ComputedAuthorizationRules {
imds: Some(computed_authorization_item.clone()),
wireserver: Some(computed_authorization_item.clone()),
hostga: Some(computed_authorization_item.clone()),
},
);
let max_file_count = 5;
for _ in 0..10 {
authorization_rules_for_logging.write_all(&temp_test_path, max_file_count);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
let files =
misc_helpers::search_files(&temp_test_path, r"^AuthorizationRules_.*\.json$").unwrap();
assert_eq!(files.len(), max_file_count);
// clean up and ignore the clean up errors
_ = std::fs::remove_dir_all(&temp_test_path);
}
}