Skip to content

Commit 05f2431

Browse files
Add Tr-Tr merging in with_huffman_tree_eff
Add generic TapTree cost function
1 parent af6a3c0 commit 05f2431

File tree

1 file changed

+106
-65
lines changed

1 file changed

+106
-65
lines changed

src/policy/concrete.rs

+106-65
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ use {Error, ForEach, ForEachKey, MiniscriptKey};
4444
#[cfg(feature = "compiler")]
4545
type PolicyTapCache<Pk> = BTreeMap<TapTree<Pk>, (Policy<Pk>, f64)>;
4646

47+
/// [`Miniscript`] -> leaf probability in policy cache
48+
#[cfg(feature = "compiler")]
49+
type MsTapCache<Pk> = BTreeMap<TapTree<Pk>, f64>;
50+
4751
/// Concrete policy which corresponds directly to a Miniscript structure,
4852
/// and whose disjunctions are annotated with satisfaction probabilities
4953
/// to assist the compiler
@@ -174,83 +178,116 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
174178
Ok(node)
175179
}
176180

177-
/// [`TapTree::Leaf`] average satisfaction cost + script size
181+
/// Average satisfaction cost for [`TapTree`]
178182
#[cfg(feature = "compiler")]
179-
fn tr_node_cost(ms: &Arc<Miniscript<Pk, Tap>>, prob: f64, cost: &f64) -> OrdF64 {
180-
OrdF64(prob * (ms.script_size() as f64 + cost))
183+
fn taptree_cost(
184+
tr: &TapTree<Pk>,
185+
ms_cache: &MsTapCache<Pk>,
186+
policy_cache: &PolicyTapCache<Pk>,
187+
depth: u32,
188+
) -> f64 {
189+
match *tr {
190+
TapTree::Tree(ref l, ref r) => {
191+
Self::taptree_cost(l, ms_cache, policy_cache, depth + 1)
192+
+ Self::taptree_cost(r, ms_cache, policy_cache, depth + 1)
193+
}
194+
TapTree::Leaf(ref ms) => {
195+
let prob = ms_cache
196+
.get(&TapTree::Leaf(Arc::clone(ms)))
197+
.expect("Probability should exist for the given ms");
198+
let sat_cost = policy_cache
199+
.get(&TapTree::Leaf(Arc::clone(ms)))
200+
.expect("Cost should exist for the given ms")
201+
.1;
202+
prob * (ms.script_size() as f64 + sat_cost + 32.0 * depth as f64)
203+
}
204+
}
181205
}
182206

183207
/// Create a [`TapTree`] from the miniscript as leaf nodes
184208
#[cfg(feature = "compiler")]
185209
fn with_huffman_tree_eff(
186-
ms: Vec<(OrdF64, (Arc<Miniscript<Pk, Tap>>, f64))>,
210+
ms: Vec<Arc<Miniscript<Pk, Tap>>>,
187211
policy_cache: &mut PolicyTapCache<Pk>,
212+
ms_cache: &mut MsTapCache<Pk>,
188213
) -> Result<TapTree<Pk>, Error> {
189214
let mut node_weights = BinaryHeap::<(Reverse<OrdF64>, OrdF64, TapTree<Pk>)>::new(); // (cost, branch_prob, tree)
190-
for (prob, script) in ms {
191-
let wt = Self::tr_node_cost(&script.0, prob.0, &script.1);
192-
node_weights.push((Reverse(wt), prob, TapTree::Leaf(Arc::clone(&script.0))));
215+
for script in ms {
216+
let wt = OrdF64(Self::taptree_cost(
217+
&TapTree::Leaf(Arc::clone(&script)),
218+
ms_cache,
219+
policy_cache,
220+
0,
221+
));
222+
let prob = OrdF64(
223+
*ms_cache
224+
.get(&TapTree::Leaf(Arc::clone(&script)))
225+
.expect("Probability should exist for the given ms"),
226+
);
227+
node_weights.push((Reverse(wt), prob, TapTree::Leaf(Arc::clone(&script))));
193228
}
194229
if node_weights.is_empty() {
195230
return Err(errstr("Empty Miniscript compilation"));
196231
}
197232
while node_weights.len() > 1 {
198-
let (prev_cost1, p1, s1) = node_weights.pop().expect("len must atleast be two");
199-
let (prev_cost2, p2, s2) = node_weights.pop().expect("len must atleast be two");
200-
201-
match (s1, s2) {
202-
(TapTree::Leaf(ms1), TapTree::Leaf(ms2)) => {
203-
// Retrieve the respective policies
204-
let (left_pol, _c1) = policy_cache
205-
.get(&TapTree::Leaf(Arc::clone(&ms1)))
206-
.ok_or_else(|| errstr("No corresponding policy found"))?
207-
.clone();
208-
209-
let (right_pol, _c2) = policy_cache
210-
.get(&TapTree::Leaf(Arc::clone(&ms2)))
211-
.ok_or_else(|| errstr("No corresponding policy found"))?
212-
.clone();
213-
214-
let parent_policy = Policy::Or(vec![
215-
((p1.0 * 1e4).round() as usize, left_pol),
216-
((p2.0 * 1e4).round() as usize, right_pol),
217-
]);
218-
219-
let (parent_compilation, cost) =
220-
compiler::best_compilation_sat::<Pk, Tap>(&parent_policy)?;
221-
222-
let parent_cost = Self::tr_node_cost(&parent_compilation, p1.0 + p2.0, &cost);
223-
let children_cost =
224-
OrdF64((prev_cost1.0).0 + (prev_cost2.0).0 + 32. * (p1.0 + p2.0));
225-
226-
policy_cache.remove(&TapTree::Leaf(Arc::clone(&ms1)));
227-
policy_cache.remove(&TapTree::Leaf(Arc::clone(&ms2)));
228-
let p = p1.0 + p2.0;
229-
node_weights.push(if parent_cost > children_cost {
230-
(
231-
Reverse(children_cost),
232-
OrdF64(p),
233-
TapTree::Tree(
234-
Arc::from(TapTree::Leaf(ms1)),
235-
Arc::from(TapTree::Leaf(ms2)),
236-
),
237-
)
238-
} else {
239-
let node = TapTree::Leaf(Arc::from(parent_compilation));
240-
policy_cache.insert(node.clone(), (parent_policy, parent_cost.0));
241-
(Reverse(parent_cost), OrdF64(p), node)
242-
});
243-
}
244-
(ms1, ms2) => {
245-
let p = p1.0 + p2.0;
246-
let cost = OrdF64((prev_cost1.0).0 + (prev_cost2.0).0 + 32.0);
247-
node_weights.push((
248-
Reverse(cost),
249-
OrdF64(p),
250-
TapTree::Tree(Arc::from(ms1), Arc::from(ms2)),
251-
));
252-
}
253-
}
233+
let (_prev_cost1, p1, ms1) = node_weights.pop().expect("len must atleast be two");
234+
let (_prev_cost2, p2, ms2) = node_weights.pop().expect("len must atleast be two");
235+
236+
// Retrieve the respective policies
237+
let (left_pol, _c1) = policy_cache
238+
.get(&ms1)
239+
.ok_or_else(|| errstr("No corresponding policy found"))?
240+
.clone();
241+
242+
let (right_pol, _c2) = policy_cache
243+
.get(&ms2)
244+
.ok_or_else(|| errstr("No corresponding policy found"))?
245+
.clone();
246+
247+
let parent_policy = Policy::Or(vec![
248+
((p1.0 * 1e4).round() as usize, left_pol),
249+
((p2.0 * 1e4).round() as usize, right_pol),
250+
]);
251+
252+
let (parent_compilation, sat_cost) =
253+
compiler::best_compilation_sat::<Pk, Tap>(&parent_policy)?;
254+
255+
let p = p1.0 + p2.0;
256+
ms_cache.insert(TapTree::Leaf(Arc::clone(&parent_compilation)), p);
257+
policy_cache.insert(
258+
TapTree::Leaf(Arc::clone(&parent_compilation)),
259+
(parent_policy.clone(), sat_cost),
260+
);
261+
262+
let parent_cost = OrdF64(Self::taptree_cost(
263+
&TapTree::Leaf(Arc::clone(&parent_compilation)),
264+
ms_cache,
265+
policy_cache,
266+
0,
267+
));
268+
let children_cost = OrdF64(
269+
Self::taptree_cost(&ms1, ms_cache, policy_cache, 0)
270+
+ Self::taptree_cost(&ms2, ms_cache, policy_cache, 0),
271+
);
272+
273+
node_weights.push(if parent_cost > children_cost {
274+
ms_cache.insert(
275+
TapTree::Tree(Arc::from(ms1.clone()), Arc::from(ms2.clone())),
276+
p,
277+
);
278+
policy_cache.insert(
279+
TapTree::Tree(Arc::from(ms1.clone()), Arc::from(ms2.clone())),
280+
(parent_policy, sat_cost),
281+
);
282+
(
283+
Reverse(children_cost),
284+
OrdF64(p),
285+
TapTree::Tree(Arc::from(ms1), Arc::from(ms2)),
286+
)
287+
} else {
288+
let node = TapTree::Leaf(Arc::from(parent_compilation));
289+
(Reverse(parent_cost), OrdF64(p), node)
290+
});
254291
}
255292
debug_assert!(node_weights.len() == 1);
256293
let node = node_weights
@@ -301,6 +338,7 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
301338
#[cfg(feature = "compiler")]
302339
fn compile_tr_efficient(&self) -> Result<TapTree<Pk>, Error> {
303340
let mut policy_cache = PolicyTapCache::<Pk>::new();
341+
let mut ms_cache = MsTapCache::<Pk>::new();
304342
let leaf_compilations: Vec<_> = self
305343
.to_tapleaf_prob_vec(1.0)
306344
.into_iter()
@@ -311,10 +349,13 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
311349
TapTree::Leaf(Arc::clone(&compilation.0)),
312350
(policy.clone(), compilation.1), // (policy, sat_cost)
313351
);
314-
(OrdF64(prob), compilation) // (branch_prob, comp=(ms, sat_cost))
352+
ms_cache.insert(TapTree::Leaf(Arc::from(compilation.0.clone())), prob);
353+
compilation.0 // (branch_prob, comp=(ms, sat_cost))
315354
})
316355
.collect();
317-
let taptree = Self::with_huffman_tree_eff(leaf_compilations, &mut policy_cache).unwrap();
356+
let taptree =
357+
Self::with_huffman_tree_eff(leaf_compilations, &mut policy_cache, &mut ms_cache)
358+
.unwrap();
318359
Ok(taptree)
319360
}
320361

0 commit comments

Comments
 (0)