Skip to content

Commit 8c17b94

Browse files
committed
Merge #231: types: use slab allocator for type bounds
c22dbaa types: drop BoundMutex and instead use references into the type context slab (Andrew Poelstra) 2e24c49 types: pull unify and bind into inference context (Andrew Poelstra) a26cf7a types: remove set and get methods from BoundRef (Andrew Poelstra) 33f58fa types: introduce BoundRef type, use in place of Arc<BoundMutex> in union-bound (Andrew Poelstra) 021316c types: abstract pointer type in union-bound algorithm (Andrew Poelstra) eccc332 types: add &Context to recursive type constructors (Andrew Poelstra) 65b35a9 types: add &Context to type constructors (Andrew Poelstra) 8e08900 types: make `bind` and `unify` go through Context (Andrew Poelstra) 8eeab8f types: introduce inference context object, thread it through the API (Andrew Poelstra) 9b0790e cmr: pull Constructible impl on Cmr into an impl on an auxiliary type (Andrew Poelstra) Pull request description: Our existing type inference engine assumes a "global" set of type bounds, which has two bad consequences: one is that if you are constructing multiple programs, there is no way to "firewall" their type bounds so that you cannot accidentally combine type variables from one program with type variables from another. You just need to be careful. The other consequence is that if you construct infinitely sized types, which are represented as a reference cycle, the existing inference engine will leak memory. To fix this, we need to stop allocating type bounds using untethered `Arc`s and instead use a slab allocator, which allows all bounds to be dropped at once, regardless of their circularity. This should also improve memory locality and our speed, as well as reducing the total amount of locking and potential mutex contention if type inference is done in a multithreaded context. This is a 2000-line diff but the vast majority of the changes are "API-only" stuff where I was moving types around and threading new parameters through dozens or hundreds of call sites. I did my best to break everything up into commits such that the big-diff commits don't do much of anything and the real changes happen in the small-diff ones to make review easier. By itself, this PR does **not** fix the issue of reference cycles, because it includes an `Arc<Context>` inside the recursive `Type` type itself. Future PRs will: * Take a single mutex lock during calls to the top-level `bind` and `unify` calls, so that these all happen atomically, including all recursive calls. * Add another intermediate type under `Type` which eliminates the `Arc<Context>` and its potential for circular references. Along the way, make the `Bound` type private, which is not really used outside of the types module anyway. * Do "checkpointing" during type inference that makes node construction atomic; this is #226 which is **not** fixed by this PR. * (Maybe) move node allocation into the type inference context so that nodes can be slab-allocated as well, which will address #229 "for free" without us figuring out a non-recursive `Drop` impl for `Arc<Node<N>>`. ACKs for top commit: uncomputable: ACK c22dbaa Tree-SHA512: 0fd2fdd9fe3634068d67279d517573df04fafa60b70e432f59417880982ad22e893822362973f946f1deb6279080aec1efdd942dfd8adad81bbddc7d55077336
2 parents b55b8e7 + c22dbaa commit 8c17b94

25 files changed

+1144
-595
lines changed

jets-bench/benches/elements/data_structures.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
use bitcoin::secp256k1;
55
use elements::Txid;
66
use rand::{thread_rng, RngCore};
7-
pub use simplicity::hashes::sha256;
87
use simplicity::{
9-
bitcoin, elements, hashes::Hash, hex::FromHex, types::Type, BitIter, Error, Value,
8+
bitcoin, elements, hashes::Hash, hex::FromHex, types::{self, Type}, BitIter, Error, Value,
109
};
1110
use std::sync::Arc;
1211

