proxy_agent/src/proxy/proxy_connection.rs (251 lines of code) (raw):
// Copyright (c) Microsoft Corporation
// SPDX-License-Identifier: MIT
//! This module contains the connection context struct for the proxy listener, and write proxy processing logs to local file.
use crate::common::error::{Error, HyperErrorType};
use crate::common::hyper_client;
use crate::common::result::Result;
use crate::proxy::Claims;
use crate::redirector::{self, AuditEntry};
use crate::shared_state::proxy_server_wrapper::ProxyServerSharedState;
use crate::shared_state::redirector_wrapper::RedirectorSharedState;
use http_body_util::Full;
use hyper::body::Bytes;
use hyper::client::conn::http1;
use hyper::Request;
use proxy_agent_shared::logger::{self, logger_manager, LoggerLevel};
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;
pub type RequestBody = Full<Bytes>;
struct Client {
sender: http1::SendRequest<RequestBody>,
}
impl Client {
async fn send_request(
&mut self,
req: Request<RequestBody>,
) -> Result<hyper::Response<hyper::body::Incoming>> {
if self.sender.is_closed() {
return Err(Error::Hyper(HyperErrorType::HostConnection(
"the connection has been closed".to_string(),
)));
}
let full_url = req.uri().to_string();
self.sender.send_request(req).await.map_err(|e| {
Error::Hyper(HyperErrorType::Custom(
format!("Failed to send request to {}", full_url),
e,
))
})
}
}
#[derive(Clone)]
pub struct TcpConnectionContext {
pub id: u128,
pub client_addr: SocketAddr,
pub claims: Option<Claims>,
pub destination_ip: Option<Ipv4Addr>, // currently, we only support IPv4
pub destination_port: u16,
sender: std::result::Result<Arc<Mutex<Client>>, String>,
logger: ConnectionLogger,
}
impl TcpConnectionContext {
pub async fn new(
id: u128,
client_addr: SocketAddr,
redirector_shared_state: RedirectorSharedState,
proxy_server_shared_state: ProxyServerSharedState,
#[cfg(windows)] raw_socket_id: usize, // windows only, it is the raw socket id, used to get audit entry from socket stream
) -> Self {
let client_source_ip = client_addr.ip();
let client_source_port = client_addr.port();
let mut logger = ConnectionLogger::new(id, 0);
let (claims, destination_ip, destination_port, sender) = match Self::get_audit_entry(
&client_addr,
&redirector_shared_state,
&mut logger,
#[cfg(windows)]
raw_socket_id,
)
.await
{
Ok(audit_entry) => {
let claims = match Claims::from_audit_entry(
&audit_entry,
client_source_ip,
client_source_port,
proxy_server_shared_state,
)
.await
{
Ok(claims) => Some(claims),
Err(e) => {
logger.write(
LoggerLevel::Error,
format!("Failed to get claims from audit entry: {}", e),
);
// return None for claims
None
}
};
let host_ip = audit_entry.destination_ipv4_addr().to_string();
let host_port = audit_entry.destination_port_in_host_byte_order();
let mut cloned_logger = logger.clone();
let fun = move |message: String| {
cloned_logger.write(LoggerLevel::Warn, message);
};
let sender = match hyper_client::build_http_sender(&host_ip, host_port, fun).await {
Ok(sender) => {
logger.write(
LoggerLevel::Trace,
"Successfully created http sender".to_string(),
);
Ok(Arc::new(Mutex::new(Client { sender })))
}
Err(e) => Err(e.to_string()),
};
(
claims,
Some(audit_entry.destination_ipv4_addr()),
host_port,
sender,
)
}
Err(e) => {
logger.write(
LoggerLevel::Warn,
"This tcp connection may send to proxy agent tcp listener directly".to_string(),
);
(None, None, 0, Err(e.to_string()))
}
};
Self {
id,
client_addr,
claims,
destination_ip,
destination_port,
sender,
logger,
}
}
async fn get_audit_entry(
client_addr: &SocketAddr,
redirector_shared_state: &RedirectorSharedState,
logger: &mut ConnectionLogger,
#[cfg(windows)] raw_socket_id: usize,
) -> Result<AuditEntry> {
let client_source_port = client_addr.port();
match redirector::lookup_audit(client_source_port, redirector_shared_state).await {
Ok(data) => {
logger.write(
LoggerLevel::Trace,
format!(
"Found audit entry with client_source_port '{}' successfully",
client_source_port
),
);
match redirector::remove_audit(client_source_port, redirector_shared_state).await {
Ok(_) => logger.write(
LoggerLevel::Trace,
format!(
"Removed audit entry with client_source_port '{}' successfully",
client_source_port
),
),
Err(e) => {
logger.write(
LoggerLevel::Warn,
format!("Failed to remove audit entry: {}", e),
);
}
}
Ok(data)
}
Err(e) => {
let message = format!(
"Failed to find audit entry with client_source_port '{}' with error: {}",
client_source_port, e
);
logger.write(LoggerLevel::Warn, message.clone());
#[cfg(not(windows))]
{
Err(Error::FindAuditEntryError(message))
}
#[cfg(windows)]
{
logger.write(
LoggerLevel::Info,
"Try to get audit entry from socket stream".to_string(),
);
match redirector::get_audit_from_stream_socket(raw_socket_id) {
Ok(data) => {
logger.write(
LoggerLevel::Info,
"Found audit entry from socket stream successfully".to_string(),
);
Ok(data)
}
Err(e) => {
logger.write(
LoggerLevel::Warn,
format!("Failed to get lookup_audit_from_stream with error: {}", e),
);
Err(Error::FindAuditEntryError(message))
}
}
}
}
}
}
/// Get the target server ip address in string for logging purpose.
pub fn get_ip_string(&self) -> String {
if let Some(ip) = &self.destination_ip {
return ip.to_string();
}
"None".to_string()
}
pub fn log(&mut self, logger_level: LoggerLevel, message: String) {
self.logger.write(logger_level, message)
}
async fn send_request(
&self,
request: hyper::Request<RequestBody>,
) -> Result<hyper::Response<hyper::body::Incoming>> {
match &self.sender {
Ok(sender) => sender.lock().await.send_request(request).await,
Err(e) => Err(Error::Hyper(HyperErrorType::HostConnection(e.clone()))),
}
}
}
pub struct HttpConnectionContext {
pub id: u128,
pub now: Instant,
pub method: hyper::Method,
pub url: hyper::Uri,
pub tcp_connection_context: TcpConnectionContext,
pub logger: ConnectionLogger,
}
impl HttpConnectionContext {
pub fn should_skip_sig(&self) -> bool {
hyper_client::should_skip_sig(&self.method, &self.url)
}
pub fn contains_traversal_characters(&self) -> bool {
self.url.path().contains("..")
}
pub fn log(&mut self, logger_level: LoggerLevel, message: String) {
self.logger.write(logger_level, message)
}
pub fn get_logger_mut_ref(&mut self) -> &mut ConnectionLogger {
&mut self.logger
}
pub async fn send_request(
&self,
request: hyper::Request<RequestBody>,
) -> Result<hyper::Response<hyper::body::Incoming>> {
self.tcp_connection_context.send_request(request).await
}
}
pub struct ConnectionLogger {
pub tcp_connection_id: u128,
pub http_connection_id: u128,
queue: Vec<String>,
}
impl ConnectionLogger {
pub const CONNECTION_LOGGER_KEY: &'static str = "Connection_Logger";
pub fn new(tcp_connection_id: u128, http_connection_id: u128) -> Self {
Self {
tcp_connection_id,
http_connection_id,
queue: Vec::new(),
}
}
pub fn write(&mut self, logger_level: LoggerLevel, message: String) {
if logger_level > logger_manager::get_logger_level() {
return;
}
self.queue.push(format!(
"{}{}[{}] - {}",
logger::get_log_header(logger_level),
self.http_connection_id,
self.tcp_connection_id,
message
));
}
}
impl Drop for ConnectionLogger {
fn drop(&mut self) {
if !self.queue.is_empty() {
self.queue.push(format!(
"{}{}[{}] - {}",
logger::get_log_header(LoggerLevel::Info),
self.http_connection_id,
self.tcp_connection_id,
"------------------------ ConnectionLogger is dropped ------------------------"
));
logger_manager::write_many(
Some(Self::CONNECTION_LOGGER_KEY.to_string()),
self.queue.clone(),
);
}
}
}
impl Clone for ConnectionLogger {
fn clone(&self) -> Self {
Self {
tcp_connection_id: self.tcp_connection_id,
http_connection_id: self.http_connection_id,
queue: Vec::new(), // Do not clone the queue, as it is used for logging
}
}
}