@@ -127,8 +127,8 @@ def map_prefix(
127
127
) -> tp .Any : ...
128
128
129
129
def check_consistent_aliasing (
130
- node : tuple [ tp .Any , ...] ,
131
- prefix : tuple [ tp .Any , ...] ,
130
+ node : tp .Any ,
131
+ prefix : tp .Any ,
132
132
/ ,
133
133
* ,
134
134
node_prefixes : dict [tp .Any , list [tuple [PathParts , tp .Any ]]] | None = None ,
@@ -279,7 +279,9 @@ def to_tree(
279
279
with graph .split_context (ctxtag ) as split_ctx :
280
280
return jax .tree .map (
281
281
lambda x : split_fn (split_ctx , (), prefix , x )
282
- if map_non_graph_nodes or graph .is_graph_node (x )
282
+ if map_non_graph_nodes
283
+ or graph .is_graph_node (x )
284
+ or isinstance (x , variablelib .Variable )
283
285
else x ,
284
286
tree ,
285
287
)
@@ -296,7 +298,7 @@ def to_tree(
296
298
297
299
with graph .split_context (ctxtag ) as split_ctx :
298
300
for (keypath , leaf ), leaf_prefix in zip (leaf_keys , leaf_prefixes ):
299
- if graph .is_graph_node (leaf ):
301
+ if graph .is_graph_node (leaf ) or isinstance ( leaf , variablelib . Variable ) :
300
302
if check_aliasing :
301
303
check_consistent_aliasing (
302
304
leaf , leaf_prefix , node_prefixes = node_prefixes
@@ -343,7 +345,9 @@ def from_tree(
343
345
with graph .merge_context (is_inner , ctxtag ) as merge_ctx :
344
346
return jax .tree .map (
345
347
lambda x : merge_fn (merge_ctx , (), prefix , x )
346
- if map_non_graph_nodes or is_node_leaf (x )
348
+ if map_non_graph_nodes
349
+ or is_node_leaf (x )
350
+ or isinstance (x , variablelib .Variable )
347
351
else x ,
348
352
tree ,
349
353
is_leaf = is_leaf ,
@@ -362,12 +366,21 @@ def from_tree(
362
366
363
367
with graph .merge_context (is_inner , ctxtag ) as merge_ctx :
364
368
for (keypath , leaf ), leaf_prefix in zip (leaf_keys , leaf_prefixes ):
365
- if map_non_graph_nodes or is_node_leaf (leaf ):
369
+ if (
370
+ map_non_graph_nodes
371
+ or is_node_leaf (leaf )
372
+ or isinstance (leaf , variablelib .Variable )
373
+ ):
366
374
leaf = merge_fn (merge_ctx , keypath , leaf_prefix , leaf )
367
375
leaves_out .append (leaf )
368
376
369
377
pytree_out = jax .tree .unflatten (treedef , leaves_out )
370
378
return pytree_out
371
379
372
380
def clear_non_graph_nodes (tree ):
373
- return jax .tree .map (lambda x : x if graph .is_graph_node (x ) else None , tree )
381
+ return jax .tree .map (
382
+ lambda x : x
383
+ if graph .is_graph_node (x ) or isinstance (x , variablelib .Variable )
384
+ else None ,
385
+ tree ,
386
+ )
0 commit comments