rust/azure_iot_operations_services/src/schema_registry/client.rs (159 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
//! Client for Schema Registry operations.
//!
//! To use this client, the `schema_registry` feature must be enabled.
use std::sync::Arc;
use std::time::Duration;
use azure_iot_operations_mqtt::interface::ManagedClient;
use azure_iot_operations_protocol::application::ApplicationContext;
use azure_iot_operations_protocol::rpc_command;
use crate::schema_registry::schemaregistry_gen::common_types::options::CommandInvokerOptionsBuilder;
use crate::schema_registry::schemaregistry_gen::schema_registry::client::{
GetCommandInvoker, GetRequestPayloadBuilder, GetRequestSchemaBuilder, PutCommandInvoker,
PutRequestPayloadBuilder, PutRequestSchemaBuilder,
};
use crate::schema_registry::{Error, ErrorKind, GetRequest, PutRequest, Schema};
/// Schema registry client implementation.
#[derive(Clone)]
pub struct Client<C>
where
C: ManagedClient + Clone + Send + Sync + 'static,
C::PubReceiver: Send + Sync,
{
get_command_invoker: Arc<GetCommandInvoker<C>>,
put_command_invoker: Arc<PutCommandInvoker<C>>,
client_id: String, // TODO: Temporary until the schema registry service updates their executor
}
impl<C> Client<C>
where
C: ManagedClient + Clone + Send + Sync + 'static,
C::PubReceiver: Send + Sync,
{
/// Create a new Schema Registry Client.
///
/// # Panics
/// Panics if the options for the underlying command invokers cannot be built. Not possible since
/// the options are statically generated.
pub fn new(application_context: ApplicationContext, client: &C) -> Self {
let options = CommandInvokerOptionsBuilder::default()
.build()
.expect("Statically generated options should not fail.");
Self {
get_command_invoker: Arc::new(GetCommandInvoker::new(
application_context.clone(),
client.clone(),
&options,
)),
put_command_invoker: Arc::new(PutCommandInvoker::new(
application_context,
client.clone(),
&options,
)),
client_id: client.client_id().to_string(), // TODO: Temporary until the schema registry service updates their executor
}
}
/// Retrieves schema information from a schema registry service.
///
/// # Arguments
/// * `get_request` - The request to get a schema from the schema registry.
/// * `timeout` - The duration until the Schema Registry Client stops waiting for a response to the request, it is rounded up to the nearest second.
///
/// Returns a [`Schema`] if the schema was found, otherwise returns `None`.
///
/// # Errors
/// [`struct@Error`] of kind [`InvalidArgument`](ErrorKind::InvalidArgument)
/// if the `timeout` is zero or > `u32::max`, or there is an error building the request.
///
/// [`struct@Error`] of kind [`SerializationError`](ErrorKind::SerializationError)
/// if there is an error serializing the request.
///
/// [`struct@Error`] of kind [`ServiceError`](ErrorKind::ServiceError)
/// if there is an error returned by the Schema Registry Service.
///
/// [`struct@Error`] of kind [`AIOProtocolError`](ErrorKind::AIOProtocolError)
/// if there are any underlying errors from the AIO RPC protocol.
pub async fn get(
&self,
get_request: GetRequest,
timeout: Duration,
) -> Result<Option<Schema>, Error> {
let get_request_payload = GetRequestPayloadBuilder::default()
.get_schema_request(
GetRequestSchemaBuilder::default()
.name(Some(get_request.id))
.version(Some(get_request.version))
.build()
.map_err(|e| Error(ErrorKind::InvalidArgument(e.to_string())))?,
)
.build()
.map_err(|e| Error(ErrorKind::InvalidArgument(e.to_string())))?;
let command_request = rpc_command::invoker::RequestBuilder::default()
.custom_user_data(vec![("__invId".to_string(), self.client_id.clone())]) // TODO: Temporary until the schema registry service updates their executor
.payload(get_request_payload)
.map_err(|e| Error(ErrorKind::SerializationError(e.to_string())))?
.timeout(timeout)
.build()
.map_err(|e| Error(ErrorKind::InvalidArgument(e.to_string())))?;
let get_result = self.get_command_invoker.invoke(command_request).await;
match get_result {
Ok(response) => Ok(response.payload.schema),
Err(e) => {
if let azure_iot_operations_protocol::common::aio_protocol_error::AIOProtocolErrorKind::PayloadInvalid = e.kind {
if let Some(nested_error) = &e.nested_error {
if let Some(json_error) = nested_error.downcast_ref::<serde_json::Error>() {
if json_error.is_eof() && json_error.column() == 0 && json_error.line() == 1 {
return Ok(None);
}
}
}
}
Err(Error(ErrorKind::from(e)))
}
}
}
/// Adds or updates a schema in the schema registry service.
///
/// # Arguments
/// * `put_request` - The request to put a schema in the schema registry.
/// * `timeout` - The duration until the Schema Registry Client stops waiting for a response to the request, it is rounded up to the nearest second.
///
/// Returns the [`Schema`] that was put if the request was successful.
///
/// # Errors
/// [`struct@Error`] of kind [`InvalidArgument`](ErrorKind::InvalidArgument)
/// if the `content` is empty, the `timeout` is zero or > `u32::max`, or there is an error building the request.
///
/// [`struct@Error`] of kind [`SerializationError`](ErrorKind::SerializationError)
/// if there is an error serializing the request.
///
/// [`struct@Error`] of kind [`ServiceError`](ErrorKind::ServiceError)
/// if there is an error returned by the Schema Registry Service.
///
/// [`struct@Error`] of kind [`AIOProtocolError`](ErrorKind::AIOProtocolError)
/// if there are any underlying errors from the AIO RPC protocol.
pub async fn put(&self, put_request: PutRequest, timeout: Duration) -> Result<Schema, Error> {
let put_request_payload = PutRequestPayloadBuilder::default()
.put_schema_request(
PutRequestSchemaBuilder::default()
.format(Some(put_request.format))
.schema_content(Some(put_request.content))
.version(Some(put_request.version))
.tags(Some(put_request.tags))
.schema_type(Some(put_request.schema_type))
.build()
.map_err(|e| Error(ErrorKind::InvalidArgument(e.to_string())))?,
)
.build()
.map_err(|e| Error(ErrorKind::InvalidArgument(e.to_string())))?;
let command_request = rpc_command::invoker::RequestBuilder::default()
.custom_user_data(vec![("__invId".to_string(), self.client_id.clone())]) // TODO: Temporary until the schema registry service updates their executor
.payload(put_request_payload)
.map_err(|e| Error(ErrorKind::SerializationError(e.to_string())))?
.timeout(timeout)
.build()
.map_err(|e| Error(ErrorKind::InvalidArgument(e.to_string())))?;
Ok(self
.put_command_invoker
.invoke(command_request)
.await
.map_err(ErrorKind::from)?
.payload
.schema)
}
/// Shutdown the [`Client`]. Shuts down the underlying command invokers for get and put operations.
///
/// Note: If this method is called, the [`Client`] should not be used again.
/// If the method returns an error, it may be called again to re-attempt unsubscribing.
///
/// Returns Ok(()) on success, otherwise returns [`struct@Error`].
/// # Errors
/// [`struct@Error`] of kind [`AIOProtocolError`](ErrorKind::AIOProtocolError)
/// if the unsubscribe fails or if the unsuback reason code doesn't indicate success.
pub async fn shutdown(&self) -> Result<(), Error> {
// Shutdown the get command invoker
self.get_command_invoker
.shutdown()
.await
.map_err(ErrorKind::from)?;
// Shutdown the put command invoker
self.put_command_invoker
.shutdown()
.await
.map_err(ErrorKind::from)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use azure_iot_operations_mqtt::{
MqttConnectionSettingsBuilder,
session::{Session, SessionOptionsBuilder},
};
use azure_iot_operations_protocol::application::ApplicationContextBuilder;
use crate::schema_registry::{
Client, DEFAULT_SCHEMA_VERSION, Error, ErrorKind, Format, GetRequestBuilder,
GetRequestBuilderError, PutRequestBuilder, SchemaType,
};
// TODO: This should return a mock ManagedClient instead.
// Until that's possible, need to return a Session so that the Session doesn't go out of
// scope and render the ManagedClient unable to to be used correctly.
fn create_session() -> Session {
// TODO: Make a real mock that implements MqttProvider
let connection_settings = MqttConnectionSettingsBuilder::default()
.hostname("localhost")
.client_id("test_client")
.build()
.unwrap();
let session_options = SessionOptionsBuilder::default()
.connection_settings(connection_settings)
.build()
.unwrap();
Session::new(session_options).unwrap()
}
const TEST_SCHEMA_ID: &str = "test_schema_id";
const TEST_SCHEMA_CONTENT: &str = r#"
{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"test": {
"type": "integer"
},
}
}
"#;
#[tokio::test]
async fn test_get_request_valid() {
let get_request = GetRequestBuilder::default()
.id(TEST_SCHEMA_ID.to_string())
.build()
.unwrap();
assert_eq!(get_request.id, TEST_SCHEMA_ID);
assert_eq!(get_request.version, DEFAULT_SCHEMA_VERSION.to_string());
}
#[tokio::test]
async fn test_get_request_invalid_id() {
let get_request = GetRequestBuilder::default().build();
assert!(matches!(
get_request.unwrap_err(),
GetRequestBuilderError::UninitializedField(_)
));
let get_request = GetRequestBuilder::default().id(String::new()).build();
assert!(matches!(
get_request.unwrap_err(),
GetRequestBuilderError::ValidationError(_)
));
}
#[tokio::test]
async fn test_put_request_valid() {
let put_request = PutRequestBuilder::default()
.content(TEST_SCHEMA_CONTENT.to_string())
.format(Format::JsonSchemaDraft07)
.build()
.unwrap();
assert_eq!(put_request.content, TEST_SCHEMA_CONTENT);
assert!(matches!(put_request.format, Format::JsonSchemaDraft07));
assert!(matches!(put_request.schema_type, SchemaType::MessageSchema));
assert_eq!(put_request.tags, HashMap::new());
assert_eq!(put_request.version, DEFAULT_SCHEMA_VERSION.to_string());
}
#[tokio::test]
async fn test_get_timeout_invalid() {
let session = create_session();
let client = Client::new(
ApplicationContextBuilder::default().build().unwrap(),
&session.create_managed_client(),
);
let get_result = client
.get(
GetRequestBuilder::default()
.id(TEST_SCHEMA_ID.to_string())
.build()
.unwrap(),
std::time::Duration::from_millis(0),
)
.await;
assert!(matches!(
get_result.unwrap_err(),
Error(ErrorKind::InvalidArgument(_))
));
let get_result = client
.get(
GetRequestBuilder::default()
.id(TEST_SCHEMA_ID.to_string())
.build()
.unwrap(),
std::time::Duration::from_secs(u64::from(u32::MAX) + 1),
)
.await;
assert!(matches!(
get_result.unwrap_err(),
Error(ErrorKind::InvalidArgument(_))
));
}
#[tokio::test]
async fn test_put_timeout_invalid() {
let session = create_session();
let client = Client::new(
ApplicationContextBuilder::default().build().unwrap(),
&session.create_managed_client(),
);
let put_result = client
.put(
PutRequestBuilder::default()
.content(TEST_SCHEMA_CONTENT.to_string())
.format(Format::JsonSchemaDraft07)
.build()
.unwrap(),
std::time::Duration::from_millis(0),
)
.await;
assert!(matches!(
put_result.unwrap_err(),
Error(ErrorKind::InvalidArgument(_))
));
let put_result = client
.put(
PutRequestBuilder::default()
.content(TEST_SCHEMA_CONTENT.to_string())
.format(Format::JsonSchemaDraft07)
.build()
.unwrap(),
std::time::Duration::from_secs(u64::from(u32::MAX) + 1),
)
.await;
assert!(matches!(
put_result.unwrap_err(),
Error(ErrorKind::InvalidArgument(_))
));
}
}