proxy_agent/src/proxy/proxy_authorizer.rs (603 lines of code) (raw):

// Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT //! This module contains the logic to authorize the connection based on the claims. //! The claims are used to determine if the process is allowed to connect to the remote server. //! //! Example //! ```rust //! use proxy_agent::proxy_authorizer; //! use proxy_agent::proxy::Claims; //! use proxy_agent::shared_state::key_keeper_wrapper::KeyKeeperSharedState; //! use proxy_agent::common::constants; //! use std::str::FromStr; //! //! let key_keeper_shared_state = KeyKeeperSharedState::start_new(); //! let vm_metadata = proxy_authorizer::get_access_control_rules(constants::WIRE_SERVER_IP.to_string(), constants::WIRE_SERVER_PORT, key_keeper_shared_state.clone()).await.unwrap(); //! let authorizer = proxy_authorizer::get_authorizer(constants::WIRE_SERVER_IP, constants::WIRE_SERVER_PORT, claims); //! let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); //! authorizer.authorize(logger, url, vm_metadata); //! use super::authorization_rules::{AuthorizationMode, ComputedAuthorizationItem}; use super::proxy_connection::ConnectionLogger; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; use crate::{common::constants, common::result::Result, proxy::Claims}; use proxy_agent_shared::logger::LoggerLevel; #[derive(PartialEq)] pub enum AuthorizeResult { Ok, OkWithAudit, Forbidden, } pub trait Authorizer { // authorize the connection fn authorize( &self, logger: &mut ConnectionLogger, request_url: hyper::Uri, access_control_rules: Option<ComputedAuthorizationItem>, ) -> AuthorizeResult; fn to_string(&self) -> String; fn type_name(&self) -> String { std::any::type_name::<Self>().to_string() } } struct WireServer { claims: Claims, } impl Authorizer for WireServer { fn authorize( &self, logger: &mut ConnectionLogger, request_url: hyper::Uri, access_control_rules: Option<ComputedAuthorizationItem>, ) -> AuthorizeResult { if !self.claims.runAsElevated { return AuthorizeResult::Forbidden; } if let Some(rules) = access_control_rules { if rules.is_allowed(logger, request_url.clone(), self.claims.clone()) { return AuthorizeResult::Ok; } else { if rules.mode == AuthorizationMode::Audit { logger.write( LoggerLevel::Info, format!("WireServer request {} denied in audit mode, continue forward the request", request_url)); return AuthorizeResult::OkWithAudit; } return AuthorizeResult::Forbidden; } } AuthorizeResult::Ok } fn to_string(&self) -> String { format!( "WireServer {{ runAsElevated: {}, processName: {} }}", self.claims.runAsElevated, self.claims.processName.to_string_lossy() ) } } struct Imds { #[allow(dead_code)] claims: Claims, } impl Authorizer for Imds { fn authorize( &self, logger: &mut ConnectionLogger, request_url: hyper::Uri, access_control_rules: Option<ComputedAuthorizationItem>, ) -> AuthorizeResult { if let Some(rules) = access_control_rules { if rules.is_allowed(logger, request_url.clone(), self.claims.clone()) { return AuthorizeResult::Ok; } else { if rules.mode == AuthorizationMode::Audit { logger.write( LoggerLevel::Info, format!( "IMDS request {} denied in audit mode, continue forward the request", request_url ), ); return AuthorizeResult::OkWithAudit; } return AuthorizeResult::Forbidden; } } AuthorizeResult::Ok } fn to_string(&self) -> String { "IMDS".to_string() } } struct GAPlugin { claims: Claims, } impl Authorizer for GAPlugin { fn authorize( &self, logger: &mut ConnectionLogger, request_url: hyper::Uri, access_control_rules: Option<ComputedAuthorizationItem>, ) -> AuthorizeResult { if !self.claims.runAsElevated { return AuthorizeResult::Forbidden; } if let Some(rules) = access_control_rules { if rules.is_allowed(logger, request_url.clone(), self.claims.clone()) { return AuthorizeResult::Ok; } else { if rules.mode == AuthorizationMode::Audit { logger.write( LoggerLevel::Info, format!("HostGAPlugin request {} denied in audit mode, continue forward the request", request_url)); return AuthorizeResult::OkWithAudit; } return AuthorizeResult::Forbidden; } } AuthorizeResult::Ok } fn to_string(&self) -> String { format!( "GAPlugin {{ runAsElevated: {}, processName: {} }}", self.claims.runAsElevated, self.claims.processName.to_string_lossy() ) } } struct ProxyAgent {} impl Authorizer for ProxyAgent { fn authorize( &self, _logger: &mut ConnectionLogger, _request_url: hyper::Uri, _access_control_rules: Option<ComputedAuthorizationItem>, ) -> AuthorizeResult { // Forbid the request send to this listener directly AuthorizeResult::Forbidden } fn to_string(&self) -> String { "ProxyAgent".to_string() } } struct Default {} impl Authorizer for Default { fn authorize( &self, _logger: &mut ConnectionLogger, _request_url: hyper::Uri, _access_control_rules: Option<ComputedAuthorizationItem>, ) -> AuthorizeResult { AuthorizeResult::Ok } fn to_string(&self) -> String { "Default".to_string() } } pub fn get_authorizer(ip: String, port: u16, claims: Claims) -> Box<dyn Authorizer> { if ip == constants::WIRE_SERVER_IP && port == constants::WIRE_SERVER_PORT { Box::new(WireServer { claims }) } else if ip == constants::GA_PLUGIN_IP && port == constants::GA_PLUGIN_PORT { return Box::new(GAPlugin { claims }); } else if ip == constants::IMDS_IP && port == constants::IMDS_PORT { return Box::new(Imds { claims }); } else if ip == constants::PROXY_AGENT_IP && port == constants::PROXY_AGENT_PORT { return Box::new(ProxyAgent {}); } else { Box::new(Default {}) } } pub async fn get_access_control_rules( ip: String, port: u16, key_keeper_shared_state: KeyKeeperSharedState, ) -> Result<Option<ComputedAuthorizationItem>> { match (ip.as_str(), port) { (constants::WIRE_SERVER_IP, constants::WIRE_SERVER_PORT) => { key_keeper_shared_state.get_wireserver_rules().await } (constants::GA_PLUGIN_IP, constants::GA_PLUGIN_PORT) => { key_keeper_shared_state.get_hostga_rules().await } (constants::IMDS_IP, constants::IMDS_PORT) => { key_keeper_shared_state.get_imds_rules().await } _ => Ok(None), } } pub fn authorize( ip: String, port: u16, logger: &mut ConnectionLogger, request_uri: hyper::Uri, claims: Claims, access_control_rules: Option<ComputedAuthorizationItem>, ) -> AuthorizeResult { let auth = get_authorizer(ip, port, claims); logger.write( LoggerLevel::Trace, format!("Got auth: {}", auth.to_string()), ); auth.authorize(logger, request_uri, access_control_rules) } #[cfg(test)] mod tests { use crate::{ key_keeper::key::AuthorizationItem, proxy::{proxy_authorizer::AuthorizeResult, proxy_connection::ConnectionLogger}, shared_state::key_keeper_wrapper::KeyKeeperSharedState, }; use std::{ffi::OsString, path::PathBuf, str::FromStr}; #[test] fn get_authenticate_test() { let claims = crate::proxy::Claims { userId: 0, userName: "test".to_string(), userGroups: vec!["test".to_string()], processId: std::process::id(), processName: OsString::from("test"), processFullPath: PathBuf::from("test"), processCmdLine: "test".to_string(), runAsElevated: true, clientIp: "127.0.0.1".to_string(), clientPort: 0, // doesn't matter for this test }; let mut test_logger = ConnectionLogger::new(0, 0); let auth: Box<dyn super::Authorizer> = super::get_authorizer( crate::common::constants::WIRE_SERVER_IP.to_string(), crate::common::constants::WIRE_SERVER_PORT, claims.clone(), ); let test_uri = hyper::Uri::from_str("test").unwrap(); assert_eq!( auth.to_string(), "WireServer { runAsElevated: true, processName: test }" ); assert!( AuthorizeResult::Ok == auth.authorize(&mut test_logger, test_uri.clone(), None), "WireServer authentication must be Ok" ); let auth = super::get_authorizer( crate::common::constants::GA_PLUGIN_IP.to_string(), crate::common::constants::GA_PLUGIN_PORT, claims.clone(), ); assert_eq!( auth.to_string(), "GAPlugin { runAsElevated: true, processName: test }" ); assert!( AuthorizeResult::Ok == auth.authorize(&mut test_logger, test_uri.clone(), None), "GAPlugin authentication must be Ok" ); let auth = super::get_authorizer( crate::common::constants::IMDS_IP.to_string(), crate::common::constants::IMDS_PORT, claims.clone(), ); assert_eq!(auth.to_string(), "IMDS"); assert!( AuthorizeResult::Ok == auth.authorize(&mut test_logger, test_uri.clone(), None), "IMDS authentication must be Ok" ); let auth = super::get_authorizer( crate::common::constants::PROXY_AGENT_IP.to_string(), crate::common::constants::PROXY_AGENT_PORT, claims.clone(), ); assert_eq!(auth.to_string(), "ProxyAgent"); assert!( AuthorizeResult::Forbidden == auth.authorize(&mut test_logger, test_uri.clone(), None), "ProxyAgent authentication must be Forbidden" ); let auth = super::get_authorizer( crate::common::constants::PROXY_AGENT_IP.to_string(), crate::common::constants::PROXY_AGENT_PORT + 1, claims.clone(), ); assert_eq!(auth.to_string(), "Default"); } #[tokio::test] async fn wireserver_authenticate_test() { let claims = crate::proxy::Claims { userId: 0, userName: "test".to_string(), userGroups: vec!["test".to_string()], processId: std::process::id(), processName: OsString::from("test"), processFullPath: PathBuf::from("test"), processCmdLine: "test".to_string(), runAsElevated: true, clientIp: "127.0.0.1".to_string(), clientPort: 0, // doesn't matter for this test }; let mut test_logger = ConnectionLogger::new(1, 1); let auth = super::get_authorizer( crate::common::constants::WIRE_SERVER_IP.to_string(), crate::common::constants::WIRE_SERVER_PORT, claims.clone(), ); let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); let key_keeper_shared_state = KeyKeeperSharedState::start_new(); // validate disabled rules let disabled_rules = AuthorizationItem { defaultAccess: "deny".to_string(), mode: "disabled".to_string(), id: "id".to_string(), rules: None, }; key_keeper_shared_state .set_wireserver_rules(Some(disabled_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state .get_wireserver_rules() .await .unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Ok, "WireServer authentication must be Ok with disabled rules" ); // validate audit rules let audit_deny_rules = AuthorizationItem { defaultAccess: "deny".to_string(), mode: "audit".to_string(), id: "id".to_string(), rules: None, }; let audit_allow_rules = AuthorizationItem { defaultAccess: "allow".to_string(), mode: "audit".to_string(), id: "id".to_string(), rules: None, }; key_keeper_shared_state .set_wireserver_rules(Some(audit_allow_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state .get_wireserver_rules() .await .unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Ok, "WireServer authentication must be Ok with audit allow rules" ); key_keeper_shared_state .set_wireserver_rules(Some(audit_deny_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state .get_wireserver_rules() .await .unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::OkWithAudit, "WireServer authentication must be OkWithAudit with audit deny rules" ); // validate enforce rules let enforce_allow_rules = AuthorizationItem { defaultAccess: "allow".to_string(), mode: "enforce".to_string(), id: "id".to_string(), rules: None, }; let enforce_deny_rules = AuthorizationItem { defaultAccess: "deny".to_string(), mode: "enforce".to_string(), id: "id".to_string(), rules: None, }; key_keeper_shared_state .set_wireserver_rules(Some(enforce_allow_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state .get_wireserver_rules() .await .unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Ok, "WireServer authentication must be Ok with enforce allow rules" ); key_keeper_shared_state .set_wireserver_rules(Some(enforce_deny_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state .get_wireserver_rules() .await .unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Forbidden, "WireServer authentication must be Forbidden with enforce deny rules" ); } #[tokio::test] async fn imds_authenticate_test() { let mut test_logger = ConnectionLogger::new(1, 1); let claims = crate::proxy::Claims { userId: 0, userName: "test".to_string(), userGroups: vec!["test".to_string()], processId: std::process::id(), processName: OsString::from("test"), processFullPath: PathBuf::from("test"), processCmdLine: "test".to_string(), runAsElevated: true, clientIp: "127.0.0.1".to_string(), clientPort: 0, // doesn't matter for this test }; let auth = super::get_authorizer( crate::common::constants::IMDS_IP.to_string(), crate::common::constants::IMDS_PORT, claims.clone(), ); let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); let key_keeper_shared_state = KeyKeeperSharedState::start_new(); // validate disabled rules let disabled_rules = AuthorizationItem { defaultAccess: "deny".to_string(), mode: "disabled".to_string(), id: "id".to_string(), rules: None, }; key_keeper_shared_state .set_imds_rules(Some(disabled_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state.get_imds_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules,) == AuthorizeResult::Ok, "IMDS authentication must be Ok with disabled rules" ); // validate audit rules let audit_deny_rules = AuthorizationItem { defaultAccess: "deny".to_string(), mode: "audit".to_string(), id: "id".to_string(), rules: None, }; let audit_allow_rules = AuthorizationItem { defaultAccess: "allow".to_string(), mode: "audit".to_string(), id: "id".to_string(), rules: None, }; key_keeper_shared_state .set_imds_rules(Some(audit_allow_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state.get_imds_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules,) == AuthorizeResult::Ok, "IMDS authentication must be Ok with audit allow rules" ); key_keeper_shared_state .set_imds_rules(Some(audit_deny_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state.get_imds_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules,) == AuthorizeResult::OkWithAudit, "IMDS authentication must be OkWithAudit with audit deny rules" ); // validate enforce rules let enforce_allow_rules = AuthorizationItem { defaultAccess: "allow".to_string(), mode: "enforce".to_string(), id: "id".to_string(), rules: None, }; let enforce_deny_rules = AuthorizationItem { defaultAccess: "deny".to_string(), mode: "enforce".to_string(), id: "id".to_string(), rules: None, }; key_keeper_shared_state .set_imds_rules(Some(enforce_allow_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state.get_imds_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules,) == AuthorizeResult::Ok, "IMDS authentication must be Ok with enforce allow rules" ); key_keeper_shared_state .set_imds_rules(Some(enforce_deny_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state.get_imds_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules,) == AuthorizeResult::Forbidden, "IMDS authentication must be Forbidden with enforce deny rules" ); } #[tokio::test] async fn hostga_authenticate_test() { let claims = crate::proxy::Claims { userId: 0, userName: "test".to_string(), userGroups: vec!["test".to_string()], processId: std::process::id(), processName: OsString::from("test"), processFullPath: PathBuf::from("test"), processCmdLine: "test".to_string(), runAsElevated: true, clientIp: "127.0.0.1".to_string(), clientPort: 0, // doesn't matter for this test }; let mut test_logger = ConnectionLogger::new(1, 1); let auth = super::get_authorizer( crate::common::constants::GA_PLUGIN_IP.to_string(), crate::common::constants::GA_PLUGIN_PORT, claims.clone(), ); let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); let key_keeper_shared_state = KeyKeeperSharedState::start_new(); // validate disabled rules let disabled_rules = AuthorizationItem { defaultAccess: "deny".to_string(), mode: "disabled".to_string(), id: "id".to_string(), rules: None, }; key_keeper_shared_state .set_hostga_rules(Some(disabled_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state.get_hostga_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Ok, "HostGA authentication must be Ok with disabled rules" ); // validate audit rules let audit_deny_rules = AuthorizationItem { defaultAccess: "deny".to_string(), mode: "audit".to_string(), id: "id".to_string(), rules: None, }; let audit_allow_rules = AuthorizationItem { defaultAccess: "allow".to_string(), mode: "audit".to_string(), id: "id".to_string(), rules: None, }; key_keeper_shared_state .set_hostga_rules(Some(audit_allow_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state.get_hostga_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Ok, "HostGA authentication must be Ok with audit allow rules" ); key_keeper_shared_state .set_hostga_rules(Some(audit_deny_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state.get_hostga_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::OkWithAudit, "HostGA authentication must be OkWithAudit with audit deny rules" ); // validate enforce rules let enforce_allow_rules = AuthorizationItem { defaultAccess: "allow".to_string(), mode: "enforce".to_string(), id: "id".to_string(), rules: None, }; let enforce_deny_rules = AuthorizationItem { defaultAccess: "deny".to_string(), mode: "enforce".to_string(), id: "id".to_string(), rules: None, }; key_keeper_shared_state .set_hostga_rules(Some(enforce_allow_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state.get_hostga_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Ok, "HostGA authentication must be Ok with enforce allow rules" ); key_keeper_shared_state .set_hostga_rules(Some(enforce_deny_rules)) .await .unwrap(); let access_control_rules = key_keeper_shared_state.get_hostga_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Forbidden, "HostGA authentication must be Forbidden with enforce deny rules" ); } }