Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mega pr, do not merge] types: rewrite inference API to use a slab allocator for type bounds #228

Closed
wants to merge 16 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
types: remove set and get methods from BoundRef
Once our BoundRefs start requiring an inference context to access their
data, we won't be able to call .set() and .get() on them individually.
Remove these methods, and instead add them on Context.

Doing this means that everywhere we currently call .get and .set, we
need a context available. To achieve this, we add the context to Type,
and swap the fmt::Debug/fmt::Display impls for Type and Bound so that
Type is the primary one (since it has the context object available).

The change to use BoundRef in the finalization code means we can change
our occurs-check from directly using Arc::<Bound>::as_ptr to using a
more "principled" OccursCheckId object yielded from the BoundRef. This
in turn means that we no longer need to use Arc<Bound> anywhere, and
can instead directly use Bound (which is cheap to clone and doesn't need
to be wrapped in an Arc, except when we are using Arc to obtain a
pointer-id for use in the occurs check).

Converting Arc<Bound> to Bound in turn lets us remove a bunch of
Arc::new and Arc::clone calls throughout.

Again, "API only", but there's a lot going on here.
apoelstra committed Jun 30, 2024
commit f3ec9d2a02769f3950aad11377137407c785b481
3 changes: 2 additions & 1 deletion src/node/redeem.rs
Original file line number Diff line number Diff line change
@@ -290,7 +290,8 @@ impl<J: Jet> RedeemNode<J> {
data: &PostOrderIterItem<&ConstructNode<J>>,
_: &NoWitness,
) -> Result<Arc<Value>, Self::Error> {
let target_ty = data.node.data.arrow().target.finalize()?;
let arrow = data.node.data.arrow();
let target_ty = arrow.target.finalize()?;
self.bits.read_value(&target_ty).map_err(Error::from)
}

