Skip to content

nvme_driver: allocate different DMA memory sizes if not bounce buffering #1306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions openhcl/underhill_core/src/nvme_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,21 @@ impl NvmeManagerWorker {
.await
.map_err(InnerError::Vfio)?;

let driver =
nvme_driver::NvmeDriver::new(&self.driver_source, self.vp_count, device)
.instrument(tracing::info_span!(
"nvme_driver_init",
pci_id = entry.key()
))
.await
.map_err(InnerError::DeviceInitFailed)?;
// TODO: For now, any isolation means use bounce buffering. This
// needs to change when we have nvme devices that support DMA to
// confidential memory.
let driver = nvme_driver::NvmeDriver::new(
&self.driver_source,
self.vp_count,
device,
self.is_isolated,
)
.instrument(tracing::info_span!(
"nvme_driver_init",
pci_id = entry.key()
))
.await
.map_err(InnerError::DeviceInitFailed)?;

entry.insert(driver)
}
Expand Down Expand Up @@ -385,11 +392,15 @@ impl NvmeManagerWorker {
.instrument(tracing::info_span!("vfio_device_restore", pci_id))
.await?;

// TODO: For now, any isolation means use bounce buffering. This
// needs to change when we have nvme devices that support DMA to
// confidential memory.
let nvme_driver = nvme_driver::NvmeDriver::restore(
&self.driver_source,
saved_state.cpu_count,
vfio_device,
&disk.driver_state,
self.is_isolated,
)
.instrument(tracing::info_span!("nvme_driver_restore"))
.await?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl FuzzNvmeDriver {
.unwrap();

let device = FuzzEmulatedDevice::new(nvme, msi_set, mem.dma_client());
let nvme_driver = NvmeDriver::new(&driver_source, cpu_count, device).await?; // TODO: [use-arbitrary-input]
let nvme_driver = NvmeDriver::new(&driver_source, cpu_count, device, false).await?; // TODO: [use-arbitrary-input]
let namespace = nvme_driver.namespace(1).await?; // TODO: [use-arbitrary-input]

Ok(Self {
Expand Down
45 changes: 38 additions & 7 deletions vm/devices/storage/disk_nvme/nvme_driver/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pub struct NvmeDriver<T: DeviceBacking> {
namespaces: Vec<Arc<Namespace>>,
/// Keeps the controller connected (CC.EN==1) while servicing.
nvme_keepalive: bool,
bounce_buffer: bool,
}

#[derive(Inspect)]
Expand All @@ -84,6 +85,7 @@ struct DriverWorkerTask<T: DeviceBacking> {
io_issuers: Arc<IoIssuers>,
#[inspect(skip)]
recv: mesh::Receiver<NvmeWorkerRequest>,
bounce_buffer: bool,
}

#[derive(Inspect)]
Expand Down Expand Up @@ -123,14 +125,21 @@ impl IoQueue {
registers: Arc<DeviceRegisters<impl DeviceBacking>>,
mem_block: MemoryBlock,
saved_state: &IoQueueSavedState,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let IoQueueSavedState {
cpu,
iv,
queue_data,
} = saved_state;
let queue =
QueuePair::restore(spawner, interrupt, registers.clone(), mem_block, queue_data)?;
let queue = QueuePair::restore(
spawner,
interrupt,
registers.clone(),
mem_block,
queue_data,
bounce_buffer,
)?;

Ok(Self {
queue,
Expand Down Expand Up @@ -169,9 +178,10 @@ impl<T: DeviceBacking> NvmeDriver<T> {
driver_source: &VmTaskDriverSource,
cpu_count: u32,
device: T,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let pci_id = device.id().to_owned();
let mut this = Self::new_disabled(driver_source, cpu_count, device)
let mut this = Self::new_disabled(driver_source, cpu_count, device, bounce_buffer)
.instrument(tracing::info_span!("nvme_new_disabled", pci_id))
.await?;
match this
Expand All @@ -197,6 +207,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
driver_source: &VmTaskDriverSource,
cpu_count: u32,
mut device: T,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let driver = driver_source.simple();
let bar0 = Bar0(
Expand Down Expand Up @@ -245,6 +256,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
io: Vec::new(),
io_issuers: io_issuers.clone(),
recv,
bounce_buffer,
})),
admin: None,
identify: None,
Expand All @@ -253,6 +265,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
rescan_event: Default::default(),
namespaces: vec![],
nvme_keepalive: false,
bounce_buffer,
})
}

Expand Down Expand Up @@ -285,6 +298,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
admin_cqes,
interrupt0,
worker.registers.clone(),
self.bounce_buffer,
)
.context("failed to create admin queue pair")?;

Expand Down Expand Up @@ -541,6 +555,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
cpu_count: u32,
mut device: T,
saved_state: &NvmeDriverSavedState,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let driver = driver_source.simple();
let bar0_mapping = device
Expand Down Expand Up @@ -571,6 +586,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
io: Vec::new(),
io_issuers: io_issuers.clone(),
recv,
bounce_buffer,
})),
admin: None, // Updated below.
identify: Some(Arc::new(
Expand All @@ -582,6 +598,7 @@ impl<T: DeviceBacking> NvmeDriver<T> {
rescan_event: Default::default(),
namespaces: vec![],
nvme_keepalive: true,
bounce_buffer,
};

let task = &mut this.task.as_mut().unwrap();
Expand Down Expand Up @@ -610,8 +627,15 @@ impl<T: DeviceBacking> NvmeDriver<T> {
.find(|mem| mem.len() == a.mem_len && a.base_pfn == mem.pfns()[0])
.expect("unable to find restored mem block")
.to_owned();
QueuePair::restore(driver.clone(), interrupt0, registers.clone(), mem_block, a)
.unwrap()
QueuePair::restore(
driver.clone(),
interrupt0,
registers.clone(),
mem_block,
a,
bounce_buffer,
)
.unwrap()
})
.unwrap();

Expand Down Expand Up @@ -657,8 +681,14 @@ impl<T: DeviceBacking> NvmeDriver<T> {
})
.expect("unable to find restored mem block")
.to_owned();
let q =
IoQueue::restore(driver.clone(), interrupt, registers.clone(), mem_block, q)?;
let q = IoQueue::restore(
driver.clone(),
interrupt,
registers.clone(),
mem_block,
q,
bounce_buffer,
)?;
let issuer = IoIssuer {
issuer: q.queue.issuer().clone(),
cpu: q.cpu,
Expand Down Expand Up @@ -866,6 +896,7 @@ impl<T: DeviceBacking> DriverWorkerTask<T> {
state.qsize,
interrupt,
self.registers.clone(),
self.bounce_buffer,
)
.with_context(|| format!("failed to create io queue pair {qid}"))?;

Expand Down
65 changes: 51 additions & 14 deletions vm/devices/storage/disk_nvme/nvme_driver/src/queue_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,18 @@ impl PendingCommands {
}

impl QueuePair {
pub const MAX_SQ_ENTRIES: u16 = (PAGE_SIZE / 64) as u16; // Maximum SQ size in entries.
pub const MAX_CQ_ENTRIES: u16 = (PAGE_SIZE / 16) as u16; // Maximum CQ size in entries.
const SQ_SIZE: usize = PAGE_SIZE; // Submission Queue size in bytes.
const CQ_SIZE: usize = PAGE_SIZE; // Completion Queue size in bytes.
const PER_QUEUE_PAGES: usize = 128;
/// Maximum SQ size in entries.
pub const MAX_SQ_ENTRIES: u16 = (PAGE_SIZE / 64) as u16;
/// Maximum CQ size in entries.
pub const MAX_CQ_ENTRIES: u16 = (PAGE_SIZE / 16) as u16;
/// Submission Queue size in bytes.
const SQ_SIZE: usize = PAGE_SIZE;
/// Completion Queue size in bytes.
const CQ_SIZE: usize = PAGE_SIZE;
/// Number of pages per queue if bounce buffering.
const PER_QUEUE_PAGES_BOUNCE_BUFFER: usize = 128;
/// Number of pages per queue if not bounce buffering.
const PER_QUEUE_PAGES_NO_BOUNCE_BUFFER: usize = 64;

pub fn new(
spawner: impl SpawnDriver,
Expand All @@ -180,9 +187,15 @@ impl QueuePair {
cq_entries: u16, // Requested CQ size in entries.
interrupt: DeviceInterrupt,
registers: Arc<DeviceRegisters<impl DeviceBacking>>,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let total_size =
QueuePair::SQ_SIZE + QueuePair::CQ_SIZE + QueuePair::PER_QUEUE_PAGES * PAGE_SIZE;
let total_size = QueuePair::SQ_SIZE
+ QueuePair::CQ_SIZE
+ if bounce_buffer {
QueuePair::PER_QUEUE_PAGES_BOUNCE_BUFFER * PAGE_SIZE
} else {
QueuePair::PER_QUEUE_PAGES_NO_BOUNCE_BUFFER * PAGE_SIZE
};
let dma_client = device.dma_client();
let mem = dma_client
.allocate_dma_buffer(total_size)
Expand All @@ -192,7 +205,15 @@ impl QueuePair {
assert!(cq_entries <= Self::MAX_CQ_ENTRIES);

QueuePair::new_or_restore(
spawner, qid, sq_entries, cq_entries, interrupt, registers, mem, None,
spawner,
qid,
sq_entries,
cq_entries,
interrupt,
registers,
mem,
None,
bounce_buffer,
)
}

Expand All @@ -206,6 +227,7 @@ impl QueuePair {
registers: Arc<DeviceRegisters<impl DeviceBacking>>,
mem: MemoryBlock,
saved_state: Option<&QueueHandlerSavedState>,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
// MemoryBlock is either allocated or restored prior calling here.
let sq_mem_block = mem.subblock(0, QueuePair::SQ_SIZE);
Expand Down Expand Up @@ -240,12 +262,25 @@ impl QueuePair {
});

// Page allocator uses remaining part of the buffer for dynamic allocation.
const _: () = assert!(
QueuePair::PER_QUEUE_PAGES * PAGE_SIZE >= 128 * 1024 + PAGE_SIZE,
"not enough room for an ATAPI IO plus a PRP list"
);
let alloc: PageAllocator =
PageAllocator::new(mem.subblock(data_offset, QueuePair::PER_QUEUE_PAGES * PAGE_SIZE));
let alloc = if bounce_buffer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like this expression can be shortened? E.g. use one PageAllocator::new after checking all asserts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assert is probably not needed at all?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm the assert is checking the minimum size is large enough for the given commands right? I don't know enough to say if the assert should be removed or not, but it still seems useful?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, what I meant is - we are asserting if one constant is greater than another constant. Can be done statically outside of this function, I think. But okay, leave it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a static assert, since it's a const evaluation at compile time. A bit confusing, but it doesn't require taking a dependency on static_assert which i think does the same thing under the hood.

const _: () = assert!(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we package this up as a local const fn? E.g.,

const fn pages_to_size(pages: usize) -> usize {
    let size = pages * PAGE_SIZE;
    assert!(size >= ...);
    size
}
// Note: using const {} to ensure the internal assertion checks are at compile time
let alloc_len = if bounce_buffer {
    const { pages_to_size(QueuePair::PER_QUEUE_PAGES_BOUNCE_BUFFER) }
} else {
    const { pages_to_size(QueuePair::PER_QUEUE_PAGES_NO_BOUNCE_BUFFER) }
};

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat, can do

QueuePair::PER_QUEUE_PAGES_BOUNCE_BUFFER * PAGE_SIZE >= 128 * 1024 + PAGE_SIZE,
"not enough room for an ATAPI IO plus a PRP list"
);
PageAllocator::new(mem.subblock(
data_offset,
QueuePair::PER_QUEUE_PAGES_BOUNCE_BUFFER * PAGE_SIZE,
))
} else {
const _: () = assert!(
QueuePair::PER_QUEUE_PAGES_NO_BOUNCE_BUFFER * PAGE_SIZE >= 128 * 1024 + PAGE_SIZE,
"not enough room for an ATAPI IO plus a PRP list"
);
PageAllocator::new(mem.subblock(
data_offset,
QueuePair::PER_QUEUE_PAGES_NO_BOUNCE_BUFFER * PAGE_SIZE,
))
};

Ok(Self {
task,
Expand Down Expand Up @@ -302,6 +337,7 @@ impl QueuePair {
registers: Arc<DeviceRegisters<impl DeviceBacking>>,
mem: MemoryBlock,
saved_state: &QueuePairSavedState,
bounce_buffer: bool,
) -> anyhow::Result<Self> {
let QueuePairSavedState {
mem_len: _, // Used to restore DMA buffer before calling this.
Expand All @@ -321,6 +357,7 @@ impl QueuePair {
registers,
mem,
Some(handler_data),
bounce_buffer,
)
}
}
Expand Down
8 changes: 4 additions & 4 deletions vm/devices/storage/disk_nvme/nvme_driver/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async fn test_nvme_ioqueue_max_mqes(driver: DefaultDriver) {
let cap: Cap = Cap::new().with_mqes_z(max_u16);
device.set_mock_response_u64(Some((0, cap.into())));

let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device).await;
let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device, false).await;
assert!(driver.is_ok());
}

Expand Down Expand Up @@ -113,7 +113,7 @@ async fn test_nvme_ioqueue_invalid_mqes(driver: DefaultDriver) {
// Setup mock response at offset 0
let cap: Cap = Cap::new().with_mqes_z(0);
device.set_mock_response_u64(Some((0, cap.into())));
let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device).await;
let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device, false).await;

assert!(driver.is_err());
}
Expand Down Expand Up @@ -150,7 +150,7 @@ async fn test_nvme_driver(driver: DefaultDriver, allow_dma: bool) {
.await
.unwrap();
let device = NvmeTestEmulatedDevice::new(nvme, msi_set, dma_client.clone());
let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device)
let driver = NvmeDriver::new(&driver_source, CPU_COUNT, device, false)
.await
.unwrap();
let namespace = driver.namespace(1).await.unwrap();
Expand Down Expand Up @@ -266,7 +266,7 @@ async fn test_nvme_save_restore_inner(driver: DefaultDriver) {
.unwrap();

let device = NvmeTestEmulatedDevice::new(nvme_ctrl, msi_x, dma_client.clone());
let mut nvme_driver = NvmeDriver::new(&driver_source, CPU_COUNT, device)
let mut nvme_driver = NvmeDriver::new(&driver_source, CPU_COUNT, device, false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add a test case with true

.await
.unwrap();
let _ns1 = nvme_driver.namespace(1).await.unwrap();
Expand Down