Skip to content

Commit 2e24c49

Browse files
committed
types: pull unify and bind into inference context
Pulls the unify and bind methods out of Type and BoundMutex and implement them on a new private LockedContext struct. This locks the entire inference context for the duration of a bind or unify operation, and because it only locks inside of non-recursive methods, it is impossible to deadlock. This is "API-only" in the sense that the actual type bounds continue to be represented by free-floating Arcs, but it has a semantic change in that binds and unifications now happen atomically (due to the continuously held lock on the context) which fixes a likely class of bugs wherein if you try to unify related variables from multiple threads at once, the old code probably would due weird things, due to the very local locking and total lack of other synchronization. The next commit will finally delete BoundMutex, move the bounds into the actual context object, and you will see the point of all these massive code lifts :).
1 parent a26cf7a commit 2e24c49

File tree

2 files changed

+110
-103
lines changed

2 files changed

+110
-103
lines changed

src/types/context.rs

+108-8
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
//!
1616
1717
use std::fmt;
18-
use std::sync::{Arc, Mutex};
18+
use std::sync::{Arc, Mutex, MutexGuard};
1919

2020
use crate::dag::{Dag, DagLike};
2121

2222
use super::bound_mutex::BoundMutex;
23-
use super::{Bound, Error, Final, Type};
23+
use super::{Bound, CompleteBound, Error, Final, Type};
2424

2525
/// Type inference context, or handle to a context.
2626
///
@@ -155,14 +155,24 @@ impl Context {
155155
///
156156
/// Fails if the type has an existing incompatible bound.
157157
pub fn bind(&self, existing: &Type, new: Bound, hint: &'static str) -> Result<(), Error> {
158-
existing.bind(new, hint)
158+
let existing_root = existing.bound.root();
159+
let lock = self.lock();
160+
lock.bind(existing_root, new, hint)
159161
}
160162

161163
/// Unify the type with another one.
162164
///
163165
/// Fails if the bounds on the two types are incompatible
164166
pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> {
165-
ty1.unify(ty2, hint)
167+
let lock = self.lock();
168+
lock.unify(ty1, ty2, hint)
169+
}
170+
171+
/// Locks the underlying slab mutex.
172+
fn lock(&self) -> LockedContext {
173+
LockedContext {
174+
slab: self.slab.lock().unwrap(),
175+
}
166176
}
167177
}
168178

@@ -183,10 +193,6 @@ impl BoundRef {
183193
);
184194
}
185195

