integrations/virtiofs/src/virtiofs_util.rs (428 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. use std::cmp::min; use std::collections::VecDeque; use std::io::Read; use std::io::Write; use std::io::{self}; use std::mem::size_of; use std::mem::MaybeUninit; use std::ops::Deref; use std::ptr::copy_nonoverlapping; use virtio_queue::DescriptorChain; use vm_memory::bitmap::Bitmap; use vm_memory::bitmap::BitmapSlice; use vm_memory::Address; use vm_memory::ByteValued; use vm_memory::GuestMemory; use vm_memory::GuestMemoryMmap; use vm_memory::GuestMemoryRegion; use vm_memory::VolatileMemory; use vm_memory::VolatileSlice; use crate::buffer::ReadWriteAtVolatile; use crate::error::*; /// Used to consume and use data areas in shared memory between host and VMs. struct DescriptorChainConsumer<'a, B> { buffers: VecDeque<VolatileSlice<'a, B>>, bytes_consumed: usize, } impl<'a, B: BitmapSlice> DescriptorChainConsumer<'a, B> { #[cfg(test)] fn available_bytes(&self) -> usize { self.buffers.iter().fold(0, |count, vs| count + vs.len()) } fn bytes_consumed(&self) -> usize { self.bytes_consumed } fn consume<F>(&mut self, count: usize, f: F) -> Result<usize> where F: FnOnce(&[&VolatileSlice<B>]) -> Result<usize>, { let mut len = 0; let mut bufs = Vec::with_capacity(self.buffers.len()); for vs in &self.buffers { if len >= count { break; } bufs.push(vs); let remain = count - len; if remain < vs.len() { len += remain; } else { len += vs.len(); } } if bufs.is_empty() { return Ok(0); } let bytes_consumed = f(&bufs)?; let total_bytes_consumed = self.bytes_consumed .checked_add(bytes_consumed) .ok_or(new_vhost_user_fs_error( "the combined length of all the buffers in DescriptorChain would overflow", None, ))?; let mut remain = bytes_consumed; while let Some(vs) = self.buffers.pop_front() { if remain < vs.len() { self.buffers.push_front(vs.offset(remain).unwrap()); break; } remain -= vs.len(); } self.bytes_consumed = total_bytes_consumed; Ok(bytes_consumed) } fn split_at(&mut self, offset: usize) -> Result<DescriptorChainConsumer<'a, B>> { let mut remain = offset; let pos = self.buffers.iter().position(|vs| { if remain < vs.len() { true } else { remain -= vs.len(); false } }); if let Some(at) = pos { let mut other = self.buffers.split_off(at); if remain > 0 { let front = other.pop_front().expect("empty VecDeque after split"); self.buffers.push_back( front .subslice(0, remain) .map_err(|_| new_vhost_user_fs_error("volatile memory error", None))?, ); other.push_front( front .offset(remain) .map_err(|_| new_vhost_user_fs_error("volatile memory error", None))?, ); } Ok(DescriptorChainConsumer { buffers: other, bytes_consumed: 0, }) } else if remain == 0 { Ok(DescriptorChainConsumer { buffers: VecDeque::new(), bytes_consumed: 0, }) } else { Err(new_vhost_user_fs_error( "DescriptorChain split is out of bounds", None, )) } } } /// Provides a high-level interface for reading data in shared memory sequences. pub struct Reader<'a, B = ()> { buffer: DescriptorChainConsumer<'a, B>, } impl<'a, B: Bitmap + BitmapSlice + 'static> Reader<'a, B> { pub fn new<M>( mem: &'a GuestMemoryMmap<B>, desc_chain: DescriptorChain<M>, ) -> Result<Reader<'a, B>> where M: Deref, M::Target: GuestMemory + Sized, { let mut len: usize = 0; let buffers = desc_chain .readable() .map(|desc| { len = len .checked_add(desc.len() as usize) .ok_or(new_vhost_user_fs_error( "the combined length of all the buffers in DescriptorChain would overflow", None, ))?; let region = mem.find_region(desc.addr()).ok_or(new_vhost_user_fs_error( "no memory region for this address range", None, ))?; let offset = desc .addr() .checked_sub(region.start_addr().raw_value()) .unwrap(); region .deref() .get_slice(offset.raw_value() as usize, desc.len() as usize) .map_err(|err| { new_vhost_user_fs_error("volatile memory error", Some(err.into())) }) }) .collect::<Result<VecDeque<VolatileSlice<'a, B>>>>()?; Ok(Reader { buffer: DescriptorChainConsumer { buffers, bytes_consumed: 0, }, }) } pub fn read_obj<T: ByteValued>(&mut self) -> io::Result<T> { let mut obj = MaybeUninit::<T>::uninit(); let buf = unsafe { std::slice::from_raw_parts_mut(obj.as_mut_ptr() as *mut u8, size_of::<T>()) }; self.read_exact(buf)?; Ok(unsafe { obj.assume_init() }) } pub fn read_to_at<F: ReadWriteAtVolatile<B>>( &mut self, dst: F, count: usize, ) -> io::Result<usize> { self.buffer .consume(count, |bufs| dst.write_vectored_at_volatile(bufs)) .map_err(|err| err.into()) } #[cfg(test)] pub fn available_bytes(&self) -> usize { self.buffer.available_bytes() } #[cfg(test)] pub fn bytes_read(&self) -> usize { self.buffer.bytes_consumed() } } impl<B: BitmapSlice> io::Read for Reader<'_, B> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { self.buffer .consume(buf.len(), |bufs| { let mut rem = buf; let mut total = 0; for vs in bufs { let copy_len = min(rem.len(), vs.len()); unsafe { copy_nonoverlapping(vs.ptr_guard().as_ptr(), rem.as_mut_ptr(), copy_len); } rem = &mut rem[copy_len..]; total += copy_len; } Ok(total) }) .map_err(|err| err.into()) } } /// Provides a high-level interface for writing data in shared memory sequences. pub struct Writer<'a, B = ()> { buffer: DescriptorChainConsumer<'a, B>, } impl<'a, B: Bitmap + BitmapSlice + 'static> Writer<'a, B> { pub fn new<M>( mem: &'a GuestMemoryMmap<B>, desc_chain: DescriptorChain<M>, ) -> Result<Writer<'a, B>> where M: Deref, M::Target: GuestMemory + Sized, { let mut len: usize = 0; let buffers = desc_chain .writable() .map(|desc| { len = len .checked_add(desc.len() as usize) .ok_or(new_vhost_user_fs_error( "the combined length of all the buffers in DescriptorChain would overflow", None, ))?; let region = mem.find_region(desc.addr()).ok_or(new_vhost_user_fs_error( "no memory region for this address range", None, ))?; let offset = desc .addr() .checked_sub(region.start_addr().raw_value()) .unwrap(); region .deref() .get_slice(offset.raw_value() as usize, desc.len() as usize) .map_err(|err| { new_vhost_user_fs_error("volatile memory error", Some(err.into())) }) }) .collect::<Result<VecDeque<VolatileSlice<'a, B>>>>()?; Ok(Writer { buffer: DescriptorChainConsumer { buffers, bytes_consumed: 0, }, }) } pub fn split_at(&mut self, offset: usize) -> Result<Writer<'a, B>> { self.buffer.split_at(offset).map(|buffer| Writer { buffer }) } pub fn write_from_at<F: ReadWriteAtVolatile<B>>( &mut self, src: F, count: usize, ) -> io::Result<usize> { self.buffer .consume(count, |bufs| src.read_vectored_at_volatile(bufs)) .map_err(|err| err.into()) } #[cfg(test)] pub fn available_bytes(&self) -> usize { self.buffer.available_bytes() } pub fn bytes_written(&self) -> usize { self.buffer.bytes_consumed() } } impl<B: BitmapSlice> Write for Writer<'_, B> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { self.buffer .consume(buf.len(), |bufs| { let mut rem = buf; let mut total = 0; for vs in bufs { let copy_len = min(rem.len(), vs.len()); unsafe { copy_nonoverlapping(rem.as_ptr(), vs.ptr_guard_mut().as_ptr(), copy_len); } vs.bitmap().mark_dirty(0, copy_len); rem = &rem[copy_len..]; total += copy_len; } Ok(total) }) .map_err(|err| err.into()) } fn flush(&mut self) -> io::Result<()> { Ok(()) } } #[cfg(test)] mod tests { use super::*; use virtio_queue::Queue; use virtio_queue::QueueOwnedT; use virtio_queue::QueueT; use vm_memory::Bytes; use vm_memory::GuestAddress; use vm_memory::Le16; use vm_memory::Le32; use vm_memory::Le64; const VIRTQ_DESC_F_NEXT: u16 = 0x1; const VIRTQ_DESC_F_WRITE: u16 = 0x2; enum DescriptorType { Readable, Writable, } // Helper structure for testing, used to define the layout of the descriptor chain. #[derive(Copy, Clone, Debug, Default)] #[repr(C)] struct VirtqDesc { addr: Le64, len: Le32, flags: Le16, next: Le16, } // Helper structure for testing, used to define the layout of the available ring. #[derive(Copy, Clone, Debug, Default)] #[repr(C)] struct VirtqAvail { flags: Le16, idx: Le16, ring: Le16, } unsafe impl ByteValued for VirtqAvail {} unsafe impl ByteValued for VirtqDesc {} // Helper function for testing, used to create a descriptor chain with the specified descriptors. fn create_descriptor_chain( memory: &GuestMemoryMmap, descriptor_array_addr: GuestAddress, mut buffers_start_addr: GuestAddress, descriptors: Vec<(DescriptorType, u32)>, ) -> DescriptorChain<&GuestMemoryMmap> { let descriptors_len = descriptors.len(); for (index, (type_, size)) in descriptors.into_iter().enumerate() { let mut flags = 0; if let DescriptorType::Writable = type_ { flags |= VIRTQ_DESC_F_WRITE; } if index + 1 < descriptors_len { flags |= VIRTQ_DESC_F_NEXT; } let desc = VirtqDesc { addr: buffers_start_addr.raw_value().into(), len: size.into(), flags: flags.into(), next: (index as u16 + 1).into(), }; buffers_start_addr = buffers_start_addr.checked_add(size as u64).unwrap(); memory .write_obj( desc, descriptor_array_addr .checked_add((index * std::mem::size_of::<VirtqDesc>()) as u64) .unwrap(), ) .unwrap(); } let avail_ring = descriptor_array_addr .checked_add((descriptors_len * std::mem::size_of::<VirtqDesc>()) as u64) .unwrap(); let avail = VirtqAvail { flags: 0.into(), idx: 1.into(), ring: 0.into(), }; memory.write_obj(avail, avail_ring).unwrap(); let mut queue = Queue::new(4).unwrap(); queue .try_set_desc_table_address(descriptor_array_addr) .unwrap(); queue.try_set_avail_ring_address(avail_ring).unwrap(); queue.set_ready(true); queue.iter(memory).unwrap().next().unwrap() } #[test] fn simple_chain_reader_test() { let memory_start_addr = GuestAddress(0x0); let memory = GuestMemoryMmap::from_ranges(&[(memory_start_addr, 0x1000)]).unwrap(); let chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (DescriptorType::Readable, 8), (DescriptorType::Readable, 16), (DescriptorType::Readable, 18), (DescriptorType::Readable, 64), ], ); let mut reader = Reader::new(&memory, chain).unwrap(); assert_eq!(reader.available_bytes(), 106); assert_eq!(reader.bytes_read(), 0); let mut buffer = [0; 64]; reader.read_exact(&mut buffer).unwrap(); assert_eq!(reader.available_bytes(), 42); assert_eq!(reader.bytes_read(), 64); assert_eq!(reader.read(&mut buffer).unwrap(), 42); assert_eq!(reader.available_bytes(), 0); assert_eq!(reader.bytes_read(), 106); } #[test] fn simple_chain_writer_test() { let memory_start_addr = GuestAddress(0x0); let memory = GuestMemoryMmap::from_ranges(&[(memory_start_addr, 0x1000)]).unwrap(); let chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (DescriptorType::Writable, 8), (DescriptorType::Writable, 16), (DescriptorType::Writable, 18), (DescriptorType::Writable, 64), ], ); let mut writer = Writer::new(&memory, chain).unwrap(); assert_eq!(writer.available_bytes(), 106); assert_eq!(writer.bytes_written(), 0); let buffer = [0; 64]; writer.write_all(&buffer).unwrap(); assert_eq!(writer.available_bytes(), 42); assert_eq!(writer.bytes_written(), 64); assert_eq!(writer.write(&buffer).unwrap(), 42); assert_eq!(writer.available_bytes(), 0); assert_eq!(writer.bytes_written(), 106); } }