Skip to content

nvme_driver: allocate different DMA memory sizes if not bounce buffering #1306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions openhcl/underhill_core/src/nvme_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,21 @@ impl NvmeManagerWorker {
.await
.map_err(InnerError::Vfio)?;

let driver =
nvme_driver::NvmeDriver::new(&self.driver_source, self.vp_count, device)
.instrument(tracing::info_span!(
"nvme_driver_init",
pci_id = entry.key()
))
.await
.map_err(InnerError::DeviceInitFailed)?;
// TODO: For now, any isolation means use bounce buffering. This
// needs to change when we have nvme devices that support DMA to
// confidential memory.
let driver = nvme_driver::NvmeDriver::new(
&self.driver_source,
self.vp_count,
device,
self.is_isolated,
)
.instrument(tracing::info_span!(
"nvme_driver_init",
pci_id = entry.key()
))
.await
.map_err(InnerError::DeviceInitFailed)?;

entry.insert(driver)
}
Expand Down Expand Up @@ -385,11 +392,15 @@ impl NvmeManagerWorker {
.instrument(tracing::info_span!("vfio_device_restore", pci_id))
.await?;

// TODO: For now, any isolation means use bounce buffering. This
// needs to change when we have nvme devices that support DMA to
// confidential memory.
let nvme_driver = nvme_driver::NvmeDriver::restore(
&self.driver_source,
saved_state.cpu_count,
vfio_device,
&disk.driver_state,
self.is_isolated,
)
.instrument(tracing::info_span!("nvme_driver_restore"))
.await?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl FuzzNvmeDriver {
.unwrap();

let device = FuzzEmulatedDevice::new(nvme, msi_set, mem.dma_client());
let nvme_driver = NvmeDriver::new(&driver_source, cpu_count, device).await?; // TODO: [use-arbitrary-input]
let nvme_driver = NvmeDriver::new(&driver_source, cpu_count, device, false).await?; // TODO: [use-arbitrary-input]
let namespace = nvme_driver.namespace(1).await?; // TODO: [use-arbitrary-input]

Ok(Self {
Expand Down
45 changes: 38 additions & 7 deletions vm/devices/storage/disk_nvme/nvme_driver/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pub struct NvmeDriver<T: DeviceBacking> {
namespaces: Vec<Arc<Namespace>>,
/// Keeps the controller connected (CC.EN==1) while servicing.
nvme_keepalive: bool,
bounce_buffer: bool,
}

#[derive(Inspect)]
Expand All @@ -84,6 +85,7 @@ struct DriverWorkerTask<T: DeviceBacking> {
io_issuers: Arc<IoIssuers>,
#[inspect(skip)]
recv: mesh::Receiver<NvmeWorkerRequest>,
bounce_buffer: bool,
}

#[derive(Inspect)]
Expand Down Expand Up @@ -123,14 +125,21 @@ impl IoQueue {
registers: Arc<DeviceRegisters<impl DeviceBacking>>,
mem_block: MemoryBlock,
saved_state: &IoQueueSavedState,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let IoQueueSavedState {
cpu,
iv,
queue_data,
} = saved_state;
let queue =
QueuePair::restore(spawner, interrupt, registers.clone(), mem_block, queue_data)?;
let queue = QueuePair::restore(
spawner,
interrupt,
registers.clone(),
mem_block,
queue_data,
bounce_buffer,
)?;

Ok(Self {
queue,
Expand Down Expand Up @@ -169,9 +178,10 @@ impl<T: DeviceBacking> NvmeDriver<T> {
driver_source: &VmTaskDriverSource,
cpu_count: u32,
device: T,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let pci_id = device.id().to_owned();
let mut this = Self::new_disabled(driver_source, cpu_count, device)
let mut this = Self::new_disabled(driver_source, cpu_count, device, bounce_buffer)
.instrument(tracing::info_span!("nvme_new_disabled", pci_id))
.await?;
match this
Expand All @@ -197,6 +207,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
driver_source: &VmTaskDriverSource,
cpu_count: u32,
mut device: T,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let driver = driver_source.simple();
let bar0 = Bar0(
Expand Down Expand Up @@ -245,6 +256,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
io: Vec::new(),
io_issuers: io_issuers.clone(),
recv,
bounce_buffer,
})),
admin: None,
identify: None,
Expand All @@ -253,6 +265,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
rescan_event: Default::default(),
namespaces: vec![],
nvme_keepalive: false,
bounce_buffer,
})
}

Expand Down Expand Up @@ -285,6 +298,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
admin_cqes,
interrupt0,
worker.registers.clone(),
self.bounce_buffer,
)
.context("failed to create admin queue pair")?;

Expand Down Expand Up @@ -541,6 +555,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
cpu_count: u32,
mut device: T,
saved_state: &NvmeDriverSavedState,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let driver = driver_source.simple();
let bar0_mapping = device
Expand Down Expand Up @@ -571,6 +586,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
io: Vec::new(),
io_issuers: io_issuers.clone(),
recv,
bounce_buffer,
})),
admin: None, // Updated below.
identify: Some(Arc::new(
Expand All @@ -582,6 +598,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
rescan_event: Default::default(),
namespaces: vec![],
nvme_keepalive: true,
bounce_buffer,
};

let task = &mut this.task.as_mut().unwrap();
Expand Down Expand Up @@ -610,8 +627,15 @@ impl<T: DeviceBacking> NvmeDriver<T> {
.find(|mem| mem.len() == a.mem_len && a.base_pfn == mem.pfns()[0])
.expect("unable to find restored mem block")
.to_owned();
QueuePair::restore(driver.clone(), interrupt0, registers.clone(), mem_block, a)
.unwrap()
QueuePair::restore(
driver.clone(),
interrupt0,
registers.clone(),
mem_block,
a,
bounce_buffer,
)
.unwrap()
})
.unwrap();

