Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 70 additions & 27 deletions ostd/src/sync/rcu/non_null/either.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,80 @@
// SPDX-License-Identifier: MPL-2.0
use vstd::prelude::*;

use core::{marker::PhantomData, ptr::NonNull};

use super::NonNullPtr;
use super::{NonNullPtr, SmartPtrPointsTo, ptr_mut_from_nonnull};
use crate::util::Either;

verus! {
// If both `L` and `R` have at least one alignment bit (i.e., their alignments are at least 2), we
// can use the alignment bit to indicate whether a pointer is `L` or `R`, so it's possible to
// implement `NonNullPtr` for `Either<L, R>`.
unsafe impl<L: NonNullPtr, R: NonNullPtr> NonNullPtr for Either<L, R> {
type Target = PhantomData<Self>;

/*
type Ref<'a>
= Either<L::Ref<'a>, R::Ref<'a>>
where
Self: 'a;
*/

// const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS)
// .checked_sub(1)
// .expect("`L` and `R` alignments should be at least 2 to pack `Either` into one pointer");
fn ALIGN_BITS() -> u32 {
// min(L::ALIGN_BITS(), R::ALIGN_BITS())
// .checked_sub(1)
// .expect("`L` and `R` alignments should be at least 2 to pack `Either` into one pointer")
min(L::ALIGN_BITS(), R::ALIGN_BITS()).checked_sub(1).unwrap_or(0)
}

const ALIGN_BITS: u32 = min(L::ALIGN_BITS, R::ALIGN_BITS)
.checked_sub(1)
.expect("`L` and `R` alignments should be at least 2 to pack `Either` into one pointer");

fn into_raw(self) -> NonNull<Self::Target> {
// fn into_raw(self) -> NonNull<Self::Target> {
#[verifier::external_body]
fn into_raw(self) -> (NonNull<Self::Target>, Tracked<SmartPtrPointsTo<Self::Target>>) {
match self {
Self::Left(left) => left.into_raw().cast(),
Self::Right(right) => right
.into_raw()
.map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
.cast(),
Self::Left(left) => {
let (ptr, Tracked(_perm)) = left.into_raw();
(ptr.cast(), Tracked::assume_new())
}
Self::Right(right) => {
let (ptr, Tracked(_perm)) = right.into_raw();
let ptr = ptr
.map_addr(|addr| addr | (1 << Self::ALIGN_BITS()))
.cast();
(ptr, Tracked::assume_new())
}
}
}

unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self {
// unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self {
#[verifier::external_body]
unsafe fn from_raw(
ptr: NonNull<Self::Target>,
perm: Tracked<SmartPtrPointsTo<Self::Target>>,
) -> Self {
let _ = perm;
// SAFETY: The caller ensures that the pointer comes from `Self::into_raw`, which
// guarantees that `real_ptr` is a non-null pointer.
let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS) };
let (is_right, real_ptr) = unsafe { remove_bits(ptr, 1 << Self::ALIGN_BITS()) };

if is_right == 0 {
// SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `L::into_raw`. Other
// safety requirements are upheld by the caller.
Either::Left(unsafe { L::from_raw(real_ptr.cast()) })
Either::Left(unsafe { L::from_raw(real_ptr.cast(), Tracked::assume_new()) })
} else {
// SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `R::into_raw`. Other
// safety requirements are upheld by the caller.
Either::Right(unsafe { R::from_raw(real_ptr.cast()) })
Either::Right(unsafe { R::from_raw(real_ptr.cast(), Tracked::assume_new()) })
}
}

/*
unsafe fn raw_as_ref<'a>(raw: NonNull<Self::Target>) -> Self::Ref<'a> {
// SAFETY: The caller ensures that the pointer comes from `Self::into_raw`, which
// guarantees that `real_ptr` is a non-null pointer.
let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS) };
let (is_right, real_ptr) = unsafe { remove_bits(raw, 1 << Self::ALIGN_BITS()) };

if is_right == 0 {
// SAFETY: `Self::into_raw` guarantees that `real_ptr` comes from `L::into_raw`. Other
Expand All @@ -65,10 +91,19 @@ unsafe impl<L: NonNullPtr, R: NonNullPtr> NonNullPtr for Either<L, R> {
match ptr_ref {
Either::Left(left) => L::ref_as_raw(left).cast(),
Either::Right(right) => R::ref_as_raw(right)
.map_addr(|addr| addr | (1 << Self::ALIGN_BITS))
.map_addr(|addr| addr | (1 << Self::ALIGN_BITS()))
.cast(),
}
}
*/

open spec fn match_points_to_type(perm: SmartPtrPointsTo<Self::Target>) -> bool {
true
}

open spec fn ptr_mut_spec(self) -> *mut Self::Target {
arbitrary()
}
}

// A `min` implementation for use in constant evaluation.
Expand All @@ -84,17 +119,25 @@ const fn min(a: u32, b: u32) -> u32 {
///
/// The caller must ensure that removing the bits from the non-null pointer will result in another
/// non-null pointer.
unsafe fn remove_bits<T>(ptr: NonNull<T>, bits: usize) -> (usize, NonNull<T>) {
use core::num::NonZeroUsize;

let removed_bits = ptr.addr().get() & bits;
let result_ptr = ptr.map_addr(|addr|
// SAFETY: The safety is upheld by the caller.
unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) });
unsafe fn remove_bits<T>(ptr: NonNull<T>, bits: usize) -> (usize, NonNull<T>)
requires
(ptr_mut_from_nonnull(ptr)@.addr & bits) < ptr_mut_from_nonnull(ptr)@.addr,
(ptr_mut_from_nonnull(ptr)@.addr & !bits) != 0,
{
// let removed_bits = ptr.addr().get() & bits;
let raw = ptr.as_ptr();
let Tracked(exposed) = vstd::raw_ptr::expose_provenance(raw);
let addr = raw as usize;
let removed_bits = addr & bits;
// let result_ptr = ptr.map_addr(|addr|
// // SAFETY: The safety is upheld by the caller.
// unsafe { NonZeroUsize::new_unchecked(addr.get() & !bits) });
let result_raw = vstd::raw_ptr::with_exposed_provenance(addr & !bits, Tracked(exposed));
let result_ptr = unsafe { NonNull::new_unchecked(result_raw) };

(removed_bits, result_ptr)
}

}
#[cfg(ktest)]
mod test {
use alloc::{boxed::Box, sync::Arc};
Expand All @@ -107,8 +150,8 @@ mod test {

#[ktest]
fn alignment() {
assert_eq!(<Either32 as NonNullPtr>::ALIGN_BITS, 1);
assert_eq!(<Either16 as NonNullPtr>::ALIGN_BITS, 0);
assert_eq!(<Either32 as NonNullPtr>::ALIGN_BITS(), 1);
assert_eq!(<Either16 as NonNullPtr>::ALIGN_BITS(), 0);
}

#[ktest]
Expand Down
2 changes: 1 addition & 1 deletion ostd/src/sync/rcu/non_null/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use alloc::{boxed::Box, sync::Arc};
use vstd::prelude::*;
use vstd_extra::prelude::*;

//mod either;
mod either;

//use core::simd::ptr;
use core::{marker::PhantomData, mem::ManuallyDrop, ops::Deref, ptr::NonNull};
Expand Down
Loading