crates/ratchet-core/src/gpu/device.rs (283 lines of code) (raw):
use crate::{gpu::*, DType, Tensor, TensorId};
use rustc_hash::FxHashMap;
use std::{borrow::Cow, sync::Arc};
use wgpu::{Adapter, Limits};
use crate::DeviceError;
pub const MAX_BUFFER_SIZE: u64 = (2 << 29) - 1;
/// # Device
///
/// A device is a handle to a physical GPU.
/// It is used to create resources and submit commands to the GPU.
///
/// Currently, WebGPU doesn't support multiple devices.
/// Ordinal should always be 0.
#[derive(Clone)]
pub struct WgpuDevice {
ordinal: u32,
buffer_allocator: Arc<BufferAllocator>,
bind_group_pool: Arc<BindGroupPool>,
bind_group_layout_pool: Arc<BindGroupLayoutPool>,
pipeline_layout_pool: Arc<PipelineLayoutPool>,
compute_pipeline_pool: Arc<ComputePipelinePool>,
kernel_module_pool: Arc<KernelModulePool>,
device_limits: DeviceLimits,
device_features: DeviceFeatures,
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
}
impl std::ops::Deref for WgpuDevice {
type Target = wgpu::Device;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl std::fmt::Debug for WgpuDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "wgpu:{}", self.ordinal)
}
}
impl PartialEq for WgpuDevice {
fn eq(&self, other: &Self) -> bool {
self.ordinal == other.ordinal && self.device.global_id() == other.device.global_id()
}
}
impl WgpuDevice {
pub async fn new() -> Result<Self, DeviceError> {
#[cfg(target_arch = "wasm32")]
let adapter = Self::select_adapter().await?;
#[cfg(not(target_arch = "wasm32"))]
let adapter = Self::select_adapter()?;
log::info!("Adapter: {:?}", adapter.get_info());
log::info!("Active GPU: {}", adapter.get_info().name);
#[allow(unused_mut)]
let mut required_features = wgpu::Features::default();
required_features |= wgpu::Features::SHADER_F16;
required_features |= wgpu::Features::SUBGROUP;
#[cfg(feature = "gpu-profiling")]
{
required_features |= wgpu::Features::TIMESTAMP_QUERY;
}
let mut device_descriptor = wgpu::DeviceDescriptor {
label: Some("Ratchet"),
required_features,
required_limits: Limits {
max_buffer_size: MAX_BUFFER_SIZE,
max_storage_buffer_binding_size: MAX_BUFFER_SIZE as u32,
max_compute_invocations_per_workgroup: 1024,
..Default::default()
},
memory_hints: wgpu::MemoryHints::Performance,
};
let device_request = adapter.request_device(&device_descriptor, None).await;
let (device, queue) = if let Err(e) = device_request {
log::error!("Failed to acq. device, trying with reduced limits: {:?}", e);
device_descriptor.required_limits = adapter.limits();
device_descriptor.required_features = adapter.features();
adapter.request_device(&device_descriptor, None).await
} else {
device_request
}?;
log::info!("Device: {:?}", device.limits());
let limits = DeviceLimits::from(device.limits());
let mut features = DeviceFeatures::from(device.features());
if std::env::var("RATCHET_FORCE_F32").is_ok() {
log::warn!("Forcing F32 precision");
features.SHADER_F16 = false;
}
if std::env::var("RATCHET_DISABLE_SUBGROUPS").is_ok() {
log::warn!("Disabling subgroup support");
features.SUBGROUP = false;
}
log::warn!("Device features: {:?}", features);
Ok(Self {
queue: Arc::new(queue),
ordinal: 0,
buffer_allocator: Arc::new(BufferAllocator::new()),
bind_group_pool: Arc::new(BindGroupPool::new()),
bind_group_layout_pool: Arc::new(BindGroupLayoutPool::new()),
pipeline_layout_pool: Arc::new(PipelineLayoutPool::new()),
kernel_module_pool: Arc::new(KernelModulePool::new()),
compute_pipeline_pool: Arc::new(ComputePipelinePool::new()),
device: Arc::new(device),
device_limits: limits,
device_features: features,
})
}
pub(crate) fn queue(&self) -> &wgpu::Queue {
&self.queue
}
pub fn ordinal(&self) -> u32 {
self.ordinal
}
#[cfg(target_arch = "wasm32")]
async fn select_adapter() -> Result<Adapter, DeviceError> {
let instance = wgpu::Instance::default();
instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.ok_or(DeviceError::AdapterRequestFailed)
}
#[cfg(not(target_arch = "wasm32"))]
fn select_adapter() -> Result<Adapter, DeviceError> {
use wgpu::DeviceType;
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
dx12_shader_compiler: wgpu::util::dx12_shader_compiler_from_env().unwrap_or_default(),
..Default::default()
});
let backends = wgpu::util::backend_bits_from_env().unwrap_or(wgpu::Backends::PRIMARY);
let adapter = instance
.enumerate_adapters(backends)
.into_iter()
.max_by_key(|adapter| match adapter.get_info().device_type {
DeviceType::DiscreteGpu => 5,
DeviceType::Other => 4,
DeviceType::IntegratedGpu => 3,
DeviceType::VirtualGpu => 2,
DeviceType::Cpu => 1,
})
.ok_or(DeviceError::AdapterRequestFailed)?;
Ok(adapter)
}
pub fn features(&self) -> &DeviceFeatures {
&self.device_features
}
pub fn limits(&self) -> &DeviceLimits {
&self.device_limits
}
}
impl WgpuDevice {
pub fn get_or_create_buffer_init(
&self,
desc: &BufferDescriptor,
contents: Cow<'_, [u8]>,
) -> Result<PooledGPUBuffer, DeviceError> {
Ok(self
.buffer_allocator
.create_buffer_init(desc, contents, self))
}
pub fn create_uniform_init(&self, cpu_uniform: CpuUniform) -> PooledGPUBuffer {
self.buffer_allocator.create_uniform_init(cpu_uniform, self)
}
pub fn get_or_create_buffer(
&self,
desc: &BufferDescriptor,
immediate: bool,
) -> Result<PooledGPUBuffer, DeviceError> {
Ok(self.buffer_allocator.create_buffer(desc, self, immediate))
}
pub fn get_buffer(&self, handle: GpuBufferHandle) -> Result<PooledGPUBuffer, DeviceError> {
Ok(self.buffer_allocator.get(handle))
}
pub fn get_or_create_bind_group(
&self,
desc: &BindGroupDescriptor,
) -> Result<GpuBindGroup, PoolError> {
Ok(self.bind_group_pool.get_or_create(desc, self))
}
pub fn get_or_create_bind_group_layout(
&self,
desc: &BindGroupLayoutDescriptor,
) -> Result<BindGroupLayoutHandle, PoolError> {
Ok(self.bind_group_layout_pool.get_or_create(desc, self))
}
pub fn get_or_create_pipeline_layout(
&self,
desc: &PipelineLayoutDescriptor,
) -> Result<PipelineLayoutHandle, PoolError> {
Ok(self.pipeline_layout_pool.get_or_create(desc, self))
}
pub fn get_or_create_compute_pipeline(
&self,
desc: &ComputePipelineDescriptor,
) -> Result<ComputePipelineHandle, PoolError> {
Ok(self.compute_pipeline_pool.get_or_create(desc, self))
}
pub fn get_or_create_compute_module<K: Kernel + ?Sized>(
&self,
desc: &KernelModuleDesc,
kernel: &K,
inplace: bool,
dst: &Tensor,
workgroup_size: &WorkgroupSize,
device: &WgpuDevice,
) -> KernelModuleHandle {
self.kernel_module_pool
.get_or_create(desc, kernel, inplace, dst, workgroup_size, device)
}
pub fn kernel_module_resources(
&self,
) -> StaticResourcePoolReadLockAccessor<'_, KernelModuleHandle, wgpu::ShaderModule> {
self.kernel_module_pool.resources()
}
pub fn bind_group_layout_resources(
&self,
) -> StaticResourcePoolReadLockAccessor<'_, BindGroupLayoutHandle, wgpu::BindGroupLayout> {
self.bind_group_layout_pool.resources()
}
pub fn pipeline_layout_resources(
&self,
) -> StaticResourcePoolReadLockAccessor<'_, PipelineLayoutHandle, wgpu::PipelineLayout> {
self.pipeline_layout_pool.resources()
}
pub fn pipeline_resources(
&self,
) -> StaticResourcePoolReadLockAccessor<'_, ComputePipelineHandle, wgpu::ComputePipeline> {
self.compute_pipeline_pool.resources()
}
/// Allocates all buffers required for storage of activations.
/// Additionally, allocates buffer for leaf node, the tensor upon which resolve was called.
pub fn allocate_cfg(
&self,
execution_order: &[&Tensor],
device: &WgpuDevice,
) -> Result<FxHashMap<TensorId, PooledGPUBuffer>, DeviceError> {
self.buffer_allocator.allocate_cfg(execution_order, device)
}
pub fn begin_pass(&self) {
self.buffer_allocator.begin_pass(0);
}
pub fn compute_features(&self) -> &DeviceFeatures {
&self.device_features
}
pub fn compute_limits(&self) -> &DeviceLimits {
&self.device_limits
}
}
#[derive(Clone)]
pub struct DeviceLimits {
pub max_bind_groups: u32,
pub max_storage_buffer_binding_size: u32,
pub max_compute_invocations_per_workgroup: u32,
}
impl From<wgpu::Limits> for DeviceLimits {
fn from(limits: wgpu::Limits) -> Self {
let wgpu::Limits {
max_bind_groups,
max_storage_buffer_binding_size,
max_compute_invocations_per_workgroup,
..
} = limits;
DeviceLimits {
max_bind_groups,
max_storage_buffer_binding_size,
max_compute_invocations_per_workgroup,
}
}
}
#[derive(Clone, Debug)]
pub struct DeviceFeatures {
pub SHADER_F16: bool,
pub SUBGROUP: bool,
}
impl DeviceFeatures {
pub fn compute_precision(&self) -> DType {
if self.SHADER_F16 {
DType::F16
} else {
DType::F32
}
}
}
impl From<wgpu::Features> for DeviceFeatures {
fn from(features: wgpu::Features) -> Self {
DeviceFeatures {
SHADER_F16: features.contains(wgpu::Features::SHADER_F16),
SUBGROUP: features.contains(wgpu::Features::SUBGROUP),
}
}
}