diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py index bb0805c4d5..116d66321a 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py @@ -68,10 +68,12 @@ def unflatten_tree(tree, xs): def map_tree_up_to(shallow, fn, tree, *rest): """`map_tree` with recursion depth defined by depth of `shallow`.""" - def wrapper(_, *rest): - return fn(*rest) + def wrapper(x, *rest): + return None if x is None else fn(*rest) - return tree_util.tree_map(wrapper, shallow, tree, *rest) + return tree_util.tree_map( + wrapper, shallow, tree, *rest, is_leaf=lambda x: x is None + ) def get_shallow_tree(is_leaf, tree):