utils/src/auth.rs (91 lines of code) (raw):
use std::fmt::Debug;
use std::sync::Arc;
#[cfg(not(target_family = "wasm"))]
use std::time::{SystemTime, UNIX_EPOCH};
use crate::errors::AuthError;
/// Helper type for information about an auth token.
/// Namely, the token itself and expiration time
pub type TokenInfo = (String, u64);
/// Helper to provide auth tokens to CAS.
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
pub trait TokenRefresher: Debug + Send + Sync {
/// Get a new auth token for CAS and the unixtime (in seconds) for expiration
async fn refresh(&self) -> Result<TokenInfo, AuthError>;
}
#[derive(Debug)]
pub struct NoOpTokenRefresher;
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
impl TokenRefresher for NoOpTokenRefresher {
async fn refresh(&self) -> Result<TokenInfo, AuthError> {
Ok(("token".to_string(), 0))
}
}
#[derive(Debug)]
pub struct ErrTokenRefresher;
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
impl TokenRefresher for ErrTokenRefresher {
async fn refresh(&self) -> Result<TokenInfo, AuthError> {
Err(AuthError::RefreshFunctionNotCallable("Token refresh not expected".to_string()))
}
}
/// Shared configuration for token-based auth
#[derive(Debug, Clone)]
pub struct AuthConfig {
/// Initial token to use
pub token: String,
/// Initial token expiration time in epoch seconds
pub token_expiration: u64,
/// A function to refresh tokens.
pub token_refresher: Arc<dyn TokenRefresher>,
}
impl AuthConfig {
/// Builds a new AuthConfig from the indicated optional parameters.
pub fn maybe_new(
token: Option<String>,
token_expiry: Option<u64>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
) -> Option<Self> {
match (token, token_expiry, token_refresher) {
// we have a refresher, so use that. Doesn't matter if the token/expiry are set since we can refresh them.
(token, expiry, Some(refresher)) => Some(Self {
token: token.unwrap_or_default(),
token_expiration: expiry.unwrap_or_default(),
token_refresher: refresher,
}),
// Since no refreshing, we instead use the token with some expiration (no expiration means we expect this
// token to live forever.
(Some(token), expiry, None) => Some(Self {
token,
token_expiration: expiry.unwrap_or(u64::MAX),
token_refresher: Arc::new(ErrTokenRefresher),
}),
(_, _, _) => None,
}
}
}
pub struct TokenProvider {
token: String,
expiration: u64,
refresher: Arc<dyn TokenRefresher>,
}
impl TokenProvider {
pub fn new(cfg: &AuthConfig) -> Self {
Self {
token: cfg.token.clone(),
expiration: cfg.token_expiration,
refresher: cfg.token_refresher.clone(),
}
}
pub async fn get_valid_token(&mut self) -> Result<String, AuthError> {
if self.is_expired() {
let (new_token, new_expiry) = self.refresher.refresh().await?;
self.token = new_token;
self.expiration = new_expiry;
}
Ok(self.token.clone())
}
fn is_expired(&self) -> bool {
#[cfg(not(target_family = "wasm"))]
let cur_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(u64::MAX);
#[cfg(target_family = "wasm")]
let cur_time = web_time::SystemTime::now()
.duration_since(web_time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(u64::MAX);
self.expiration <= cur_time
}
}