8 changes: 4 additions & 4 deletions src/types/arrow.rs
Original file line number Diff line number Diff line change
@@ -125,15 +125,15 @@ impl Arrow {
if let Some(lchild_arrow) = lchild_arrow {
ctx.bind(
&lchild_arrow.source,
Arc::new(Bound::Product(a, c.shallow_clone())),
Bound::Product(a, c.shallow_clone()),
"case combinator: left source = A × C",
)?;
ctx.unify(&target, &lchild_arrow.target, "").unwrap();
}
if let Some(rchild_arrow) = rchild_arrow {
ctx.bind(
&rchild_arrow.source,
Arc::new(Bound::Product(b, c)),
Bound::Product(b, c),
"case combinator: left source = B × C",
)?;
ctx.unify(
@@ -168,12 +168,12 @@ impl Arrow {

ctx.bind(
&lchild_arrow.source,
Arc::new(prod_256_a),
prod_256_a,
"disconnect combinator: left source = 2^256 × A",
)?;
ctx.bind(
&lchild_arrow.target,
Arc::new(prod_b_c),
prod_b_c,
"disconnect combinator: left target = B × C",
)?;

74 changes: 64 additions & 10 deletions src/types/context.rs
Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@
use std::fmt;
use std::sync::{Arc, Mutex};

use crate::dag::{Dag, DagLike};

use super::bound_mutex::BoundMutex;
use super::{Bound, Error, Final, Type};

@@ -123,11 +125,37 @@ impl Context {
}
}

/// Accesses a bound.
///
/// # Panics
///
/// Panics if passed a `BoundRef` that was not allocated by this context.
pub fn get(&self, bound: &BoundRef) -> Bound {
bound.assert_matches_context(self);
bound.index.get().shallow_clone()
}

/// Reassigns a bound to a different bound.
///
/// # Panics
///
/// Panics if called on a complete type. This is a sanity-check to avoid
/// replacing already-completed types, which can cause inefficiencies in
/// the union-bound algorithm (and if our replacement changes the type,
/// this is probably a bug.
/// probably a bug.
///
/// Also panics if passed a `BoundRef` that was not allocated by this context.
pub fn reassign_non_complete(&self, bound: BoundRef, new: Bound) {
bound.assert_matches_context(self);
bound.index.set(new)
}

/// Binds the type to a given bound. If this fails, attach the provided
/// hint to the error.
///
/// Fails if the type has an existing incompatible bound.
pub fn bind(&self, existing: &Type, new: Arc<Bound>, hint: &'static str) -> Result<(), Error> {
pub fn bind(&self, existing: &Type, new: Bound, hint: &'static str) -> Result<(), Error> {
existing.bind(new, hint)
}

@@ -139,7 +167,7 @@ impl Context {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BoundRef {
context: *const Mutex<Vec<Bound>>,
// Will become an index into the context in a latter commit, but for
@@ -156,16 +184,18 @@ impl BoundRef {
);
}

pub fn get(&self) -> Arc<Bound> {
self.index.get()
}

pub fn set(&self, new: Arc<Bound>) {
self.index.set(new)
pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> {
self.index.bind(bound, hint)
}

pub fn bind(&self, bound: Arc<Bound>, hint: &'static str) -> Result<(), Error> {
self.index.bind(bound, hint)
/// Creates an "occurs-check ID" which is just a copy of the [`BoundRef`]
/// with `PartialEq` and `Eq` implemented in terms of underlying pointer
/// equality.
pub fn occurs_check_id(&self) -> OccursCheckId {
OccursCheckId {
context: self.context,
index: Arc::as_ptr(&self.index),
}
}
}

@@ -185,3 +215,27 @@ impl super::PointerLike for BoundRef {
}
}
}

impl<'ctx> DagLike for (&'ctx Context, BoundRef) {
type Node = BoundRef;
fn data(&self) -> &BoundRef {
&self.1
}

fn as_dag_node(&self) -> Dag<Self> {
match self.0.get(&self.1) {
Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
Dag::Binary((self.0, ty1.bound.root()), (self.0, ty2.bound.root()))
}
}
}
}

#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct OccursCheckId {
context: *const Mutex<Vec<Bound>>,
// Will become an index into the context in a latter commit, but for
// now we set it to an Arc<BoundMutex> to preserve semantics.
index: *const BoundMutex,
}
247 changes: 132 additions & 115 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -105,7 +105,7 @@ pub enum Error {
hint: &'static str,
},
/// A type is recursive (i.e., occurs within itself), violating the "occurs check"
OccursCheck { infinite_bound: Arc<Bound> },
OccursCheck { infinite_bound: Bound },
/// Attempted to combine two nodes which had different type inference
/// contexts. This is probably a programming error.
InferenceContextMismatch,
@@ -156,7 +156,7 @@ mod bound_mutex {
/// Source or target type of a Simplicity expression
pub struct BoundMutex {
/// The type's status according to the union-bound algorithm.
inner: Mutex<Arc<Bound>>,
inner: Mutex<Bound>,
}

impl fmt::Debug for BoundMutex {
@@ -168,32 +168,32 @@ mod bound_mutex {
impl BoundMutex {
pub fn new(bound: Bound) -> Self {
BoundMutex {
inner: Mutex::new(Arc::new(bound)),
inner: Mutex::new(bound),
}
}

pub fn get(&self) -> Arc<Bound> {
Arc::clone(&self.inner.lock().unwrap())
pub fn get(&self) -> Bound {
self.inner.lock().unwrap().shallow_clone()
}

pub fn set(&self, new: Arc<Bound>) {
pub fn set(&self, new: Bound) {
let mut lock = self.inner.lock().unwrap();
assert!(
!matches!(**lock, Bound::Complete(..)),
!matches!(*lock, Bound::Complete(..)),
"tried to modify finalized type",
);
*lock = new;
}

pub fn bind(&self, bound: Arc<Bound>, hint: &'static str) -> Result<(), Error> {
pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> {
let existing_bound = self.get();
let bind_error = || Error::Bind {
existing_bound: existing_bound.shallow_clone(),
new_bound: bound.shallow_clone(),
hint,
};

match (existing_bound.as_ref(), bound.as_ref()) {
match (&existing_bound, &bound) {
// Binding a free type to anything is a no-op
(_, Bound::Free(_)) => Ok(()),
// Free types are simply dropped and replaced by the new bound
@@ -226,8 +226,8 @@ mod bound_mutex {
CompleteBound::Sum(ref comp1, ref comp2),
Bound::Sum(ref ty1, ref ty2),
) => {
ty1.bind(Arc::new(Bound::Complete(Arc::clone(comp1))), hint)?;
ty2.bind(Arc::new(Bound::Complete(Arc::clone(comp2))), hint)
ty1.bind(Bound::Complete(Arc::clone(comp1)), hint)?;
ty2.bind(Bound::Complete(Arc::clone(comp2)), hint)
}
_ => Err(bind_error()),
}
@@ -244,11 +244,11 @@ mod bound_mutex {
// It also gives the user access to more information about the type,
// prior to finalization.
if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) {
self.set(Arc::new(Bound::Complete(if let Bound::Sum(..) = *bound {
self.set(Bound::Complete(if let Bound::Sum(..) = bound {
Final::sum(data1, data2)
} else {
Final::product(data1, data2)
})));
}));
}
Ok(())
}
@@ -263,7 +263,7 @@ mod bound_mutex {
}

/// The state of a [`Type`] based on all constraints currently imposed on it.
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum Bound {
/// Fully-unconstrained type
Free(String),
@@ -275,6 +275,25 @@ pub enum Bound {
Product(Type, Type),
}

impl fmt::Display for Bound {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Bound::Free(s) => f.write_str(s),
Bound::Complete(comp) => comp.fmt(f),
Bound::Sum(ty1, ty2) => {
ty1.fmt(f)?;
f.write_str(" + ")?;
ty2.fmt(f)
}
Bound::Product(ty1, ty2) => {
ty1.fmt(f)?;
f.write_str(" × ")?;
ty2.fmt(f)
}
}
}
}

impl Bound {
/// Clones the `Bound`.
///
@@ -302,86 +321,15 @@ impl Bound {
}
}

const MAX_DISPLAY_DEPTH: usize = 64;

impl fmt::Debug for Bound {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let arc = Arc::new(self.shallow_clone());
for data in arc.verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH)) {
if data.depth == MAX_DISPLAY_DEPTH {
if data.n_children_yielded == 0 {
f.write_str("...")?;
}
continue;
}
match (&*data.node, data.n_children_yielded) {
(Bound::Free(ref s), _) => f.write_str(s)?,
(Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?,
(Bound::Sum(..), 0) | (Bound::Product(..), 0) => f.write_str("(")?,
(Bound::Sum(..), 2) | (Bound::Product(..), 2) => f.write_str(")")?,
(Bound::Sum(..), _) => f.write_str(" + ")?,
(Bound::Product(..), _) => f.write_str(" × ")?,
}
}
Ok(())
}
}

impl fmt::Display for Bound {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let arc = Arc::new(self.shallow_clone());
for data in arc.verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH)) {
if data.depth == MAX_DISPLAY_DEPTH {
if data.n_children_yielded == 0 {
f.write_str("...")?;
}
continue;
}
match (&*data.node, data.n_children_yielded) {
(Bound::Free(ref s), _) => f.write_str(s)?,
(Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?,
(Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
if data.index > 0 {
f.write_str("(")?;
}
}
(Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
if data.index > 0 {
f.write_str(")")?
}
}
(Bound::Sum(..), _) => f.write_str(" + ")?,
(Bound::Product(..), _) => f.write_str(" × ")?,
}
}
Ok(())
}
}

impl DagLike for Arc<Bound> {
type Node = Bound;
fn data(&self) -> &Bound {
self
}

fn as_dag_node(&self) -> Dag<Self> {
match **self {
Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
Dag::Binary(ty1.bound.root().get(), ty2.bound.root().get())
}
}
}
}

/// Source or target type of a Simplicity expression.
///
/// Internally this type is essentially just a refcounted pointer; it is
/// therefore quite cheap to clone, but be aware that cloning will not
/// actually create a new independent type, just a second pointer to the
/// first one.
#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct Type {
ctx: Context,
/// A set of constraints, which maintained by the union-bound algorithm and
/// is progressively tightened as type inference proceeds.
bound: UbElement<BoundRef>,
@@ -391,13 +339,15 @@ impl Type {
/// Return an unbound type with the given name
pub fn free(ctx: &Context, name: String) -> Self {
Type {
ctx: ctx.shallow_clone(),
bound: UbElement::new(ctx.alloc_free(name)),
}
}

/// Create the unit type.
pub fn unit(ctx: &Context) -> Self {
Type {
ctx: ctx.shallow_clone(),
bound: UbElement::new(ctx.alloc_unit()),
}
}
@@ -412,20 +362,23 @@ impl Type {
/// Create the sum of the given `left` and `right` types.
pub fn sum(ctx: &Context, left: Self, right: Self) -> Self {
Type {
ctx: ctx.shallow_clone(),
bound: UbElement::new(ctx.alloc_sum(left, right)),
}
}

/// Create the product of the given `left` and `right` types.
pub fn product(ctx: &Context, left: Self, right: Self) -> Self {
Type {
ctx: ctx.shallow_clone(),
bound: UbElement::new(ctx.alloc_product(left, right)),
}
}

/// Create a complete type.
pub fn complete(ctx: &Context, final_data: Arc<Final>) -> Self {
Type {
ctx: ctx.shallow_clone(),
bound: UbElement::new(ctx.alloc_complete(final_data)),
}
}
@@ -442,7 +395,7 @@ impl Type {
/// hint to the error.
///
/// Fails if the type has an existing incompatible bound.
fn bind(&self, bound: Arc<Bound>, hint: &'static str) -> Result<(), Error> {
fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> {
let root = self.bound.root();
root.bind(bound, hint)
}
@@ -452,13 +405,13 @@ impl Type {
/// Fails if the bounds on the two types are incompatible
fn unify(&self, other: &Self, hint: &'static str) -> Result<(), Error> {
self.bound.unify(&other.bound, |x_bound, y_bound| {
x_bound.bind(y_bound.get(), hint)
x_bound.bind(self.ctx.get(y_bound), hint)
})
}

/// Accessor for this type's bound
pub fn bound(&self) -> Arc<Bound> {
self.bound.root().get()
pub fn bound(&self) -> Bound {
self.ctx.get(&self.bound.root())
}

/// Accessor for the TMR of this type, if it is final
@@ -468,7 +421,7 @@ impl Type {

/// Accessor for the data of this type, if it is complete
pub fn final_data(&self) -> Option<Arc<Final>> {
if let Bound::Complete(ref data) = *self.bound.root().get() {
if let Bound::Complete(ref data) = self.bound() {
Some(Arc::clone(data))
} else {
None
@@ -481,55 +434,57 @@ impl Type {
/// complete, since its children may have been unified to a complete type. To
/// ensure a type is complete, call [`Type::finalize`].
pub fn is_final(&self) -> bool {
matches!(*self.bound.root().get(), Bound::Complete(..))
self.final_data().is_some()
}

/// Attempts to finalize the type. Returns its TMR on success.
pub fn finalize(&self) -> Result<Arc<Final>, Error> {
use context::OccursCheckId;

/// Helper type for the occurs-check.
enum OccursCheckStack {
Iterate(Arc<Bound>),
Complete(*const Bound),
Iterate(BoundRef),
Complete(OccursCheckId),
}

// Done with sharing tracker. Actual algorithm follows.
let root = self.bound.root();
let bound = root.get();
if let Bound::Complete(ref data) = *bound {
let bound = self.ctx.get(&root);
if let Bound::Complete(ref data) = bound {
return Ok(Arc::clone(data));
}

// First, do occurs-check to ensure that we have no infinitely sized types.
let mut stack = vec![OccursCheckStack::Iterate(Arc::clone(&bound))];
let mut stack = vec![OccursCheckStack::Iterate(root)];
let mut in_progress = HashSet::new();
let mut completed = HashSet::new();
while let Some(top) = stack.pop() {
let bound = match top {
OccursCheckStack::Complete(ptr) => {
in_progress.remove(&ptr);
completed.insert(ptr);
OccursCheckStack::Complete(id) => {
in_progress.remove(&id);
completed.insert(id);
continue;
}
OccursCheckStack::Iterate(b) => b,
};

let ptr = bound.as_ref() as *const _;
if completed.contains(&ptr) {
let id = bound.occurs_check_id();
if completed.contains(&id) {
// Once we have iterated through a type, we don't need to check it again.
// Without this shortcut the occurs-check would take exponential time.
continue;
}
if !in_progress.insert(ptr) {
if !in_progress.insert(id) {
return Err(Error::OccursCheck {
infinite_bound: bound,
infinite_bound: self.ctx.get(&bound),
});
}

stack.push(OccursCheckStack::Complete(ptr));
if let Some(child) = bound.right_child() {
stack.push(OccursCheckStack::Complete(id));
if let Some((_, child)) = (&self.ctx, bound.shallow_clone()).right_child() {
stack.push(OccursCheckStack::Iterate(child));
}
if let Some(child) = bound.left_child() {
if let Some((_, child)) = (&self.ctx, bound).left_child() {
stack.push(OccursCheckStack::Iterate(child));
}
}
@@ -539,8 +494,8 @@ impl Type {
let mut finalized = vec![];
for data in self.shallow_clone().post_order_iter::<NoSharing>() {
let bound = data.node.bound.root();
let bound_get = bound.get();
let final_data = match *bound_get {
let bound_get = self.ctx.get(&bound);
let final_data = match bound_get {
Bound::Free(_) => Final::unit(),
Bound::Complete(ref arc) => Arc::clone(arc),
Bound::Sum(..) => Final::sum(
@@ -553,9 +508,9 @@ impl Type {
),
};

if !matches!(*bound_get, Bound::Complete(..)) {
// set() ok because we are if-guarded on this variable not being complete
bound.set(Arc::new(Bound::Complete(Arc::clone(&final_data))));
if !matches!(bound_get, Bound::Complete(..)) {
self.ctx
.reassign_non_complete(bound, Bound::Complete(Arc::clone(&final_data)));
}
finalized.push(final_data);
}
@@ -576,9 +531,71 @@ impl Type {
}
}

const MAX_DISPLAY_DEPTH: usize = 64;

impl fmt::Debug for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for data in (&self.ctx, self.bound.root())
.verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH))
{
if data.depth == MAX_DISPLAY_DEPTH {
if data.n_children_yielded == 0 {
f.write_str("...")?;
}
continue;
}
let bound = data.node.0.get(&data.node.1);
match (bound, data.n_children_yielded) {
(Bound::Free(ref s), _) => f.write_str(s)?,
(Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?,
(Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
if data.index > 0 {
f.write_str("(")?;
}
}
(Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
if data.index > 0 {
f.write_str(")")?
}
}
(Bound::Sum(..), _) => f.write_str(" + ")?,
(Bound::Product(..), _) => f.write_str(" × ")?,
}
}
Ok(())
}
}

impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.bound.root().get(), f)
for data in (&self.ctx, self.bound.root())
.verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH))
{
if data.depth == MAX_DISPLAY_DEPTH {
if data.n_children_yielded == 0 {
f.write_str("...")?;
}
continue;
}
let bound = data.node.0.get(&data.node.1);
match (bound, data.n_children_yielded) {
(Bound::Free(ref s), _) => f.write_str(s)?,
(Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?,
(Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
if data.index > 0 {
f.write_str("(")?;
}
}
(Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
if data.index > 0 {
f.write_str(")")?
}
}
(Bound::Sum(..), _) => f.write_str(" + ")?,
(Bound::Product(..), _) => f.write_str(" × ")?,
}
}
Ok(())
}
}

@@ -589,7 +606,7 @@ impl DagLike for Type {
}

fn as_dag_node(&self) -> Dag<Self> {
match *self.bound.root().get() {
match self.bound() {
Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
Dag::Binary(ty1.shallow_clone(), ty2.shallow_clone())