key/aziot-key-openssl-engine-shared-test/src/main.rs (338 lines of code) (raw):
// Copyright (c) Microsoft. All rights reserved.
#![deny(rust_2018_idioms)]
#![warn(clippy::all, clippy::pedantic)]
#![allow(
clippy::default_trait_access,
clippy::let_unit_value,
clippy::too_many_lines,
clippy::use_self
)]
use clap::Parser;
#[tokio::main]
async fn main() -> Result<(), Error> {
openssl::init();
let command = Command::parse();
match command {
Command::GenerateCaCert {
key_handle,
out_file,
subject,
} => generate_cert(key_handle, &out_file, &subject, &GenerateCertKind::Ca)?,
Command::GenerateClientCert {
ca_cert,
ca_key_handle,
key_handle,
out_file,
subject,
} => generate_cert(
key_handle,
&out_file,
&subject,
&GenerateCertKind::Client {
ca_cert,
ca_key_handle,
},
)?,
Command::GenerateServerCert {
ca_cert,
ca_key_handle,
key_handle,
out_file,
subject,
} => generate_cert(
key_handle,
&out_file,
&subject,
&GenerateCertKind::Server {
ca_cert,
ca_key_handle,
},
)?,
Command::WebClient {
cert,
key_handle,
port,
} => {
let mut http_connector = hyper::client::HttpConnector::new();
http_connector.enforce_http(false);
let mut engine = load_engine()?;
let key = load_private_key(&mut engine, key_handle)?;
let mut tls_connector =
openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?;
tls_connector.set_private_key(&key)?;
tls_connector.set_certificate_chain_file(&cert)?;
// The root of the client cert is the CA, and we expect the server cert to be signed by this same CA.
// So add it to the cert store.
let ca_cert = {
let cert_chain_file = std::fs::read(cert)?;
let mut cert_chain = openssl::x509::X509::stack_from_pem(&cert_chain_file)?;
cert_chain.pop().unwrap()
};
tls_connector.cert_store_mut().add_cert(ca_cert)?;
// Log the server cert chain. Does not change the verification result from what openssl already concluded.
tls_connector.set_verify_callback(
openssl::ssl::SslVerifyMode::PEER,
|openssl_verification_result, context| {
println!("Server cert:");
let chain = context.chain().unwrap();
for (i, cert) in chain.into_iter().enumerate() {
println!(
" #{}: {}",
i + 1,
cert.subject_name()
.entries()
.next()
.unwrap()
.data()
.as_utf8()
.unwrap()
);
}
println!("openssl verification result: {openssl_verification_result}");
openssl_verification_result
},
);
let tls_connector =
hyper_openssl::HttpsConnector::with_connector(http_connector, tls_connector)?;
let client: hyper::Client<_, hyper::Body> =
hyper::Client::builder().build(tls_connector);
let response = client
.get(format!("https://127.0.0.1:{port}/").parse()?)
.await?;
let (http::response::Parts { status, .. }, response_body) = response.into_parts();
let response_body = hyper::body::to_bytes(response_body).await?;
println!("server returned {status} {response_body:?}");
if status != http::StatusCode::OK || &*response_body != b"Hello, world!\n" {
return Err("server did not return expected response".into());
}
}
Command::WebServer {
cert,
key_handle,
port,
} => {
let mut engine = load_engine()?;
let key = load_private_key(&mut engine, key_handle)?;
let incoming =
test_common::tokio_openssl2::Incoming::new("0.0.0.0", port, &cert, &key, true)?;
let server =
hyper::Server::builder(incoming).serve(hyper::service::make_service_fn(|_| {
futures_util::future::ok::<_, std::convert::Infallible>(
hyper::service::service_fn(|_| {
futures_util::future::ok::<_, std::convert::Infallible>(
hyper::Response::new(hyper::Body::from("Hello, world!\n")),
)
}),
)
}));
println!("Starting web server...");
let () = server.await?;
}
}
Ok(())
}
fn load_engine() -> Result<openssl2::FunctionalEngine, Error> {
const ENGINE_ID: &[u8] = b"aziot_keys\0";
unsafe {
openssl_sys2::ENGINE_load_builtin_engines();
}
let engine_id =
std::ffi::CStr::from_bytes_with_nul(ENGINE_ID).expect("hard-coded engine ID is valid CStr");
let engine = openssl2::StructuralEngine::by_id(engine_id)?;
let engine: openssl2::FunctionalEngine = engine.try_into()?;
println!("Loaded engine: [{}]", engine.name()?.to_string_lossy());
Ok(engine)
}
fn load_public_key(
engine: &mut openssl2::FunctionalEngine,
key_handle: String,
) -> Result<openssl::pkey::PKey<openssl::pkey::Public>, Error> {
let key_handle = std::ffi::CString::new(key_handle)?;
let key = engine.load_public_key(&key_handle)?;
Ok(key)
}
fn load_private_key(
engine: &mut openssl2::FunctionalEngine,
key_handle: String,
) -> Result<openssl::pkey::PKey<openssl::pkey::Private>, Error> {
let key_handle = std::ffi::CString::new(key_handle)?;
let key = engine.load_private_key(&key_handle)?;
Ok(key)
}
fn generate_cert(
key_handle: String,
out_file: &std::path::Path,
subject: &str,
kind: &GenerateCertKind,
) -> Result<(), Error> {
let mut engine = load_engine()?;
let mut builder = openssl::x509::X509::builder()?;
builder.set_version(2)?;
let public_key = load_public_key(&mut engine, key_handle.clone())?;
builder.set_pubkey(&public_key)?;
let not_after = openssl::asn1::Asn1Time::days_from_now(match &kind {
GenerateCertKind::Ca => 365,
GenerateCertKind::Client { .. } | GenerateCertKind::Server { .. } => 30,
})?;
builder.set_not_after(std::borrow::Borrow::borrow(¬_after))?;
let not_before = openssl::asn1::Asn1Time::days_from_now(0)?;
builder.set_not_before(std::borrow::Borrow::borrow(¬_before))?;
let mut subject_name = openssl::x509::X509Name::builder()?;
subject_name.append_entry_by_text("CN", subject)?;
let subject_name = subject_name.build();
builder.set_subject_name(&subject_name)?;
match &kind {
GenerateCertKind::Ca => {
builder.set_issuer_name(&subject_name)?;
let ca_extension = openssl::x509::extension::BasicConstraints::new()
.ca()
.build()?;
builder.append_extension(ca_extension)?;
}
GenerateCertKind::Client { ca_cert, .. } | GenerateCertKind::Server { ca_cert, .. } => {
let ca_cert = std::fs::read(ca_cert)?;
let ca_cert = openssl::x509::X509::from_pem(&ca_cert)?;
builder.set_issuer_name(ca_cert.subject_name())?;
match kind {
GenerateCertKind::Ca => unreachable!(),
GenerateCertKind::Client { .. } => {
let client_extension = openssl::x509::extension::ExtendedKeyUsage::new()
.client_auth()
.build()?;
builder.append_extension(client_extension)?;
}
GenerateCertKind::Server { .. } => {
let server_extension = openssl::x509::extension::ExtendedKeyUsage::new()
.server_auth()
.build()?;
builder.append_extension(server_extension)?;
let context = builder.x509v3_context(Some(&ca_cert), None);
let san_extension = openssl::x509::extension::SubjectAlternativeName::new()
.ip("127.0.0.1")
.build(&context)?;
builder.append_extension(san_extension)?;
}
}
}
}
let ca_key_handle = match &kind {
GenerateCertKind::Ca => key_handle,
GenerateCertKind::Client { ca_key_handle, .. }
| GenerateCertKind::Server { ca_key_handle, .. } => ca_key_handle.clone(),
};
let ca_key = load_private_key(&mut engine, ca_key_handle)?;
builder.sign(&ca_key, openssl::hash::MessageDigest::sha256())?;
let cert = builder.build();
let cert = cert.to_pem()?;
let mut out_file = std::fs::File::create(out_file)?;
std::io::Write::write_all(&mut out_file, &cert)?;
match &kind {
GenerateCertKind::Ca => (),
GenerateCertKind::Client { ca_cert, .. } | GenerateCertKind::Server { ca_cert, .. } => {
let ca_cert = std::fs::read(ca_cert)?;
std::io::Write::write_all(&mut out_file, &ca_cert)?;
}
}
std::io::Write::flush(&mut out_file)?;
Ok(())
}
#[derive(Debug)]
enum GenerateCertKind {
Ca,
Client {
ca_cert: std::path::PathBuf,
ca_key_handle: String,
},
Server {
ca_cert: std::path::PathBuf,
ca_key_handle: String,
},
}
struct Error(Box<dyn std::error::Error>, backtrace::Backtrace);
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{}", self.0)?;
let mut source = self.0.source();
while let Some(err) = source {
writeln!(f, "caused by: {err}")?;
source = err.source();
}
writeln!(f)?;
writeln!(f, "{:?}", self.1)?;
Ok(())
}
}
impl<E> From<E> for Error
where
E: Into<Box<dyn std::error::Error>>,
{
fn from(err: E) -> Self {
Error(err.into(), Default::default())
}
}
#[derive(Parser)]
enum Command {
/// Generate a CA cert.
GenerateCaCert {
/// A key handle to the key pair that will be used for the CA cert.
#[arg(long)]
key_handle: String,
/// The path where the CA cert PEM file will be stored.
#[arg(long)]
out_file: std::path::PathBuf,
/// The subject CN of the new cert.
#[arg(long)]
subject: String,
},
/// Generate a client auth cert.
GenerateClientCert {
/// The path of the CA cert PEM file.
#[arg(long)]
ca_cert: std::path::PathBuf,
/// A key handle to the key pair of the CA.
#[arg(long)]
ca_key_handle: String,
/// A key handle to the key pair that will be used for the client cert.
#[arg(long)]
key_handle: String,
/// The path where the client cert PEM file will be stored.
#[arg(long)]
out_file: std::path::PathBuf,
/// The subject CN of the new cert.
#[arg(long)]
subject: String,
},
/// Generate a server auth cert.
GenerateServerCert {
/// The path of the CA cert PEM file.
#[arg(long)]
ca_cert: std::path::PathBuf,
/// A key handle to the key pair of the CA.
#[arg(long)]
ca_key_handle: String,
/// A key handle to the key pair that will be used for the server cert.
#[arg(long)]
key_handle: String,
/// The path where the server cert PEM file will be stored.
#[arg(long)]
out_file: std::path::PathBuf,
/// The subject CN of the new cert.
#[arg(long)]
subject: String,
},
/// Start a web client that uses the specified private key and cert file for TLS.
WebClient {
/// Path of the client cert file.
#[arg(long)]
cert: std::path::PathBuf,
/// A key handle to the client cert's key pair.
#[arg(long)]
key_handle: String,
/// The port to listen on.
#[arg(long, default_value_t = 8443)]
port: u16,
},
/// Start a web server that uses the specified private key and cert file for TLS.
WebServer {
/// Path of the server cert file.
#[arg(long)]
cert: std::path::PathBuf,
/// A key handle to the server cert's key pair.
#[arg(long)]
key_handle: String,
/// The port to listen on.
#[arg(long, default_value_t = 8443)]
port: u16,
},
}