crates/ratchet-core/src/storage/gpu_buffer.rs (185 lines of code) (raw):

use crate::{ gpu::{BufferDescriptor, WgpuDevice}, gpu::{BufferUsagesExt, PooledGPUBuffer}, storage::{CPUBuffer, DeviceStorage}, Device, DeviceError, Shape, TensorDType, }; use bytemuck::NoUninit; use wgpu::BufferUsages; use crate::DType; #[derive(Clone, Debug, derive_new::new)] pub struct GPUBuffer { pub(crate) inner: PooledGPUBuffer, pub(crate) alignment: usize, } impl GPUBuffer { pub fn from_slice<T: NoUninit>(data: &[T], shape: &Shape, device: &WgpuDevice) -> Self { assert_eq!(data.len(), shape.numel()); Self::from_bytes( bytemuck::cast_slice(data), std::mem::align_of::<T>(), device, ) } //We have to use from_bytes here, as buffers may be reused and we need to //ensure that the buffer is zeroed pub fn zeros<T: TensorDType>(shape: &Shape, device: &WgpuDevice) -> Self { Self::from_bytes( vec![0; shape.numel() * T::dt().size_of()].as_slice(), T::dt().size_of(), device, ) } /// # Safety /// /// We don't check the provided shape here. /// The caller should ensure that this data is laid out correctly. /// We also require that all of the elements have the same alignment. pub unsafe fn from_quantized<T: NoUninit>(data: &[T], device: &WgpuDevice) -> Self { let bytes: &[u8] = bytemuck::cast_slice(data); Self::from_bytes(bytes, std::mem::align_of::<T>(), device) } pub(crate) fn from_bytes(bytes: &[u8], alignment: usize, device: &WgpuDevice) -> Self { let inner = device .get_or_create_buffer_init( &BufferDescriptor::new(bytes.len() as _, BufferUsages::standard(), false), bytes.into(), ) .unwrap(); device.queue().submit(None); device.poll(wgpu::Maintain::Wait); Self { inner, alignment } } /// Returns true if the buffer has all the given usages. pub(crate) fn validate_usages(&self, usages: BufferUsages) -> Result<(), DeviceError> { match self.inner.usage().contains(usages) { true => Ok(()), false => Err(DeviceError::InvalidBufferUsage(self.inner.usage(), usages)), } } pub fn inner(&self) -> &PooledGPUBuffer { &self.inner } pub fn usage(&self) -> BufferUsages { self.inner.usage() } #[allow(unused)] pub fn deep_clone(&self, device: &WgpuDevice) -> Self { let clone = device .get_or_create_buffer( &BufferDescriptor::new(self.inner.size(), self.inner.usage(), false), true, ) .unwrap(); let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); encoder.copy_buffer_to_buffer(&self.inner, 0, &clone, 0, self.inner.size()); device.queue().submit(Some(encoder.finish())); device.poll(wgpu::Maintain::Wait); Self { inner: clone, alignment: self.alignment, } } pub fn from_disk<T: TensorDType, R: std::io::BufRead + std::io::Seek>( reader: &mut R, shape: &Shape, device: &Device, ) -> Result<Self, DeviceError> { //There is no faster way to do this CPUBuffer::from_disk::<T, R>(reader, shape)?.to_device(device) } pub fn trim_id(id: wgpu::Id<wgpu::Buffer>) -> Option<String> { let id = format!("{:?}", id); let trimmed = id.trim_start_matches("Id(").trim_end_matches(')'); if trimmed.len() > 12 && trimmed.chars().all(|c| c.is_numeric()) { Some(trimmed[12..].to_string()) } else { None } } #[cfg(feature = "plotting")] pub fn plot_fmt(&self) -> String { let id_string = Self::trim_id(self.inner().global_id()).unwrap_or_default(); format!("GPU:#{}\n{} bytes", id_string, self.inner.size()) } } #[cfg_attr(target_arch = "wasm32", async_trait::async_trait)] impl DeviceStorage for GPUBuffer { fn to_device(&self, _: &Device) -> Result<GPUBuffer, DeviceError> { Ok(self.clone()) } #[cfg(target_arch = "wasm32")] async fn to_cpu(&self, device: &Device) -> Result<CPUBuffer, DeviceError> { self.validate_usages(BufferUsages::COPY_SRC)?; let device = device.try_gpu()?; let buffer_slice = self.inner.slice(..); let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel(); let alignment = self.alignment; wgpu::util::DownloadBuffer::read_buffer( device, device.queue(), &buffer_slice, move |buffer| { tx.send(match buffer { Ok(db) => Ok(CPUBuffer::from_bytes(&db, alignment)), Err(error) => Err(error), }) .expect("Failed to send result of read_buffer"); }, ); device.poll(wgpu::Maintain::Wait); Ok(rx.receive().await.unwrap()?) } #[cfg(not(target_arch = "wasm32"))] fn to_cpu(&self, device: &Device) -> Result<CPUBuffer, DeviceError> { self.validate_usages(BufferUsages::COPY_SRC)?; let device = device.try_gpu()?; let storage = wgpu_buffer_to_cpu_buffer(&self.inner, self.alignment, device); Ok(storage) } fn n_bytes(&self) -> usize { self.inner.size() as usize } fn dump(&self, _: DType, _: bool) -> String { let mut result = String::new(); let id_string = Self::trim_id(self.inner().global_id()).unwrap_or_default(); result.push_str(&format!("GPU Buffer #{}\n", id_string)); result.push_str(&format!("Size: {} bytes\n", self.inner.size())); result } } #[cfg(target_arch = "wasm32")] pub async fn wgpu_buffer_to_cpu_buffer( src_buf: &wgpu::Buffer, alignment: usize, device: WgpuDevice, ) -> CPUBuffer { assert!(src_buf.usage().contains(wgpu::BufferUsages::COPY_SRC)); let buffer_slice = src_buf.slice(..); let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel(); wgpu::util::DownloadBuffer::read_buffer( &device, device.queue(), &buffer_slice, move |buffer| { tx.send(match buffer { Ok(db) => Ok(CPUBuffer::from_bytes(&db, alignment)), Err(error) => Err(error), }) .expect("Failed to send result of read_buffer"); }, ); device.poll(wgpu::Maintain::Wait); rx.receive().await.unwrap().unwrap() } #[cfg(not(target_arch = "wasm32"))] pub fn wgpu_buffer_to_cpu_buffer( src_buf: &wgpu::Buffer, alignment: usize, device: &WgpuDevice, ) -> CPUBuffer { assert!(src_buf.usage().contains(wgpu::BufferUsages::COPY_SRC)); let buffer_slice = src_buf.slice(..); let (tx, rx) = std::sync::mpsc::channel(); wgpu::util::DownloadBuffer::read_buffer(device, device.queue(), &buffer_slice, move |buffer| { tx.send(match buffer { Ok(db) => Ok(CPUBuffer::from_bytes(&db, alignment)), Err(error) => Err(error), }) .expect("Failed to send result of read_buffer"); }); device.poll(wgpu::Maintain::Wait); rx.recv().unwrap().unwrap() }