186-
pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> {
187-
self.index.bind(bound, hint)
188-
}
189-
190196
/// Creates an "occurs-check ID" which is just a copy of the [`BoundRef`]
191197
/// with `PartialEq` and `Eq` implemented in terms of underlying pointer
192198
/// equality.
@@ -238,3 +244,97 @@ pub struct OccursCheckId {
238244
// now we set it to an Arc<BoundMutex> to preserve semantics.
239245
index: *const BoundMutex,
240246
}
247+
248+
/// Structure representing an inference context with its slab allocator mutex locked.
249+
///
250+
/// This type is never exposed outside of this module and should only exist
251+
/// ephemerally within function calls into this module.
252+
struct LockedContext<'ctx> {
253+
slab: MutexGuard<'ctx, Vec<Bound>>,
254+
}
255+
256+
impl<'ctx> LockedContext<'ctx> {
257+
/// Unify the type with another one.
258+
///
259+
/// Fails if the bounds on the two types are incompatible
260+
fn unify(&self, existing: &Type, other: &Type, hint: &'static str) -> Result<(), Error> {
261+
existing.bound.unify(&other.bound, |x_bound, y_bound| {
262+
self.bind(x_bound, y_bound.index.get(), hint)
263+
})
264+
}
265+
266+
fn bind(&self, existing: BoundRef, new: Bound, hint: &'static str) -> Result<(), Error> {
267+
let existing_bound = existing.index.get();
268+
let bind_error = || Error::Bind {
269+
existing_bound: existing_bound.shallow_clone(),
270+
new_bound: new.shallow_clone(),
271+
hint,
272+
};
273+
274+
match (&existing_bound, &new) {
275+
// Binding a free type to anything is a no-op
276+
(_, Bound::Free(_)) => Ok(()),
277+
// Free types are simply dropped and replaced by the new bound
278+
(Bound::Free(_), _) => {
279+
// Free means non-finalized, so set() is ok.
280+
existing.index.set(new);
281+
Ok(())
282+
}
283+
// Binding complete->complete shouldn't ever happen, but if so, we just
284+
// compare the two types and return a pass/fail
285+
(Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
286+
if existing_final == new_final {
287+
Ok(())
288+
} else {
289+
Err(bind_error())
290+
}
291+
}
292+
// Binding an incomplete to a complete type requires recursion.
293+
(Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => {
294+
match (complete.bound(), incomplete) {
295+
// A unit might match a Bound::Free(..) or a Bound::Complete(..),
296+
// and both cases were handled above. So this is an error.
297+
(CompleteBound::Unit, _) => Err(bind_error()),
298+
(
299+
CompleteBound::Product(ref comp1, ref comp2),
300+
Bound::Product(ref ty1, ref ty2),
301+
)
302+
| (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => {
303+
let bound1 = ty1.bound.root();
304+
let bound2 = ty2.bound.root();
305+
self.bind(bound1, Bound::Complete(Arc::clone(comp1)), hint)?;
306+
self.bind(bound2, Bound::Complete(Arc::clone(comp2)), hint)
307+
}
308+
_ => Err(bind_error()),
309+
}
310+
}
311+
(Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
312+
| (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
313+
self.unify(x1, y1, hint)?;
314+
self.unify(x2, y2, hint)?;
315+
// This type was not complete, but it may be after unification, giving us
316+
// an opportunity to finaliize it. We do this eagerly to make sure that
317+
// "complete" (no free children) is always equivalent to "finalized" (the
318+
// bound field having variant Bound::Complete(..)), even during inference.
319+
//
320+
// It also gives the user access to more information about the type,
321+
// prior to finalization.
322+
if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) {
323+
existing
324+
.index
325+
.set(Bound::Complete(if let Bound::Sum(..) = existing_bound {
326+
Final::sum(data1, data2)
327+
} else {
328+
Final::product(data1, data2)
329+
}));
330+
}
331+
Ok(())
332+
}
333+
(x, y) => Err(Error::Bind {
334+
existing_bound: x.shallow_clone(),
335+
new_bound: y.shallow_clone(),
336+
hint,
337+
}),
338+
}
339+
}
340+
}

src/types/mod.rs

+2-95
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ impl fmt::Display for Error {
149149
impl std::error::Error for Error {}
150150

151151
mod bound_mutex {
152-
use super::{Bound, CompleteBound, Error, Final};
152+
use super::Bound;
153153
use std::fmt;
154-
use std::sync::{Arc, Mutex};
154+
use std::sync::Mutex;
155155

156156
/// Source or target type of a Simplicity expression
157157
pub struct BoundMutex {
@@ -184,81 +184,6 @@ mod bound_mutex {
184184
);
185185
*lock = new;
186186
}
187-
188-
pub fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> {
189-
let existing_bound = self.get();
190-
let bind_error = || Error::Bind {
191-
existing_bound: existing_bound.shallow_clone(),
192-
new_bound: bound.shallow_clone(),
193-
hint,
194-
};
195-
196-
match (&existing_bound, &bound) {
197-
// Binding a free type to anything is a no-op
198-
(_, Bound::Free(_)) => Ok(()),
199-
// Free types are simply dropped and replaced by the new bound
200-
(Bound::Free(_), _) => {
201-
// Free means non-finalized, so set() is ok.
202-
self.set(bound);
203-
Ok(())
204-
}
205-
// Binding complete->complete shouldn't ever happen, but if so, we just
206-
// compare the two types and return a pass/fail
207-
(Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
208-
if existing_final == new_final {
209-
Ok(())
210-
} else {
211-
Err(bind_error())
212-
}
213-
}
214-
// Binding an incomplete to a complete type requires recursion.
215-
(Bound::Complete(complete), incomplete)
216-
| (incomplete, Bound::Complete(complete)) => {
217-
match (complete.bound(), incomplete) {
218-
// A unit might match a Bound::Free(..) or a Bound::Complete(..),
219-
// and both cases were handled above. So this is an error.
220-
(CompleteBound::Unit, _) => Err(bind_error()),
221-
(
222-
CompleteBound::Product(ref comp1, ref comp2),
223-
Bound::Product(ref ty1, ref ty2),
224-
)
225-
| (
226-
CompleteBound::Sum(ref comp1, ref comp2),
227-
Bound::Sum(ref ty1, ref ty2),
228-
) => {
229-
ty1.bind(Bound::Complete(Arc::clone(comp1)), hint)?;
230-
ty2.bind(Bound::Complete(Arc::clone(comp2)), hint)
231-
}
232-
_ => Err(bind_error()),
233-
}
234-
}
235-
(Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
236-
| (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
237-
x1.unify(y1, hint)?;
238-
x2.unify(y2, hint)?;
239-
// This type was not complete, but it may be after unification, giving us
240-
// an opportunity to finaliize it. We do this eagerly to make sure that
241-
// "complete" (no free children) is always equivalent to "finalized" (the
242-
// bound field having variant Bound::Complete(..)), even during inference.
243-
//
244-
// It also gives the user access to more information about the type,
245-
// prior to finalization.
246-
if let (Some(data1), Some(data2)) = (y1.final_data(), y2.final_data()) {
247-
self.set(Bound::Complete(if let Bound::Sum(..) = bound {
248-
Final::sum(data1, data2)
249-
} else {
250-
Final::product(data1, data2)
251-
}));
252-
}
253-
Ok(())
254-
}
255-
(x, y) => Err(Error::Bind {
256-
existing_bound: x.shallow_clone(),
257-
new_bound: y.shallow_clone(),
258-
hint,
259-
}),
260-
}
261-
}
262187
}
263188
}
264189

@@ -391,24 +316,6 @@ impl Type {
391316
self.clone()
392317
}
393318

394-
/// Binds the type to a given bound. If this fails, attach the provided
395-
/// hint to the error.
396-
///
397-
/// Fails if the type has an existing incompatible bound.
398-
fn bind(&self, bound: Bound, hint: &'static str) -> Result<(), Error> {
399-
let root = self.bound.root();
400-
root.bind(bound, hint)
401-
}
402-
403-
/// Unify the type with another one.
404-
///
405-
/// Fails if the bounds on the two types are incompatible
406-
fn unify(&self, other: &Self, hint: &'static str) -> Result<(), Error> {
407-
self.bound.unify(&other.bound, |x_bound, y_bound| {
408-
x_bound.bind(self.ctx.get(y_bound), hint)
409-
})
410-
}
411-
412319
/// Accessor for this type's bound
413320
pub fn bound(&self) -> Bound {
414321
self.ctx.get(&self.bound.root())

0 commit comments

Comments
 (0)