Expand Down Expand Up @@ -657,8 +681,14 @@ impl<T: DeviceBacking> NvmeDriver<T> {
})
.expect("unable to find restored mem block")
.to_owned();
let q =
IoQueue::restore(driver.clone(), interrupt, registers.clone(), mem_block, q)?;
let q = IoQueue::restore(
driver.clone(),
interrupt,
registers.clone(),
mem_block,
q,
bounce_buffer,
)?;
let issuer = IoIssuer {
issuer: q.queue.issuer().clone(),
cpu: q.cpu,
Expand Down Expand Up @@ -866,6 +896,7 @@ impl<T: DeviceBacking> DriverWorkerTask<T> {
state.qsize,
interrupt,
self.registers.clone(),
self.bounce_buffer,
)
.with_context(|| format!("failed to create io queue pair {qid}"))?;

Expand Down
71 changes: 56 additions & 15 deletions vm/devices/storage/disk_nvme/nvme_driver/src/queue_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,18 @@ impl PendingCommands {
}

impl QueuePair {
pub const MAX_SQ_ENTRIES: u16 = (PAGE_SIZE / 64) as u16; // Maximum SQ size in entries.
pub const MAX_CQ_ENTRIES: u16 = (PAGE_SIZE / 16) as u16; // Maximum CQ size in entries.
const SQ_SIZE: usize = PAGE_SIZE; // Submission Queue size in bytes.
const CQ_SIZE: usize = PAGE_SIZE; // Completion Queue size in bytes.
const PER_QUEUE_PAGES: usize = 128;
/// Maximum SQ size in entries.
pub const MAX_SQ_ENTRIES: u16 = (PAGE_SIZE / 64) as u16;
/// Maximum CQ size in entries.
pub const MAX_CQ_ENTRIES: u16 = (PAGE_SIZE / 16) as u16;
/// Submission Queue size in bytes.
const SQ_SIZE: usize = PAGE_SIZE;
/// Completion Queue size in bytes.
const CQ_SIZE: usize = PAGE_SIZE;
/// Number of pages per queue if bounce buffering.
const PER_QUEUE_PAGES_BOUNCE_BUFFER: usize = 128;
/// Number of pages per queue if not bounce buffering.
const PER_QUEUE_PAGES_NO_BOUNCE_BUFFER: usize = 64;

pub fn new(
spawner: impl SpawnDriver,
Expand All @@ -180,9 +187,15 @@ impl QueuePair {
cq_entries: u16, // Requested CQ size in entries.
interrupt: DeviceInterrupt,
registers: Arc<DeviceRegisters<impl DeviceBacking>>,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let total_size =
QueuePair::SQ_SIZE + QueuePair::CQ_SIZE + QueuePair::PER_QUEUE_PAGES * PAGE_SIZE;
let total_size = QueuePair::SQ_SIZE
+ QueuePair::CQ_SIZE
+ if bounce_buffer {
QueuePair::PER_QUEUE_PAGES_BOUNCE_BUFFER * PAGE_SIZE
} else {
QueuePair::PER_QUEUE_PAGES_NO_BOUNCE_BUFFER * PAGE_SIZE
};
let dma_client = device.dma_client();
let mem = dma_client
.allocate_dma_buffer(total_size)
Expand All @@ -192,7 +205,15 @@ impl QueuePair {
assert!(cq_entries <= Self::MAX_CQ_ENTRIES);

QueuePair::new_or_restore(
spawner, qid, sq_entries, cq_entries, interrupt, registers, mem, None,
spawner,
qid,
sq_entries,
cq_entries,
interrupt,
registers,
mem,
None,
bounce_buffer,
)
}

Expand All @@ -206,6 +227,7 @@ impl QueuePair {
registers: Arc<DeviceRegisters<impl DeviceBacking>>,
mem: MemoryBlock,
saved_state: Option<&QueueHandlerSavedState>,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
// MemoryBlock is either allocated or restored prior calling here.
let sq_mem_block = mem.subblock(0, QueuePair::SQ_SIZE);
Expand Down Expand Up @@ -239,13 +261,30 @@ impl QueuePair {
}
});

// Page allocator uses remaining part of the buffer for dynamic allocation.
const _: () = assert!(
QueuePair::PER_QUEUE_PAGES * PAGE_SIZE >= 128 * 1024 + PAGE_SIZE,
"not enough room for an ATAPI IO plus a PRP list"
);
let alloc: PageAllocator =
PageAllocator::new(mem.subblock(data_offset, QueuePair::PER_QUEUE_PAGES * PAGE_SIZE));
// Convert the queue pages to bytes, and assert that queue size is large
// enough.
const fn pages_to_size_bytes(pages: usize) -> usize {
let size = pages * PAGE_SIZE;
assert!(
size >= 128 * 1024 + PAGE_SIZE,
"not enough room for an ATAPI IO plus a PRP list"
);
size
}

// Page allocator uses remaining part of the buffer for dynamic
// allocation. The length of the page allocator depends on if bounce
// buffering / double buffering is needed.
//
// NOTE: Do not remove the `const` blocks below. This is to force
// compile time evaluation of the assertion described above.
let alloc_len = if bounce_buffer {
const { pages_to_size_bytes(QueuePair::PER_QUEUE_PAGES_BOUNCE_BUFFER) }
} else {
const { pages_to_size_bytes(QueuePair::PER_QUEUE_PAGES_NO_BOUNCE_BUFFER) }
};

let alloc = PageAllocator::new(mem.subblock(data_offset, alloc_len));

Ok(Self {
task,
Expand Down Expand Up @@ -302,6 +341,7 @@ impl QueuePair {
registers: Arc<DeviceRegisters<impl DeviceBacking>>,
mem: MemoryBlock,
saved_state: &QueuePairSavedState,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let QueuePairSavedState {
mem_len: _, // Used to restore DMA buffer before calling this.
Expand All @@ -321,6 +361,7 @@ impl QueuePair {
registers,
mem,
Some(handler_data),
bounce_buffer,
)
}
}
Expand Down
8 changes: 4 additions & 4 deletions vm/devices/storage/disk_nvme/nvme_driver/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async fn test_nvme_ioqueue_max_mqes(driver: DefaultDriver) {
let cap: Cap = Cap::new().with_mqes_z(max_u16);
device.set_mock_response_u64(Some((0, cap.into())));

let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device).await;
let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device, false).await;
assert!(driver.is_ok());
}

Expand Down Expand Up @@ -113,7 +113,7 @@ async fn test_nvme_ioqueue_invalid_mqes(driver: DefaultDriver) {
// Setup mock response at offset 0
let cap: Cap = Cap::new().with_mqes_z(0);
device.set_mock_response_u64(Some((0, cap.into())));
let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device).await;
let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device, false).await;

assert!(driver.is_err());
}
Expand Down Expand Up @@ -150,7 +150,7 @@ async fn test_nvme_driver(driver: DefaultDriver, allow_dma: bool) {
.await
.unwrap();
let device = NvmeTestEmulatedDevice::new(nvme, msi_set, dma_client.clone());
let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device)
let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device, false)
.await
.unwrap();
let namespace = driver.namespace(1).await.unwrap();
Expand Down Expand Up @@ -266,7 +266,7 @@ async fn test_nvme_save_restore_inner(driver: DefaultDriver) {
.unwrap();

let device = NvmeTestEmulatedDevice::new(nvme_ctrl, msi_x, dma_client.clone());
let mut nvme_driver = NvmeDriver::new(&driver_source, CPU_COUNT, device)
let mut nvme_driver = NvmeDriver::new(&driver_source, CPU_COUNT, device, false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add a test case with true

.await
.unwrap();
let _ns1 = nvme_driver.namespace(1).await.unwrap();
Expand Down