@@ -57,7 +56,8 @@ pub fn var_len_buf_from_slice(v: &[u8], mut n: usize) -> Result<Arc<Value>, Erro
5756
assert!(n < 16);
5857
assert!(v.len() < (1 << (n + 1)));
5958
let mut iter = BitIter::new(v.iter().copied());
60-
let types = Type::powers_of_two(n); // size n + 1
59+
let ctx = types::Context::new();
60+
let types = Type::powers_of_two(&ctx, n); // size n + 1
6161
let mut res = None;
6262
while n > 0 {
6363
let v = if v.len() >= (1 << (n + 1)) {

jets-bench/benches/elements/main.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ impl ElementsBenchEnvType {
9393
}
9494

9595
fn jet_arrow(jet: Elements) -> (Arc<types::Final>, Arc<types::Final>) {
96-
let src_ty = jet.source_ty().to_type().final_data().unwrap();
97-
let tgt_ty = jet.target_ty().to_type().final_data().unwrap();
96+
let src_ty = jet.source_ty().to_final();
97+
let tgt_ty = jet.target_ty().to_final();
9898
(src_ty, tgt_ty)
9999
}
100100

@@ -302,7 +302,7 @@ fn bench(c: &mut Criterion) {
302302
let keypair = bitcoin::key::Keypair::new(&secp_ctx, &mut thread_rng());
303303
let xpk = bitcoin::key::XOnlyPublicKey::from_keypair(&keypair);
304304

305-
let msg = bitcoin::secp256k1::Message::from_slice(&rand::random::<[u8; 32]>()).unwrap();
305+
let msg = bitcoin::secp256k1::Message::from_digest_slice(&rand::random::<[u8; 32]>()).unwrap();
306306
let sig = secp_ctx.sign_schnorr(&msg, &keypair);
307307
let xpk_value = Value::u256_from_slice(&xpk.0.serialize());
308308
let sig_value = Value::u512_from_slice(sig.as_ref());

src/bit_encoding/bitwriter.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,13 @@ mod tests {
117117
use super::*;
118118
use crate::jet::Core;
119119
use crate::node::CoreConstructible;
120+
use crate::types;
120121
use crate::ConstructNode;
121122
use std::sync::Arc;
122123

123124
#[test]
124125
fn vec() {
125-
let program = Arc::<ConstructNode<Core>>::unit();
126+
let program = Arc::<ConstructNode<Core>>::unit(&types::Context::new());
126127
let _ = write_to_vec(|w| program.encode(w));
127128
}
128129

src/bit_encoding/decode.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::node::{
1212
ConstructNode, CoreConstructible, DisconnectConstructible, JetConstructible, NoWitness,
1313
WitnessConstructible,
1414
};
15+
use crate::types;
1516
use crate::{BitIter, FailEntropy, Value};
1617
use std::collections::HashSet;
1718
use std::sync::Arc;
@@ -178,6 +179,7 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
178179
return Err(Error::TooManyNodes(len));
179180
}
180181

182+
let inference_context = types::Context::new();
181183
let mut nodes = Vec::with_capacity(len);
182184
for _ in 0..len {
183185
let new_node = decode_node(bits, nodes.len())?;
@@ -195,8 +197,8 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
195197
}
196198

197199
let new = match nodes[data.node.0] {
198-
DecodeNode::Unit => Node(ArcNode::unit()),
199-
DecodeNode::Iden => Node(ArcNode::iden()),
200+
DecodeNode::Unit => Node(ArcNode::unit(&inference_context)),
201+
DecodeNode::Iden => Node(ArcNode::iden(&inference_context)),
200202
DecodeNode::InjL(i) => Node(ArcNode::injl(converted[i].get()?)),
201203
DecodeNode::InjR(i) => Node(ArcNode::injr(converted[i].get()?)),
202204
DecodeNode::Take(i) => Node(ArcNode::take(converted[i].get()?)),
@@ -222,16 +224,16 @@ pub fn decode_expression<I: Iterator<Item = u8>, J: Jet>(
222224
converted[i].get()?,
223225
&Some(Arc::clone(converted[j].get()?)),
224226
)?),
225-
DecodeNode::Witness => Node(ArcNode::witness(NoWitness)),
226-
DecodeNode::Fail(entropy) => Node(ArcNode::fail(entropy)),
227+
DecodeNode::Witness => Node(ArcNode::witness(&inference_context, NoWitness)),
228+
DecodeNode::Fail(entropy) => Node(ArcNode::fail(&inference_context, entropy)),
227229
DecodeNode::Hidden(cmr) => {
228230
if !hidden_set.insert(cmr) {
229231
return Err(Error::SharingNotMaximal);
230232
}
231233
Hidden(cmr)
232234
}
233-
DecodeNode::Jet(j) => Node(ArcNode::jet(j)),
234-
DecodeNode::Word(ref w) => Node(ArcNode::const_word(Arc::clone(w))),
235+
DecodeNode::Jet(j) => Node(ArcNode::jet(&inference_context, j)),
236+
DecodeNode::Word(ref w) => Node(ArcNode::const_word(&inference_context, Arc::clone(w))),
235237
};
236238
converted.push(new);
237239
}

src/human_encoding/named_node.rs

+47-27
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::node::{
99
self, Commit, CommitData, CommitNode, Converter, Inner, NoDisconnect, NoWitness, Node, Witness,
1010
WitnessData,
1111
};
12-
use crate::node::{Construct, ConstructData, Constructible};
12+
use crate::node::{Construct, ConstructData, Constructible as _, CoreConstructible as _};
1313
use crate::types;
1414
use crate::types::arrow::{Arrow, FinalArrow};
1515
use crate::{encode, Value, WitnessNode};
@@ -116,6 +116,7 @@ impl<J: Jet> NamedCommitNode<J> {
116116
struct Populator<'a, J: Jet> {
117117
witness_map: &'a HashMap<Arc<str>, Arc<Value>>,
118118
disconnect_map: &'a HashMap<Arc<str>, Arc<NamedCommitNode<J>>>,
119+
inference_context: types::Context,
119120
phantom: PhantomData<J>,
120121
}
121122

@@ -153,17 +154,16 @@ impl<J: Jet> NamedCommitNode<J> {
153154
// Like witness nodes (see above), disconnect nodes may be pruned later.
154155
// The finalization will detect missing branches and throw an error.
155156
let maybe_commit = self.disconnect_map.get(hole_name);
156-
// FIXME: Recursive call of to_witness_node
157-
// We cannot introduce a stack
158-
// because we are implementing methods of the trait Converter
159-
// which are used Marker::convert().
157+
// FIXME: recursive call to convert
158+
// We cannot introduce a stack because we are implementing the Converter
159+
// trait and do not have access to the actual algorithm used for conversion
160+
// in order to save its state.
160161
//
161162
// OTOH, if a user writes a program with so many disconnected expressions
162163
// that there is a stack overflow, it's his own fault :)
163-
// This would fail in a fuzz test.
164-
let witness = maybe_commit.map(|commit| {
165-
commit.to_witness_node(self.witness_map, self.disconnect_map)
166-
});
164+
// This may fail in a fuzz test.
165+
let witness = maybe_commit
166+
.map(|commit| commit.convert::<InternalSharing, _, _>(self).unwrap());
167167
Ok(witness)
168168
}
169169
}
@@ -181,13 +181,15 @@ impl<J: Jet> NamedCommitNode<J> {
181181
let inner = inner
182182
.map(|node| node.cached_data())
183183
.map_witness(|maybe_value| maybe_value.clone());
184-
Ok(WitnessData::from_inner(inner).expect("types are already finalized"))
184+
Ok(WitnessData::from_inner(&self.inference_context, inner)
185+
.expect("types are already finalized"))
185186
}
186187
}
187188

