integrations/virtiofs/src/virtiofs_util.rs (428 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::cmp::min;
use std::collections::VecDeque;
use std::io::Read;
use std::io::Write;
use std::io::{self};
use std::mem::size_of;
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::ptr::copy_nonoverlapping;
use virtio_queue::DescriptorChain;
use vm_memory::bitmap::Bitmap;
use vm_memory::bitmap::BitmapSlice;
use vm_memory::Address;
use vm_memory::ByteValued;
use vm_memory::GuestMemory;
use vm_memory::GuestMemoryMmap;
use vm_memory::GuestMemoryRegion;
use vm_memory::VolatileMemory;
use vm_memory::VolatileSlice;
use crate::buffer::ReadWriteAtVolatile;
use crate::error::*;
/// Used to consume and use data areas in shared memory between host and VMs.
struct DescriptorChainConsumer<'a, B> {
buffers: VecDeque<VolatileSlice<'a, B>>,
bytes_consumed: usize,
}
impl<'a, B: BitmapSlice> DescriptorChainConsumer<'a, B> {
#[cfg(test)]
fn available_bytes(&self) -> usize {
self.buffers.iter().fold(0, |count, vs| count + vs.len())
}
fn bytes_consumed(&self) -> usize {
self.bytes_consumed
}
fn consume<F>(&mut self, count: usize, f: F) -> Result<usize>
where
F: FnOnce(&[&VolatileSlice<B>]) -> Result<usize>,
{
let mut len = 0;
let mut bufs = Vec::with_capacity(self.buffers.len());
for vs in &self.buffers {
if len >= count {
break;
}
bufs.push(vs);
let remain = count - len;
if remain < vs.len() {
len += remain;
} else {
len += vs.len();
}
}
if bufs.is_empty() {
return Ok(0);
}
let bytes_consumed = f(&bufs)?;
let total_bytes_consumed =
self.bytes_consumed
.checked_add(bytes_consumed)
.ok_or(new_vhost_user_fs_error(
"the combined length of all the buffers in DescriptorChain would overflow",
None,
))?;
let mut remain = bytes_consumed;
while let Some(vs) = self.buffers.pop_front() {
if remain < vs.len() {
self.buffers.push_front(vs.offset(remain).unwrap());
break;
}
remain -= vs.len();
}
self.bytes_consumed = total_bytes_consumed;
Ok(bytes_consumed)
}
fn split_at(&mut self, offset: usize) -> Result<DescriptorChainConsumer<'a, B>> {
let mut remain = offset;
let pos = self.buffers.iter().position(|vs| {
if remain < vs.len() {
true
} else {
remain -= vs.len();
false
}
});
if let Some(at) = pos {
let mut other = self.buffers.split_off(at);
if remain > 0 {
let front = other.pop_front().expect("empty VecDeque after split");
self.buffers.push_back(
front
.subslice(0, remain)
.map_err(|_| new_vhost_user_fs_error("volatile memory error", None))?,
);
other.push_front(
front
.offset(remain)
.map_err(|_| new_vhost_user_fs_error("volatile memory error", None))?,
);
}
Ok(DescriptorChainConsumer {
buffers: other,
bytes_consumed: 0,
})
} else if remain == 0 {
Ok(DescriptorChainConsumer {
buffers: VecDeque::new(),
bytes_consumed: 0,
})
} else {
Err(new_vhost_user_fs_error(
"DescriptorChain split is out of bounds",
None,
))
}
}
}
/// Provides a high-level interface for reading data in shared memory sequences.
pub struct Reader<'a, B = ()> {
buffer: DescriptorChainConsumer<'a, B>,
}
impl<'a, B: Bitmap + BitmapSlice + 'static> Reader<'a, B> {
pub fn new<M>(
mem: &'a GuestMemoryMmap<B>,
desc_chain: DescriptorChain<M>,
) -> Result<Reader<'a, B>>
where
M: Deref,
M::Target: GuestMemory + Sized,
{
let mut len: usize = 0;
let buffers = desc_chain
.readable()
.map(|desc| {
len = len
.checked_add(desc.len() as usize)
.ok_or(new_vhost_user_fs_error(
"the combined length of all the buffers in DescriptorChain would overflow",
None,
))?;
let region = mem.find_region(desc.addr()).ok_or(new_vhost_user_fs_error(
"no memory region for this address range",
None,
))?;
let offset = desc
.addr()
.checked_sub(region.start_addr().raw_value())
.unwrap();
region
.deref()
.get_slice(offset.raw_value() as usize, desc.len() as usize)
.map_err(|err| {
new_vhost_user_fs_error("volatile memory error", Some(err.into()))
})
})
.collect::<Result<VecDeque<VolatileSlice<'a, B>>>>()?;
Ok(Reader {
buffer: DescriptorChainConsumer {
buffers,
bytes_consumed: 0,
},
})
}
pub fn read_obj<T: ByteValued>(&mut self) -> io::Result<T> {
let mut obj = MaybeUninit::<T>::uninit();
let buf =
unsafe { std::slice::from_raw_parts_mut(obj.as_mut_ptr() as *mut u8, size_of::<T>()) };
self.read_exact(buf)?;
Ok(unsafe { obj.assume_init() })
}
pub fn read_to_at<F: ReadWriteAtVolatile<B>>(
&mut self,
dst: F,
count: usize,
) -> io::Result<usize> {
self.buffer
.consume(count, |bufs| dst.write_vectored_at_volatile(bufs))
.map_err(|err| err.into())
}
#[cfg(test)]
pub fn available_bytes(&self) -> usize {
self.buffer.available_bytes()
}
#[cfg(test)]
pub fn bytes_read(&self) -> usize {
self.buffer.bytes_consumed()
}
}
impl<B: BitmapSlice> io::Read for Reader<'_, B> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.buffer
.consume(buf.len(), |bufs| {
let mut rem = buf;
let mut total = 0;
for vs in bufs {
let copy_len = min(rem.len(), vs.len());
unsafe {
copy_nonoverlapping(vs.ptr_guard().as_ptr(), rem.as_mut_ptr(), copy_len);
}
rem = &mut rem[copy_len..];
total += copy_len;
}
Ok(total)
})
.map_err(|err| err.into())
}
}
/// Provides a high-level interface for writing data in shared memory sequences.
pub struct Writer<'a, B = ()> {
buffer: DescriptorChainConsumer<'a, B>,
}
impl<'a, B: Bitmap + BitmapSlice + 'static> Writer<'a, B> {
pub fn new<M>(
mem: &'a GuestMemoryMmap<B>,
desc_chain: DescriptorChain<M>,
) -> Result<Writer<'a, B>>
where
M: Deref,
M::Target: GuestMemory + Sized,
{
let mut len: usize = 0;
let buffers = desc_chain
.writable()
.map(|desc| {
len = len
.checked_add(desc.len() as usize)
.ok_or(new_vhost_user_fs_error(
"the combined length of all the buffers in DescriptorChain would overflow",
None,
))?;
let region = mem.find_region(desc.addr()).ok_or(new_vhost_user_fs_error(
"no memory region for this address range",
None,
))?;
let offset = desc
.addr()
.checked_sub(region.start_addr().raw_value())
.unwrap();
region
.deref()
.get_slice(offset.raw_value() as usize, desc.len() as usize)
.map_err(|err| {
new_vhost_user_fs_error("volatile memory error", Some(err.into()))
})
})
.collect::<Result<VecDeque<VolatileSlice<'a, B>>>>()?;
Ok(Writer {
buffer: DescriptorChainConsumer {
buffers,
bytes_consumed: 0,
},
})
}
pub fn split_at(&mut self, offset: usize) -> Result<Writer<'a, B>> {
self.buffer.split_at(offset).map(|buffer| Writer { buffer })
}
pub fn write_from_at<F: ReadWriteAtVolatile<B>>(
&mut self,
src: F,
count: usize,
) -> io::Result<usize> {
self.buffer
.consume(count, |bufs| src.read_vectored_at_volatile(bufs))
.map_err(|err| err.into())
}
#[cfg(test)]
pub fn available_bytes(&self) -> usize {
self.buffer.available_bytes()
}
pub fn bytes_written(&self) -> usize {
self.buffer.bytes_consumed()
}
}
impl<B: BitmapSlice> Write for Writer<'_, B> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.buffer
.consume(buf.len(), |bufs| {
let mut rem = buf;
let mut total = 0;
for vs in bufs {
let copy_len = min(rem.len(), vs.len());
unsafe {
copy_nonoverlapping(rem.as_ptr(), vs.ptr_guard_mut().as_ptr(), copy_len);
}
vs.bitmap().mark_dirty(0, copy_len);
rem = &rem[copy_len..];
total += copy_len;
}
Ok(total)
})
.map_err(|err| err.into())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use virtio_queue::Queue;
use virtio_queue::QueueOwnedT;
use virtio_queue::QueueT;
use vm_memory::Bytes;
use vm_memory::GuestAddress;
use vm_memory::Le16;
use vm_memory::Le32;
use vm_memory::Le64;
const VIRTQ_DESC_F_NEXT: u16 = 0x1;
const VIRTQ_DESC_F_WRITE: u16 = 0x2;
enum DescriptorType {
Readable,
Writable,
}
// Helper structure for testing, used to define the layout of the descriptor chain.
#[derive(Copy, Clone, Debug, Default)]
#[repr(C)]
struct VirtqDesc {
addr: Le64,
len: Le32,
flags: Le16,
next: Le16,
}
// Helper structure for testing, used to define the layout of the available ring.
#[derive(Copy, Clone, Debug, Default)]
#[repr(C)]
struct VirtqAvail {
flags: Le16,
idx: Le16,
ring: Le16,
}
unsafe impl ByteValued for VirtqAvail {}
unsafe impl ByteValued for VirtqDesc {}
// Helper function for testing, used to create a descriptor chain with the specified descriptors.
fn create_descriptor_chain(
memory: &GuestMemoryMmap,
descriptor_array_addr: GuestAddress,
mut buffers_start_addr: GuestAddress,
descriptors: Vec<(DescriptorType, u32)>,
) -> DescriptorChain<&GuestMemoryMmap> {
let descriptors_len = descriptors.len();
for (index, (type_, size)) in descriptors.into_iter().enumerate() {
let mut flags = 0;
if let DescriptorType::Writable = type_ {
flags |= VIRTQ_DESC_F_WRITE;
}
if index + 1 < descriptors_len {
flags |= VIRTQ_DESC_F_NEXT;
}
let desc = VirtqDesc {
addr: buffers_start_addr.raw_value().into(),
len: size.into(),
flags: flags.into(),
next: (index as u16 + 1).into(),
};
buffers_start_addr = buffers_start_addr.checked_add(size as u64).unwrap();
memory
.write_obj(
desc,
descriptor_array_addr
.checked_add((index * std::mem::size_of::<VirtqDesc>()) as u64)
.unwrap(),
)
.unwrap();
}
let avail_ring = descriptor_array_addr
.checked_add((descriptors_len * std::mem::size_of::<VirtqDesc>()) as u64)
.unwrap();
let avail = VirtqAvail {
flags: 0.into(),
idx: 1.into(),
ring: 0.into(),
};
memory.write_obj(avail, avail_ring).unwrap();
let mut queue = Queue::new(4).unwrap();
queue
.try_set_desc_table_address(descriptor_array_addr)
.unwrap();
queue.try_set_avail_ring_address(avail_ring).unwrap();
queue.set_ready(true);
queue.iter(memory).unwrap().next().unwrap()
}
#[test]
fn simple_chain_reader_test() {
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemoryMmap::from_ranges(&[(memory_start_addr, 0x1000)]).unwrap();
let chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(DescriptorType::Readable, 8),
(DescriptorType::Readable, 16),
(DescriptorType::Readable, 18),
(DescriptorType::Readable, 64),
],
);
let mut reader = Reader::new(&memory, chain).unwrap();
assert_eq!(reader.available_bytes(), 106);
assert_eq!(reader.bytes_read(), 0);
let mut buffer = [0; 64];
reader.read_exact(&mut buffer).unwrap();
assert_eq!(reader.available_bytes(), 42);
assert_eq!(reader.bytes_read(), 64);
assert_eq!(reader.read(&mut buffer).unwrap(), 42);
assert_eq!(reader.available_bytes(), 0);
assert_eq!(reader.bytes_read(), 106);
}
#[test]
fn simple_chain_writer_test() {
let memory_start_addr = GuestAddress(0x0);
let memory = GuestMemoryMmap::from_ranges(&[(memory_start_addr, 0x1000)]).unwrap();
let chain = create_descriptor_chain(
&memory,
GuestAddress(0x0),
GuestAddress(0x100),
vec![
(DescriptorType::Writable, 8),
(DescriptorType::Writable, 16),
(DescriptorType::Writable, 18),
(DescriptorType::Writable, 64),
],
);
let mut writer = Writer::new(&memory, chain).unwrap();
assert_eq!(writer.available_bytes(), 106);
assert_eq!(writer.bytes_written(), 0);
let buffer = [0; 64];
writer.write_all(&buffer).unwrap();
assert_eq!(writer.available_bytes(), 42);
assert_eq!(writer.bytes_written(), 64);
assert_eq!(writer.write(&buffer).unwrap(), 42);
assert_eq!(writer.available_bytes(), 0);
assert_eq!(writer.bytes_written(), 106);
}
}