cas_client/src/remote_client.rs (1,095 lines of code) (raw):
use std::collections::HashMap;
use std::io::Write;
use std::mem::take;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::anyhow;
use cas_object::SerializedCasObject;
use cas_types::{
BatchQueryReconstructionResponse, CASReconstructionTerm, ChunkRange, FileRange, HttpRange, Key,
QueryReconstructionResponse, UploadShardResponse, UploadShardResponseType, UploadXorbResponse,
};
use chunk_cache::{CacheConfig, ChunkCache};
use error_printer::ErrorPrinter;
use file_utils::SafeFileCreator;
use http::header::{CONTENT_LENGTH, RANGE};
use http::HeaderValue;
use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDBFileInfo};
use mdb_shard::shard_file_reconstructor::FileReconstructor;
use mdb_shard::utils::shard_file_name;
use merklehash::{HashedWrite, MerkleHash};
use progress_tracking::item_tracking::SingleItemProgressUpdater;
use progress_tracking::upload_tracking::CompletionTracker;
use reqwest::{Body, Response, StatusCode, Url};
use reqwest_middleware::ClientWithMiddleware;
use tokio::sync::{mpsc, OwnedSemaphorePermit, Semaphore};
use tokio::task::{JoinHandle, JoinSet};
use tracing::{debug, info, instrument};
use utils::auth::AuthConfig;
#[cfg(not(target_family = "wasm"))]
use utils::singleflight::Group;
#[cfg(not(target_family = "wasm"))]
use crate::download_utils::*;
use crate::error::{CasClientError, Result};
use crate::http_client::{Api, ResponseErrorLogger, RetryConfig};
use crate::interface::*;
#[cfg(not(target_family = "wasm"))]
use crate::output_provider::OutputProvider;
use crate::retry_utils::retry_wrapper;
use crate::{http_client, Client, RegistrationClient, ShardClientInterface};
const FORCE_SYNC_METHOD: reqwest::Method = reqwest::Method::PUT;
const NON_FORCE_SYNC_METHOD: reqwest::Method = reqwest::Method::POST;
pub const CAS_ENDPOINT: &str = "http://localhost:8080";
pub const PREFIX_DEFAULT: &str = "default";
utils::configurable_constants! {
// Env (HF_XET_NUM_CONCURRENT_RANGE_GETS) to set the number of concurrent range gets.
// setting this value to 0 disables the limit, sets it to the max, this is not recommended as it may lead to errors
ref NUM_CONCURRENT_RANGE_GETS: usize = GlobalConfigMode::HighPerformanceOption {
standard: 128,
high_performance: 512,
};
// Send a report of successful partial upload every 512kb.
ref UPLOAD_REPORTING_BLOCK_SIZE : usize = 512 * 1024;
}
lazy_static::lazy_static! {
pub static ref DOWNLOAD_CONNECTION_CONCURRENCY_LIMITER: Arc<Semaphore> = Arc::new(Semaphore::new(*NUM_CONCURRENT_RANGE_GETS));
}
utils::configurable_bool_constants! {
// Env (HF_XET_RECONSTRUCT_WRITE_SEQUENTIALLY) to switch to writing terms sequentially to disk.
// Benchmarks have shown that on SSD machines, writing in parallel seems to far outperform
// sequential term writes.
// However, this is not likely the case for writing to HDD and may in fact be worse,
// so for those machines, setting this env may help download perf.
ref RECONSTRUCT_WRITE_SEQUENTIALLY = false;
}
pub struct RemoteClient {
endpoint: String,
dry_run: bool,
http_client: Arc<ClientWithMiddleware>,
authenticated_http_client: Arc<ClientWithMiddleware>,
authenticated_http_client_no_retry: Arc<ClientWithMiddleware>,
conservative_authenticated_http_client: Arc<ClientWithMiddleware>,
chunk_cache: Option<Arc<dyn ChunkCache>>,
#[cfg(not(target_family = "wasm"))]
range_download_single_flight: RangeDownloadSingleFlight,
shard_cache_directory: Option<PathBuf>,
}
impl RemoteClient {
#[allow(clippy::too_many_arguments)]
pub fn new(
endpoint: &str,
auth: &Option<AuthConfig>,
cache_config: &Option<CacheConfig>,
shard_cache_directory: Option<PathBuf>,
session_id: &str,
dry_run: bool,
) -> Self {
// use disk cache if cache_config provided.
let chunk_cache = if let Some(cache_config) = cache_config {
if cache_config.cache_size == 0 {
info!("Chunk cache size set to 0, disabling chunk cache");
None
} else {
debug!(
"Using disk cache directory: {:?}, size: {}.",
cache_config.cache_directory, cache_config.cache_size
);
chunk_cache::get_cache(cache_config)
.log_error("failed to initialize cache, not using cache")
.ok()
}
} else {
None
};
Self {
endpoint: endpoint.to_string(),
dry_run,
authenticated_http_client: Arc::new(
http_client::build_auth_http_client(auth, RetryConfig::default(), session_id).unwrap(),
),
authenticated_http_client_no_retry: Arc::new(
http_client::build_auth_http_client_no_retry(auth, session_id).unwrap(),
),
conservative_authenticated_http_client: Arc::new(
http_client::build_auth_http_client(auth, RetryConfig::no429retry(), session_id).unwrap(),
),
http_client: Arc::new(http_client::build_http_client(RetryConfig::default(), session_id).unwrap()),
chunk_cache,
#[cfg(not(target_family = "wasm"))]
range_download_single_flight: Arc::new(Group::new()),
shard_cache_directory,
}
}
async fn query_dedup_api(&self, prefix: &str, chunk_hash: &MerkleHash) -> Result<Option<Response>> {
// The API endpoint now only supports non-batched dedup request and
// ignores salt.
let key = Key {
prefix: prefix.into(),
hash: *chunk_hash,
};
let url = Url::parse(&format!("{0}/chunk/{key}", self.endpoint))?;
let response = self
.conservative_authenticated_http_client
.get(url)
.send()
.await
.map_err(|e| CasClientError::Other(format!("request failed with error {e}")))?;
// Dedup shard not found, return empty result
if !response.status().is_success() {
return Ok(None);
}
Ok(Some(response))
}
}
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
impl UploadClient for RemoteClient {
#[cfg(not(target_family = "wasm"))]
#[instrument(skip_all, name = "RemoteClient::upload_xorb", fields(key = Key{prefix : prefix.to_string(), hash : serialized_cas_object.hash}.to_string(),
xorb.len = serialized_cas_object.serialized_data.len(), xorb.num_chunks = serialized_cas_object.num_chunks
))]
async fn upload_xorb(
&self,
prefix: &str,
serialized_cas_object: SerializedCasObject,
upload_tracker: Option<Arc<CompletionTracker>>,
) -> Result<u64> {
let key = Key {
prefix: prefix.to_string(),
hash: serialized_cas_object.hash,
};
let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?;
let n_upload_bytes = serialized_cas_object.serialized_data.len() as u64;
// Backing out the incremental progress reporting for now until we figure out the middleware issue.
use crate::upload_progress_stream::UploadProgressStream;
let n_raw_bytes = serialized_cas_object.raw_num_bytes;
let xorb_hash = serialized_cas_object.hash;
let progress_callback = move |bytes_sent: u64| {
if let Some(utr) = upload_tracker.as_ref() {
// First, recallibrate the sending, as the compressed size is different than the actual data size.
let adjusted_update = (bytes_sent * n_raw_bytes) / n_upload_bytes;
utr.clone().register_xorb_upload_progress_background(xorb_hash, adjusted_update);
}
};
let upload_stream = UploadProgressStream::new(
serialized_cas_object.serialized_data,
*UPLOAD_REPORTING_BLOCK_SIZE,
progress_callback,
);
let xorb_uploaded = {
if !self.dry_run {
let client = self.authenticated_http_client_no_retry.clone();
let response = retry_wrapper(
move || {
let upload_stream = upload_stream.clone_with_reset();
let url = url.clone();
client
.post(url)
.with_extension(Api("cas::upload_xorb"))
.header(CONTENT_LENGTH, HeaderValue::from(n_upload_bytes)) // must be set because of streaming
.body(Body::wrap_stream(upload_stream))
.send()
},
RetryConfig::default(),
)
.await?;
let response_parsed: UploadXorbResponse = response.json().await?;
response_parsed.was_inserted
} else {
true
}
};
if !xorb_uploaded {
debug!("{key:?} not inserted into CAS.");
} else {
debug!("{key:?} inserted into CAS.");
}
Ok(n_upload_bytes)
}
#[cfg(target_family = "wasm")]
async fn upload_xorb(
&self,
prefix: &str,
serialized_cas_object: SerializedCasObject,
upload_tracker: Option<Arc<CompletionTracker>>,
) -> Result<u64> {
let key = Key {
prefix: prefix.to_string(),
hash: serialized_cas_object.hash,
};
let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?;
let n_upload_bytes = serialized_cas_object.serialized_data.len() as u64;
let xorb_uploaded = self
.authenticated_http_client
.post(url)
.with_extension(Api("cas::upload_xorb"))
.body(serialized_cas_object.serialized_data)
.send()
.await?;
Ok(n_upload_bytes)
}
async fn exists(&self, prefix: &str, hash: &MerkleHash) -> Result<bool> {
let key = Key {
prefix: prefix.to_string(),
hash: *hash,
};
let url = Url::parse(&format!("{}/xorb/{key}", self.endpoint))?;
let response = self.authenticated_http_client.head(url).send().await?;
match response.status() {
StatusCode::OK => Ok(true),
StatusCode::NOT_FOUND => Ok(false),
e => Err(CasClientError::internal(format!("unrecognized status code {e}"))),
}
}
fn use_xorb_footer(&self) -> bool {
false
}
}
#[cfg(not(target_family = "wasm"))]
#[async_trait::async_trait]
impl ReconstructionClient for RemoteClient {
async fn get_file(
&self,
hash: &MerkleHash,
byte_range: Option<FileRange>,
output_provider: &OutputProvider,
progress_updater: Option<Arc<SingleItemProgressUpdater>>,
) -> Result<u64> {
// If the user has set the `HF_XET_RECONSTRUCT_WRITE_SEQUENTIALLY=true` env variable, then we
// should write the file to the output sequentially instead of in parallel.
if *RECONSTRUCT_WRITE_SEQUENTIALLY {
info!("reconstruct terms sequentially");
self.reconstruct_file_to_writer_segmented(hash, byte_range, output_provider, progress_updater)
.await
} else {
info!("reconstruct terms in parallel");
self.reconstruct_file_to_writer_segmented_parallel_write(
hash,
byte_range,
output_provider,
progress_updater,
)
.await
}
}
}
#[cfg(not(target_family = "wasm"))]
#[async_trait::async_trait]
impl Reconstruct for RemoteClient {
async fn get_reconstruction(
&self,
file_id: &MerkleHash,
bytes_range: Option<FileRange>,
) -> Result<Option<QueryReconstructionResponse>> {
get_reconstruction_with_endpoint_and_client(
&self.endpoint,
&self.authenticated_http_client,
file_id,
bytes_range,
)
.await
}
}
pub(crate) async fn get_reconstruction_with_endpoint_and_client(
endpoint: &str,
client: &ClientWithMiddleware,
file_id: &MerkleHash,
bytes_range: Option<FileRange>,
) -> Result<Option<QueryReconstructionResponse>> {
let url = Url::parse(&format!("{}/reconstruction/{}", endpoint, file_id.hex()))?;
let mut request = client.get(url).with_extension(Api("cas::get_reconstruction"));
if let Some(range) = bytes_range {
// convert exclusive-end to inclusive-end range
request = request.header(RANGE, HttpRange::from(range).range_header())
}
let response = request.send().await.process_error("get_reconstruction");
let Ok(response) = response else {
let e = response.unwrap_err();
// bytes_range not satisfiable
if let CasClientError::ReqwestError(e, _) = &e {
if let Some(StatusCode::RANGE_NOT_SATISFIABLE) = e.status() {
return Ok(None);
}
}
return Err(e);
};
let len = response.content_length();
debug!("file_id: {file_id} query_reconstruction len {len:?}");
let query_reconstruction_response: QueryReconstructionResponse = response
.json()
.await
.log_error("error json parsing QueryReconstructionResponse")?;
Ok(Some(query_reconstruction_response))
}
impl Client for RemoteClient {}
#[cfg(not(target_family = "wasm"))]
impl RemoteClient {
#[instrument(skip_all, name = "RemoteClient::batch_get_reconstruction")]
async fn batch_get_reconstruction(
&self,
file_ids: impl Iterator<Item = &MerkleHash>,
) -> Result<BatchQueryReconstructionResponse> {
let mut url_str = format!("{}/reconstructions?", self.endpoint);
let mut is_first = true;
for hash in file_ids {
if is_first {
is_first = false;
} else {
url_str.push('&');
}
url_str.push_str("file_id=");
url_str.push_str(hash.hex().as_str());
}
let url: Url = url_str.parse()?;
let response = self
.authenticated_http_client
.get(url)
.with_extension(Api("cas::batch_get_reconstruction"))
.send()
.await
.process_error("batch_get_reconstruction")?;
let query_reconstruction_response: BatchQueryReconstructionResponse = response
.json()
.await
.log_error("error json parsing BatchQueryReconstructionResponse")?;
Ok(query_reconstruction_response)
}
// Segmented download such that the file reconstruction and fetch info is not queried in its entirety
// at the beginning of the download, but queried in segments. Range downloads are executed with
// a certain degree of parallelism, but writing out to storage is sequential. Ideal when the external
// storage uses HDDs.
#[instrument(skip_all, name = "RemoteClient::reconstruct_file_segmented", fields(file.hash = file_hash.hex()
))]
async fn reconstruct_file_to_writer_segmented(
&self,
file_hash: &MerkleHash,
byte_range: Option<FileRange>,
writer: &OutputProvider,
progress_updater: Option<Arc<SingleItemProgressUpdater>>,
) -> Result<u64> {
// queue size is inherently bounded by degree of concurrency.
let (task_tx, mut task_rx) = mpsc::unbounded_channel::<DownloadQueueItem<SequentialTermDownload>>();
let (running_downloads_tx, mut running_downloads_rx) =
mpsc::unbounded_channel::<JoinHandle<Result<(TermDownloadResult<Vec<u8>>, OwnedSemaphorePermit)>>>();
// derive the actual range to reconstruct
let file_reconstruct_range = byte_range.unwrap_or_else(FileRange::full);
let total_len = file_reconstruct_range.length();
// kick start the download by enqueue the fetch info task.
task_tx.send(DownloadQueueItem::Metadata(FetchInfo::new(
*file_hash,
file_reconstruct_range,
self.endpoint.clone(),
self.authenticated_http_client.clone(),
)))?;
// Start the queue processing logic
//
// If the queue item is `DownloadQueueItem::Metadata`, it fetches the file reconstruction info
// of the first segment, whose size is linear to `num_concurrent_range_gets`. Once fetched, term
// download tasks are enqueued and spawned with the degree of concurrency equal to `num_concurrent_range_gets`.
// After the above, a task that defines fetching the remainder of the file reconstruction info is enqueued,
// which will execute after the first of the above term download tasks finishes.
let chunk_cache = self.chunk_cache.clone();
let term_download_client = self.http_client.clone();
let range_download_single_flight = self.range_download_single_flight.clone();
let download_scheduler = DownloadSegmentLengthTuner::from_configurable_constants();
let download_scheduler_clone = download_scheduler.clone();
let queue_dispatcher: JoinHandle<Result<()>> = tokio::spawn(async move {
let mut remaining_total_len = total_len;
while let Some(item) = task_rx.recv().await {
match item {
DownloadQueueItem::End => {
// everything processed
debug!("download queue emptyed");
drop(running_downloads_tx);
break;
},
DownloadQueueItem::DownloadTask(term_download) => {
// acquire the permit before spawning the task, so that there's limited
// number of active downloads.
let permit = DOWNLOAD_CONNECTION_CONCURRENCY_LIMITER.clone().acquire_owned().await?;
debug!("spawning 1 download task");
let future: JoinHandle<Result<(TermDownloadResult<Vec<u8>>, OwnedSemaphorePermit)>> =
tokio::spawn(async move {
let data = term_download.run().await?;
Ok((data, permit))
});
running_downloads_tx.send(future)?;
},
DownloadQueueItem::Metadata(fetch_info) => {
// query for the file info of the first segment
let segment_size = download_scheduler_clone.next_segment_size()?;
debug!("querying file info of size {segment_size}");
let (segment, maybe_remainder) = fetch_info.take_segment(segment_size);
let Some((offset_into_first_range, terms)) = segment.query().await? else {
// signal termination
task_tx.send(DownloadQueueItem::End)?;
continue;
};
let segment = Arc::new(segment);
// define the term download tasks
let mut remaining_segment_len = segment_size;
debug!("enqueueing {} download tasks", terms.len());
for (i, term) in terms.into_iter().enumerate() {
let skip_bytes = if i == 0 { offset_into_first_range } else { 0 };
let take = remaining_total_len
.min(remaining_segment_len)
.min(term.unpacked_length as u64 - skip_bytes);
let (individual_fetch_info, _) = segment.find((term.hash, term.range)).await?;
let download_task = SequentialTermDownload {
download: FetchTermDownload {
hash: term.hash.into(),
range: individual_fetch_info.range,
fetch_info: segment.clone(),
chunk_cache: chunk_cache.clone(),
client: term_download_client.clone(),
range_download_single_flight: range_download_single_flight.clone(),
},
term,
skip_bytes,
take,
};
remaining_total_len -= take;
remaining_segment_len -= take;
debug!("enqueueing {download_task:?}");
task_tx.send(DownloadQueueItem::DownloadTask(download_task))?;
}
// enqueue the remainder of file info fetch task
if let Some(remainder) = maybe_remainder {
task_tx.send(DownloadQueueItem::Metadata(remainder))?;
} else {
task_tx.send(DownloadQueueItem::End)?;
}
},
}
}
Ok(())
});
let mut writer = writer.get_writer_at(0)?;
let mut total_written = 0;
while let Some(result) = running_downloads_rx.recv().await {
match result.await {
Ok(Ok((mut download_result, permit))) => {
let data = take(&mut download_result.payload);
writer.write_all(&data)?;
// drop permit after data written out so they don't accumulate in memory unbounded
drop(permit);
if let Some(updater) = progress_updater.as_ref() {
updater.update(data.len() as u64).await;
}
total_written += data.len() as u64;
// Now inspect the download metrics and tune the download degree of concurrency
download_scheduler.tune_on(download_result)?;
},
Ok(Err(e)) => Err(e)?,
Err(e) => Err(anyhow!("{e:?}"))?,
}
}
writer.flush()?;
queue_dispatcher.await??;
Ok(total_written)
}
// Segmented download such that the file reconstruction and fetch info is not queried in its entirety
// at the beginning of the download, but queried in segments. Range downloads are executed with
// a certain degree of parallelism, and so does writing out to storage. Ideal when the external
// storage is fast at seeks, e.g. RAM or SSDs.
#[instrument(skip_all, name = "RemoteClient::reconstruct_file_segmented_parallel", fields(file.hash = file_hash.hex()
))]
async fn reconstruct_file_to_writer_segmented_parallel_write(
&self,
file_hash: &MerkleHash,
byte_range: Option<FileRange>,
writer: &OutputProvider,
progress_updater: Option<Arc<SingleItemProgressUpdater>>,
) -> Result<u64> {
// queue size is inherently bounded by degree of concurrency.
let (task_tx, mut task_rx) =
mpsc::unbounded_channel::<DownloadQueueItem<FetchTermDownloadOnceAndWriteEverywhereUsed>>();
let mut running_downloads = JoinSet::<Result<TermDownloadResult<u64>>>::new();
// derive the actual range to reconstruct
let file_reconstruct_range = byte_range.unwrap_or_else(FileRange::full);
let base_write_negative_offset = file_reconstruct_range.start;
// kick-start the download by enqueue the fetch info task.
task_tx.send(DownloadQueueItem::Metadata(FetchInfo::new(
*file_hash,
file_reconstruct_range,
self.endpoint.clone(),
self.authenticated_http_client.clone(),
)))?;
// Start the queue processing logic
//
// If the queue item is `DownloadQueueItem::Metadata`, it fetches the file reconstruction info
// of the first segment, whose size is linear to `num_concurrent_range_gets`. Once fetched, term
// download tasks are enqueued and spawned with the degree of concurrency equal to `num_concurrent_range_gets`.
// After the above, a task that defines fetching the remainder of the file reconstruction info is enqueued,
// which will execute after the first of the above term download tasks finishes.
let term_download_client = self.http_client.clone();
let download_scheduler = DownloadSegmentLengthTuner::from_configurable_constants();
let process_result = move |result: TermDownloadResult<u64>,
total_written: &mut u64,
download_scheduler: &DownloadSegmentLengthTuner|
-> Result<u64> {
let write_len = result.payload;
*total_written += write_len;
// Now inspect the download metrics and tune the download degree of concurrency
download_scheduler.tune_on(result)?;
Ok(write_len)
};
let mut total_written = 0;
while let Some(item) = task_rx.recv().await {
// first try to join some tasks
while let Some(result) = running_downloads.try_join_next() {
let write_len = process_result(result??, &mut total_written, &download_scheduler)?;
if let Some(updater) = progress_updater.as_ref() {
updater.update(write_len).await;
}
}
match item {
DownloadQueueItem::End => {
// everything processed
debug!("download queue emptied");
break;
},
DownloadQueueItem::DownloadTask(term_download) => {
// acquire the permit before spawning the task, so that there's limited
// number of active downloads.
let permit = DOWNLOAD_CONNECTION_CONCURRENCY_LIMITER.clone().acquire_owned().await?;
debug!("spawning 1 download task");
running_downloads.spawn(async move {
let data = term_download.run().await?;
drop(permit);
Ok(data)
});
},
DownloadQueueItem::Metadata(fetch_info) => {
// query for the file info of the first segment
let segment_size = download_scheduler.next_segment_size()?;
debug!("querying file info of size {segment_size}");
let (segment, maybe_remainder) = fetch_info.take_segment(segment_size);
let Some((offset_into_first_range, terms)) = segment.query().await? else {
// signal termination
task_tx.send(DownloadQueueItem::End)?;
continue;
};
let segment = Arc::new(segment);
// define the term download tasks
let tasks = map_fetch_info_into_download_tasks(
segment.clone(),
terms,
offset_into_first_range,
base_write_negative_offset,
self.chunk_cache.clone(),
term_download_client.clone(),
self.range_download_single_flight.clone(),
writer,
)
.await?;
debug!("enqueueing {} download tasks", tasks.len());
for task_def in tasks {
task_tx.send(DownloadQueueItem::DownloadTask(task_def))?;
}
// enqueue the remainder of file info fetch task
if let Some(remainder) = maybe_remainder {
task_tx.send(DownloadQueueItem::Metadata(remainder))?;
} else {
task_tx.send(DownloadQueueItem::End)?;
}
},
}
}
while let Some(result) = running_downloads.join_next().await {
let write_len = process_result(result??, &mut total_written, &download_scheduler)?;
if let Some(updater) = progress_updater.as_ref() {
updater.update(write_len).await;
}
}
Ok(total_written)
}
}
#[cfg(not(target_family = "wasm"))]
#[allow(clippy::too_many_arguments)]
async fn map_fetch_info_into_download_tasks(
segment: Arc<FetchInfo>,
terms: Vec<CASReconstructionTerm>,
offset_into_first_range: u64,
base_write_negative_offset: u64,
chunk_cache: Option<Arc<dyn ChunkCache>>,
client: Arc<ClientWithMiddleware>,
range_download_single_flight: Arc<Group<DownloadRangeResult, CasClientError>>,
output_provider: &OutputProvider,
) -> Result<Vec<FetchTermDownloadOnceAndWriteEverywhereUsed>> {
// the actual segment length.
// the file_range end may actually exceed the file total length for the last segment.
// in that case, the maximum length of this segment will be the total of all terms given
// minus the start offset
let seg_len = segment
.file_range
.length()
.min(terms.iter().fold(0, |acc, term| acc + term.unpacked_length as u64) - offset_into_first_range);
let initial_writer_offset = segment.file_range.start - base_write_negative_offset;
let mut total_taken = 0;
let mut fetch_info_term_map: HashMap<(MerkleHash, ChunkRange), FetchTermDownloadOnceAndWriteEverywhereUsed> =
HashMap::new();
for (i, term) in terms.into_iter().enumerate() {
let (individual_fetch_info, _) = segment.find((term.hash, term.range)).await?;
let skip_bytes = if i == 0 { offset_into_first_range } else { 0 };
// amount to take is min of the whole term after skipped bytes or the remainder of the segment
let take = (term.unpacked_length as u64 - skip_bytes).min(seg_len - total_taken);
let write_term = ChunkRangeWrite {
// term details
chunk_range: term.range,
unpacked_length: term.unpacked_length,
// write details
skip_bytes,
take,
writer_offset: initial_writer_offset + total_taken,
};
let task = fetch_info_term_map
.entry((term.hash.into(), individual_fetch_info.range))
.or_insert_with(|| FetchTermDownloadOnceAndWriteEverywhereUsed {
download: FetchTermDownload {
hash: term.hash.into(),
range: individual_fetch_info.range,
fetch_info: segment.clone(),
chunk_cache: chunk_cache.clone(),
client: client.clone(),
range_download_single_flight: range_download_single_flight.clone(),
},
writes: vec![],
output: output_provider.clone(),
});
task.writes.push(write_term);
total_taken += take;
}
let tasks = fetch_info_term_map.into_values().collect();
Ok(tasks)
}
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
impl RegistrationClient for RemoteClient {
#[instrument(skip_all, name = "RemoteClient::upload_shard", fields(shard.hash = hash.hex(), shard.len = shard_data.len()
))]
async fn upload_shard(
&self,
prefix: &str,
hash: &MerkleHash,
force_sync: bool,
shard_data: &[u8],
_salt: &[u8; 32],
) -> Result<bool> {
if self.dry_run {
return Ok(true);
}
let key = Key {
prefix: prefix.into(),
hash: *hash,
};
let url = Url::parse(&format!("{}/shard/{key}", self.endpoint))?;
let method = match force_sync {
true => FORCE_SYNC_METHOD,
false => NON_FORCE_SYNC_METHOD,
};
let response = self
.authenticated_http_client
.request(method, url)
.with_extension(Api("cas::upload_shard"))
.body(shard_data.to_vec())
.send()
.await
.process_error("upload_shard")?;
let response_parsed: UploadShardResponse =
response.json().await.log_error("error json decoding upload_shard response")?;
match response_parsed.result {
UploadShardResponseType::Exists => Ok(false),
UploadShardResponseType::SyncPerformed => Ok(true),
}
}
}
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
impl FileReconstructor<CasClientError> for RemoteClient {
#[instrument(skip_all, name = "RemoteClient::get_file_reconstruction", fields(file.hash = file_hash.hex()
))]
async fn get_file_reconstruction_info(
&self,
file_hash: &MerkleHash,
) -> Result<Option<(MDBFileInfo, Option<MerkleHash>)>> {
let url = Url::parse(&format!("{}/reconstruction/{}", self.endpoint, file_hash.hex()))?;
let response = self
.authenticated_http_client
.get(url)
.with_extension(Api("cas::get_reconstruction_info"))
.send()
.await
.process_error("get_reconstruction_info")?;
let response_info: QueryReconstructionResponse = response.json().await?;
Ok(Some((
MDBFileInfo {
metadata: FileDataSequenceHeader::new(*file_hash, response_info.terms.len(), false, false),
segments: response_info
.terms
.into_iter()
.map(|ce| {
FileDataSequenceEntry::new(ce.hash.into(), ce.unpacked_length, ce.range.start, ce.range.end)
})
.collect(),
verification: vec![],
metadata_ext: None,
},
None,
)))
}
}
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
impl ShardDedupProber for RemoteClient {
#[instrument(skip_all, name = "RemoteClient::query_global_dedup")]
#[cfg(not(target_family = "wasm"))]
async fn query_for_global_dedup_shard(
&self,
prefix: &str,
chunk_hash: &MerkleHash,
_salt: &[u8; 32],
) -> Result<Option<PathBuf>> {
let Some(ref shard_cache_directory) = self.shard_cache_directory else {
return Err(CasClientError::ConfigurationError(
"Shard Write Directory not set; cannot download.".to_string(),
));
};
let Some(mut response) = self.query_dedup_api(prefix, chunk_hash).await? else {
return Ok(None);
};
let writer = SafeFileCreator::new_unnamed(shard_cache_directory)?;
// Compute the actual hash to use as the shard file name
let mut hashed_writer = HashedWrite::new(writer);
while let Some(chunk) = response.chunk().await? {
hashed_writer.write_all(&chunk)?;
}
hashed_writer.flush()?;
let shard_hash = hashed_writer.hash();
let file_path = shard_cache_directory.join(shard_file_name(&shard_hash));
let mut writer = hashed_writer.into_inner();
writer.set_dest_path(&file_path);
writer.close()?;
Ok(Some(file_path))
}
async fn query_for_global_dedup_shard_in_memory(
&self,
prefix: &str,
chunk_hash: &MerkleHash,
_salt: &[u8; 32],
) -> Result<Option<Vec<u8>>> {
let Some(response) = self.query_dedup_api(prefix, chunk_hash).await? else {
return Ok(None);
};
Ok(Some(response.bytes().await?.to_vec()))
}
}
impl ShardClientInterface for RemoteClient {}
#[cfg(test)]
#[cfg(not(target_family = "wasm"))]
mod tests {
use std::collections::HashMap;
use anyhow::Result;
use cas_object::test_utils::*;
use cas_object::CompressionScheme;
use cas_types::{CASReconstructionFetchInfo, CASReconstructionTerm, ChunkRange};
use deduplication::constants::MAX_XORB_BYTES;
use httpmock::Method::GET;
use httpmock::MockServer;
use merkledb::constants::TARGET_CDC_CHUNK_SIZE;
use tracing_test::traced_test;
use xet_threadpool::ThreadPool;
use super::*;
use crate::output_provider::BufferProvider;
#[ignore = "requires a running CAS server"]
#[traced_test]
#[test]
fn test_basic_put() {
// Arrange
let prefix = PREFIX_DEFAULT;
let raw_xorb = build_raw_xorb(3, ChunkSize::Random(512, 10248));
let threadpool = ThreadPool::new().unwrap();
let client = RemoteClient::new(CAS_ENDPOINT, &None, &None, None, "", false);
let cas_object = build_and_verify_cas_object(raw_xorb, Some(CompressionScheme::LZ4));
// Act
let result = threadpool
.external_run_async_task(async move { client.upload_xorb(prefix, cas_object, None).await })
.unwrap();
// Assert
assert!(result.is_ok());
}
#[derive(Clone)]
struct TestCase {
file_hash: MerkleHash,
reconstruction_response: QueryReconstructionResponse,
file_range: FileRange,
expected_data: Vec<u8>,
expect_error: bool,
}
const NUM_CHUNKS: u32 = 128;
const CHUNK_SIZE: u32 = TARGET_CDC_CHUNK_SIZE as u32;
macro_rules! mock_no_match_range_header {
($range_to_compare:expr) => {
|req| {
let Some(h) = &req.headers else {
return false;
};
let Some((_range_header, range_value)) =
h.iter().find(|(k, _v)| k.eq_ignore_ascii_case(RANGE.as_str()))
else {
return false;
};
let Ok(range) = HttpRange::try_from(range_value.trim_start_matches("bytes=")) else {
return false;
};
range != $range_to_compare
}
};
}
#[test]
fn test_reconstruct_file_full_file() -> Result<()> {
// Arrange server
let server = MockServer::start();
let xorb_hash: MerkleHash = MerkleHash::default();
let (cas_object, chunks_serialized, raw_data, _raw_data_chunk_hash_and_boundaries) =
build_cas_object(NUM_CHUNKS, ChunkSize::Fixed(CHUNK_SIZE), CompressionScheme::ByteGrouping4LZ4);
// Workaround to make this variable const. Change this accordingly if
// real value of the two static variables below change.
const FIRST_SEGMENT_SIZE: u64 = 16 * 64 * 1024 * 1024;
assert_eq!(FIRST_SEGMENT_SIZE, *NUM_RANGE_IN_SEGMENT_BASE as u64 * *MAX_XORB_BYTES as u64);
// Test case: full file reconstruction
const FIRST_SEGMENT_FILE_RANGE: FileRange = FileRange {
start: 0,
end: FIRST_SEGMENT_SIZE,
_marker: std::marker::PhantomData,
};
let test_case = TestCase {
file_hash: MerkleHash::from_hex(&format!("{:0>64}", "1"))?, // "0....1"
reconstruction_response: QueryReconstructionResponse {
offset_into_first_range: 0,
terms: vec![CASReconstructionTerm {
hash: xorb_hash.into(),
range: ChunkRange::new(0, NUM_CHUNKS),
unpacked_length: raw_data.len() as u32,
}],
fetch_info: HashMap::from([(
xorb_hash.into(),
vec![CASReconstructionFetchInfo {
range: ChunkRange::new(0, NUM_CHUNKS),
url: server.url(format!("/get_xorb/{xorb_hash}/")),
url_range: {
let (start, end) = cas_object.get_byte_offset(0, NUM_CHUNKS)?;
HttpRange::from(FileRange::new(start as u64, end as u64))
},
}],
)]),
},
file_range: FileRange::full(),
expected_data: raw_data,
expect_error: false,
};
// Arrange server mocks
let _mock_fi_416 = server.mock(|when, then| {
when.method(GET)
.path(format!("/reconstruction/{}", test_case.file_hash))
.matches(mock_no_match_range_header!(HttpRange::from(FIRST_SEGMENT_FILE_RANGE)));
then.status(416);
});
let _mock_fi_200 = server.mock(|when, then| {
let w = when.method(GET).path(format!("/reconstruction/{}", test_case.file_hash));
w.header(RANGE.as_str(), HttpRange::from(FIRST_SEGMENT_FILE_RANGE).range_header());
then.status(200).json_body_obj(&test_case.reconstruction_response);
});
for (k, v) in &test_case.reconstruction_response.fetch_info {
for term in v {
let data = FileRange::from(term.url_range);
let data = chunks_serialized[data.start as usize..data.end as usize].to_vec();
let _mock_data = server.mock(|when, then| {
when.method(GET)
.path(format!("/get_xorb/{k}/"))
.header(RANGE.as_str(), term.url_range.range_header());
then.status(200).body(&data);
});
}
}
test_reconstruct_file(test_case, &server.base_url())
}
#[test]
fn test_reconstruct_file_skip_front_bytes() -> Result<()> {
// Arrange server
let server = MockServer::start();
let xorb_hash: MerkleHash = MerkleHash::default();
let (cas_object, chunks_serialized, raw_data, _raw_data_chunk_hash_and_boundaries) =
build_cas_object(NUM_CHUNKS, ChunkSize::Fixed(CHUNK_SIZE), CompressionScheme::ByteGrouping4LZ4);
// Workaround to make this variable const. Change this accordingly if
// real value of the two static variables below change.
const FIRST_SEGMENT_SIZE: u64 = 16 * 64 * 1024 * 1024;
assert_eq!(FIRST_SEGMENT_SIZE, *NUM_RANGE_IN_SEGMENT_BASE as u64 * *MAX_XORB_BYTES as u64);
// Test case: skip first 100 bytes
const SKIP_BYTES: u64 = 100;
const FIRST_SEGMENT_FILE_RANGE: FileRange = FileRange {
start: SKIP_BYTES,
end: SKIP_BYTES + FIRST_SEGMENT_SIZE,
_marker: std::marker::PhantomData,
};
let test_case = TestCase {
file_hash: MerkleHash::from_hex(&format!("{:0>64}", "1"))?, // "0....1"
reconstruction_response: QueryReconstructionResponse {
offset_into_first_range: SKIP_BYTES,
terms: vec![CASReconstructionTerm {
hash: xorb_hash.into(),
range: ChunkRange::new(0, NUM_CHUNKS),
unpacked_length: raw_data.len() as u32,
}],
fetch_info: HashMap::from([(
xorb_hash.into(),
vec![CASReconstructionFetchInfo {
range: ChunkRange::new(0, NUM_CHUNKS),
url: server.url(format!("/get_xorb/{xorb_hash}/")),
url_range: {
let (start, end) = cas_object.get_byte_offset(0, NUM_CHUNKS)?;
HttpRange::from(FileRange::new(start as u64, end as u64))
},
}],
)]),
},
file_range: FileRange::new(SKIP_BYTES, u64::MAX),
expected_data: raw_data[SKIP_BYTES as usize..].to_vec(),
expect_error: false,
};
// Arrange server mocks
let _mock_fi_416 = server.mock(|when, then| {
when.method(GET)
.path(format!("/reconstruction/{}", test_case.file_hash))
.matches(mock_no_match_range_header!(HttpRange::from(FIRST_SEGMENT_FILE_RANGE)));
then.status(416);
});
let _mock_fi_200 = server.mock(|when, then| {
let w = when.method(GET).path(format!("/reconstruction/{}", test_case.file_hash));
w.header(RANGE.as_str(), HttpRange::from(FIRST_SEGMENT_FILE_RANGE).range_header());
then.status(200).json_body_obj(&test_case.reconstruction_response);
});
for (k, v) in &test_case.reconstruction_response.fetch_info {
for term in v {
let data = FileRange::from(term.url_range);
let data = chunks_serialized[data.start as usize..data.end as usize].to_vec();
let _mock_data = server.mock(|when, then| {
when.method(GET)
.path(format!("/get_xorb/{k}/"))
.header(RANGE.as_str(), term.url_range.range_header());
then.status(200).body(&data);
});
}
}
test_reconstruct_file(test_case, &server.base_url())
}
#[test]
fn test_reconstruct_file_skip_back_bytes() -> Result<()> {
// Arrange server
let server = MockServer::start();
let xorb_hash: MerkleHash = MerkleHash::default();
let (cas_object, chunks_serialized, raw_data, _raw_data_chunk_hash_and_boundaries) =
build_cas_object(NUM_CHUNKS, ChunkSize::Fixed(CHUNK_SIZE), CompressionScheme::ByteGrouping4LZ4);
// Test case: skip last 100 bytes
const FILE_SIZE: u64 = NUM_CHUNKS as u64 * CHUNK_SIZE as u64;
const SKIP_BYTES: u64 = 100;
const FIRST_SEGMENT_FILE_RANGE: FileRange = FileRange {
start: 0,
end: FILE_SIZE - SKIP_BYTES,
_marker: std::marker::PhantomData,
};
let test_case = TestCase {
file_hash: MerkleHash::from_hex(&format!("{:0>64}", "1"))?, // "0....1"
reconstruction_response: QueryReconstructionResponse {
offset_into_first_range: 0,
terms: vec![CASReconstructionTerm {
hash: xorb_hash.into(),
range: ChunkRange::new(0, NUM_CHUNKS),
unpacked_length: raw_data.len() as u32,
}],
fetch_info: HashMap::from([(
xorb_hash.into(),
vec![CASReconstructionFetchInfo {
range: ChunkRange::new(0, NUM_CHUNKS),
url: server.url(format!("/get_xorb/{xorb_hash}/")),
url_range: {
let (start, end) = cas_object.get_byte_offset(0, NUM_CHUNKS)?;
HttpRange::from(FileRange::new(start as u64, end as u64))
},
}],
)]),
},
file_range: FileRange::new(0, FILE_SIZE - SKIP_BYTES),
expected_data: raw_data[..(FILE_SIZE - SKIP_BYTES) as usize].to_vec(),
expect_error: false,
};
// Arrange server mocks
let _mock_fi_416 = server.mock(|when, then| {
when.method(GET)
.path(format!("/reconstruction/{}", test_case.file_hash))
.matches(mock_no_match_range_header!(HttpRange::from(FIRST_SEGMENT_FILE_RANGE)));
then.status(416);
});
let _mock_fi_200 = server.mock(|when, then| {
let w = when.method(GET).path(format!("/reconstruction/{}", test_case.file_hash));
w.header(RANGE.as_str(), HttpRange::from(FIRST_SEGMENT_FILE_RANGE).range_header());
then.status(200).json_body_obj(&test_case.reconstruction_response);
});
for (k, v) in &test_case.reconstruction_response.fetch_info {
for term in v {
let data = FileRange::from(term.url_range);
let data = chunks_serialized[data.start as usize..data.end as usize].to_vec();
let _mock_data = server.mock(|when, then| {
when.method(GET)
.path(format!("/get_xorb/{k}/"))
.header(RANGE.as_str(), term.url_range.range_header());
then.status(200).body(&data);
});
}
}
test_reconstruct_file(test_case, &server.base_url())
}
#[test]
fn test_reconstruct_file_two_terms() -> Result<()> {
// Arrange server
let server = MockServer::start();
let xorb_hash_1: MerkleHash = MerkleHash::from_hex(&format!("{:0>64}", "1"))?; // "0....1"
let xorb_hash_2: MerkleHash = MerkleHash::from_hex(&format!("{:0>64}", "2"))?; // "0....2"
let (cas_object, chunks_serialized, raw_data, _raw_data_chunk_hash_and_boundaries) =
build_cas_object(NUM_CHUNKS, ChunkSize::Fixed(CHUNK_SIZE), CompressionScheme::ByteGrouping4LZ4);
// Test case: two terms and skip first and last 100 bytes
const FILE_SIZE: u64 = (NUM_CHUNKS - 1) as u64 * CHUNK_SIZE as u64;
const SKIP_BYTES: u64 = 100;
const FIRST_SEGMENT_FILE_RANGE: FileRange = FileRange {
start: SKIP_BYTES,
end: FILE_SIZE - SKIP_BYTES,
_marker: std::marker::PhantomData,
};
let test_case = TestCase {
file_hash: MerkleHash::from_hex(&format!("{:0>64}", "1"))?, // "0....3"
reconstruction_response: QueryReconstructionResponse {
offset_into_first_range: SKIP_BYTES,
terms: vec![
CASReconstructionTerm {
hash: xorb_hash_1.into(),
range: ChunkRange::new(0, 5),
unpacked_length: CHUNK_SIZE * 5,
},
CASReconstructionTerm {
hash: xorb_hash_2.into(),
range: ChunkRange::new(6, NUM_CHUNKS),
unpacked_length: CHUNK_SIZE * (NUM_CHUNKS - 6),
},
],
fetch_info: HashMap::from([
(
// this constructs the first term
xorb_hash_1.into(),
vec![CASReconstructionFetchInfo {
range: ChunkRange::new(0, 7),
url: server.url(format!("/get_xorb/{xorb_hash_1}/")),
url_range: {
let (start, end) = cas_object.get_byte_offset(0, 7)?;
HttpRange::from(FileRange::new(start as u64, end as u64))
},
}],
),
(
// this constructs the second term
xorb_hash_2.into(),
vec![CASReconstructionFetchInfo {
range: ChunkRange::new(4, NUM_CHUNKS),
url: server.url(format!("/get_xorb/{xorb_hash_2}/")),
url_range: {
let (start, end) = cas_object.get_byte_offset(4, NUM_CHUNKS)?;
HttpRange::from(FileRange::new(start as u64, end as u64))
},
}],
),
]),
},
file_range: FileRange::new(SKIP_BYTES, FILE_SIZE - SKIP_BYTES),
expected_data: [
&raw_data[SKIP_BYTES as usize..(5 * CHUNK_SIZE) as usize],
&raw_data[(6 * CHUNK_SIZE) as usize..(NUM_CHUNKS * CHUNK_SIZE) as usize - SKIP_BYTES as usize],
]
.concat(),
expect_error: false,
};
// Arrange server mocks
let _mock_fi_416 = server.mock(|when, then| {
when.method(GET)
.path(format!("/reconstruction/{}", test_case.file_hash))
.matches(mock_no_match_range_header!(HttpRange::from(FIRST_SEGMENT_FILE_RANGE)));
then.status(416);
});
let _mock_fi_200 = server.mock(|when, then| {
let w = when.method(GET).path(format!("/reconstruction/{}", test_case.file_hash));
w.header(RANGE.as_str(), HttpRange::from(FIRST_SEGMENT_FILE_RANGE).range_header());
then.status(200).json_body_obj(&test_case.reconstruction_response);
});
for (k, v) in &test_case.reconstruction_response.fetch_info {
for term in v {
let data = FileRange::from(term.url_range);
let data = chunks_serialized[data.start as usize..data.end as usize].to_vec();
let _mock_data = server.mock(|when, then| {
when.method(GET)
.path(format!("/get_xorb/{k}/"))
.header(RANGE.as_str(), term.url_range.range_header());
then.status(200).body(&data);
});
}
}
test_reconstruct_file(test_case, &server.base_url())
}
fn test_reconstruct_file(test_case: TestCase, endpoint: &str) -> Result<()> {
let threadpool = ThreadPool::new()?;
// test reconstruct and sequential write
let test = test_case.clone();
let client = RemoteClient::new(endpoint, &None, &None, None, "", false);
let provider = BufferProvider::default();
let buf = provider.buf.clone();
let writer = OutputProvider::Buffer(provider);
let resp = threadpool.external_run_async_task(async move {
client
.reconstruct_file_to_writer_segmented(&test.file_hash, Some(test.file_range), &writer, None)
.await
})?;
assert_eq!(test.expect_error, resp.is_err());
if !test.expect_error {
assert_eq!(test.expected_data.len() as u64, resp.unwrap());
assert_eq!(test.expected_data, buf.value());
}
// test reconstruct and parallel write
let test = test_case;
let client = RemoteClient::new(endpoint, &None, &None, None, "", false);
let provider = BufferProvider::default();
let buf = provider.buf.clone();
let writer = OutputProvider::Buffer(provider);
let resp = threadpool.external_run_async_task(async move {
client
.reconstruct_file_to_writer_segmented_parallel_write(
&test.file_hash,
Some(test.file_range),
&writer,
None,
)
.await
})?;
assert_eq!(test.expect_error, resp.is_err());
if !test.expect_error {
assert_eq!(test.expected_data.len() as u64, resp.unwrap());
let value = buf.value();
assert_eq!(&test.expected_data[..100], &value[..100]);
let idx = test.expected_data.len() - 100;
assert_eq!(&test.expected_data[idx..], &value[idx..]);
assert_eq!(test.expected_data, value);
}
Ok(())
}
}