188189
self.convert::<InternalSharing, _, _>(&mut Populator {
189190
witness_map: witness,
190191
disconnect_map: disconnect,
192+
inference_context: types::Context::new(),
191193
phantom: PhantomData,
192194
})
193195
.unwrap()
@@ -245,13 +247,15 @@ pub struct NamedConstructData<J> {
245247
impl<J: Jet> NamedConstructNode<J> {
246248
/// Construct a named construct node from parts.
247249
pub fn new(
250+
inference_context: &types::Context,
248251
name: Arc<str>,
249252
position: Position,
250253
user_source_types: Arc<[types::Type]>,
251254
user_target_types: Arc<[types::Type]>,
252255
inner: node::Inner<Arc<Self>, J, Arc<Self>, WitnessOrHole>,
253256
) -> Result<Self, types::Error> {
254257
let construct_data = ConstructData::from_inner(
258+
inference_context,
255259
inner
256260
.as_ref()
257261
.map(|data| &data.cached_data().internal)
@@ -295,6 +299,11 @@ impl<J: Jet> NamedConstructNode<J> {
295299
self.cached_data().internal.arrow()
296300
}
297301

302+
/// Accessor for the node's type inference context.
303+
pub fn inference_context(&self) -> &types::Context {
304+
self.cached_data().internal.inference_context()
305+
}
306+
298307
/// Finalizes the types of the underlying [`crate::ConstructNode`].
299308
pub fn finalize_types_main(&self) -> Result<Arc<NamedCommitNode<J>>, ErrorSet> {
300309
self.finalize_types_inner(true)
@@ -386,17 +395,23 @@ impl<J: Jet> NamedConstructNode<J> {
386395
.map_disconnect(|_| &NoDisconnect)
387396
.copy_witness();
388397

398+
let ctx = data.node.inference_context();
399+
389400
if !self.for_main {
390401
// For non-`main` fragments, treat the ascriptions as normative, and apply them
391402
// before finalizing the type.
392403
let arrow = data.node.arrow();
393404
for ty in data.node.cached_data().user_source_types.as_ref() {
394-
if let Err(e) = arrow.source.unify(ty, "binding source type annotation") {
405+
if let Err(e) =
406+
ctx.unify(&arrow.source, ty, "binding source type annotation")
407+
{
395408
self.errors.add(data.node.position(), e);
396409
}
397410
}
398411
for ty in data.node.cached_data().user_target_types.as_ref() {
399-
if let Err(e) = arrow.target.unify(ty, "binding target type annotation") {
412+
if let Err(e) =
413+
ctx.unify(&arrow.target, ty, "binding target type annotation")
414+
{
400415
self.errors.add(data.node.position(), e);
401416
}
402417
}
@@ -413,15 +428,19 @@ impl<J: Jet> NamedConstructNode<J> {
413428
if self.for_main {
414429
// For `main`, only apply type ascriptions *after* inference has completely
415430
// determined the type.
416-
let source_ty = types::Type::complete(Arc::clone(&commit_data.arrow().source));
431+
let source_ty =
432+
types::Type::complete(ctx, Arc::clone(&commit_data.arrow().source));
417433
for ty in data.node.cached_data().user_source_types.as_ref() {
418-
if let Err(e) = source_ty.unify(ty, "binding source type annotation") {
434+
if let Err(e) = ctx.unify(&source_ty, ty, "binding source type annotation")
435+
{
419436
self.errors.add(data.node.position(), e);
420437
}
421438
}
422-
let target_ty = types::Type::complete(Arc::clone(&commit_data.arrow().target));
439+
let target_ty =
440+
types::Type::complete(ctx, Arc::clone(&commit_data.arrow().target));
423441
for ty in data.node.cached_data().user_target_types.as_ref() {
424-
if let Err(e) = target_ty.unify(ty, "binding target type annotation") {
442+
if let Err(e) = ctx.unify(&target_ty, ty, "binding target type annotation")
443+
{
425444
self.errors.add(data.node.position(), e);
426445
}
427446
}
@@ -442,22 +461,23 @@ impl<J: Jet> NamedConstructNode<J> {
442461
};
443462

444463
if for_main {
445-
let unit_ty = types::Type::unit();
464+
let ctx = self.inference_context();
465+
let unit_ty = types::Type::unit(ctx);
446466
if self.cached_data().user_source_types.is_empty() {
447-
if let Err(e) = self
448-
.arrow()
449-
.source
450-
.unify(&unit_ty, "setting root source to unit")
451-
{
467+
if let Err(e) = ctx.unify(
468+
&self.arrow().source,
469+
&unit_ty,
470+
"setting root source to unit",
471+
) {
452472
finalizer.errors.add(self.position(), e);
453473
}
454474
}
455475
if self.cached_data().user_target_types.is_empty() {
456-
if let Err(e) = self
457-
.arrow()
458-
.target
459-
.unify(&unit_ty, "setting root source to unit")
460-
{
476+
if let Err(e) = ctx.unify(
477+
&self.arrow().target,
478+
&unit_ty,
479+
"setting root target to unit",
480+
) {
461481
finalizer.errors.add(self.position(), e);
462482
}
463483
}

src/human_encoding/parse/ast.rs

+18-7
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,25 @@ pub enum Type {
8282

8383
impl Type {
8484
/// Convert to a Simplicity type
85-
pub fn reify(self) -> types::Type {
85+
pub fn reify(self, ctx: &types::Context) -> types::Type {
8686
match self {
87-
Type::Name(s) => types::Type::free(s),
88-
Type::One => types::Type::unit(),
89-
Type::Two => types::Type::sum(types::Type::unit(), types::Type::unit()),
90-
Type::Product(left, right) => types::Type::product(left.reify(), right.reify()),
91-
Type::Sum(left, right) => types::Type::sum(left.reify(), right.reify()),
92-
Type::TwoTwoN(n) => types::Type::two_two_n(n as usize), // cast OK as we are only using tiny numbers
87+
Type::Name(s) => types::Type::free(ctx, s),
88+
Type::One => types::Type::unit(ctx),
89+
Type::Two => {
90+
let unit_ty = types::Type::unit(ctx);
91+
types::Type::sum(ctx, unit_ty.shallow_clone(), unit_ty)
92+
}
93+
Type::Product(left, right) => {
94+
let left = left.reify(ctx);
95+
let right = right.reify(ctx);
96+
types::Type::product(ctx, left, right)
97+
}
98+
Type::Sum(left, right) => {
99+
let left = left.reify(ctx);
100+
let right = right.reify(ctx);
101+
types::Type::sum(ctx, left, right)
102+
}
103+
Type::TwoTwoN(n) => types::Type::two_two_n(ctx, n as usize), // cast OK as we are only using tiny numbers
93104
}
94105
}
95106
}

src/human_encoding/parse/mod.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mod ast;
77
use crate::dag::{Dag, DagLike, InternalSharing};
88
use crate::jet::Jet;
99
use crate::node;
10-
use crate::types::Type;
10+
use crate::types::{self, Type};
1111
use std::collections::HashMap;
1212
use std::mem;
1313
use std::sync::atomic::{AtomicUsize, Ordering};
@@ -181,6 +181,7 @@ pub fn parse<J: Jet + 'static>(
181181
program: &str,
182182
) -> Result<HashMap<Arc<str>, Arc<NamedCommitNode<J>>>, ErrorSet> {
183183
let mut errors = ErrorSet::new();
184+
let inference_context = types::Context::new();
184185
// **
185186
// Step 1: Read expressions into HashMap, checking for dupes and illegal names.
186187
// **
@@ -205,10 +206,10 @@ pub fn parse<J: Jet + 'static>(
205206
}
206207
}
207208
if let Some(ty) = line.arrow.0 {
208-
entry.add_source_type(ty.reify());
209+
entry.add_source_type(ty.reify(&inference_context));
209210
}
210211
if let Some(ty) = line.arrow.1 {
211-
entry.add_target_type(ty.reify());
212+
entry.add_target_type(ty.reify(&inference_context));
212213
}
213214
}
214215

@@ -485,6 +486,7 @@ pub fn parse<J: Jet + 'static>(
485486
.unwrap_or_else(|| Arc::from(namer.assign_name(inner.as_ref()).as_str()));
486487

487488
let node = NamedConstructNode::new(
489+
&inference_context,
488490
Arc::clone(&name),
489491
data.node.position,
490492
Arc::clone(&data.node.user_source_types),

src/jet/elements/tests.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::sync::Arc;
55
use crate::jet::elements::{ElementsEnv, ElementsUtxo};
66
use crate::jet::Elements;
77
use crate::node::{ConstructNode, JetConstructible};
8+
use crate::types;
89
use crate::{BitMachine, Cmr, Value};
910
use elements::secp256k1_zkp::Tweak;
1011
use elements::taproot::ControlBlock;
@@ -99,7 +100,7 @@ fn test_ffi_env() {
99100
BlockHash::all_zeros(),
100101
);
101102

102-
let prog = Arc::<ConstructNode<_>>::jet(Elements::LockTime);
103+
let prog = Arc::<ConstructNode<_>>::jet(&types::Context::new(), Elements::LockTime);
103104
assert_eq!(
104105
BitMachine::test_exec(prog, &env).expect("executing"),
105106
Value::u32(100),

0 commit comments

Comments
 (0)