6
6
from functools import partial
7
7
from typing import Union , cast
8
8
9
- from pytensor .compile .function import function
10
- from pytensor .compile .function .pfunc import rebuild_collect_shared
9
+ from pytensor .compile import get_default_mode , insert_deepcopy
10
+ from pytensor .compile .function .pfunc import pfunc , rebuild_collect_shared
11
+ from pytensor .compile .function .types import add_supervisor_to_fgraph
12
+ from pytensor .compile .io import In , Out
13
+ from pytensor .compile .mode import Mode
11
14
from pytensor .compile .sharedvalue import SharedVariable
12
15
from pytensor .configdefaults import config
13
16
from pytensor .gradient import DisconnectedType , Rop , grad
@@ -433,6 +436,7 @@ def __init__(
433
436
assert isinstance (name , str ), "name must be None or string object"
434
437
self .name = name
435
438
self .destroy_map = destroy_map if destroy_map is not None else {}
439
+ self ._prepared_fgraph = None
436
440
437
441
def __eq__ (self , other ):
438
442
# TODO: recognize a copy
@@ -847,14 +851,51 @@ def infer_shape(self, fgraph, node, shapes):
847
851
848
852
return ret
849
853
854
+ def _prepare_fgraph (self , impl ):
855
+ if self ._prepared_fgraph is None :
856
+ mode = get_default_mode ()
857
+ if impl == "py" :
858
+ mode = mode .excluding ("cxx" )
859
+ rewriter = mode .optimizer
860
+
861
+ # We are cloning fgraph too many times, but one of the existing tests checks for this
862
+ # TestOpFromGraph.test_outputs_consistency
863
+ fgraph = self .fgraph .clone ()
864
+ self ._wrapped_inputs = [
865
+ In (inp , borrow = False , mutable = False ) for inp in fgraph .inputs
866
+ ]
867
+ # These are just temporary because the graph rewirite may change them
868
+ temp_wrapped_outputs = [
869
+ Out (out , borrow = True ) for out in self .fgraph .outputs
870
+ ]
871
+ add_supervisor_to_fgraph (
872
+ fgraph ,
873
+ self ._wrapped_inputs ,
874
+ accept_inplace = False ,
875
+ )
876
+ rewriter (fgraph )
877
+ insert_deepcopy (fgraph , self ._wrapped_inputs , temp_wrapped_outputs )
878
+ self ._wrapped_outputs = [Out (out , borrow = True ) for out in fgraph .outputs ]
879
+ self ._prepared_fgraph = fgraph
880
+
881
+ return self ._prepared_fgraph , self ._wrapped_inputs , self ._wrapped_outputs
882
+
850
883
@property
851
884
def fn (self ):
852
- """Lazily compile the inner function graph."""
853
885
if getattr (self , "_fn" , None ) is not None :
854
886
return self ._fn
855
887
856
- self ._fn = function (self .inner_inputs , self .inner_outputs , ** self .kwargs )
857
- self ._fn .trust_input = True
888
+ fgraph , wrapped_inputs , wrapped_outputs = self ._prepare_fgraph (impl = None )
889
+
890
+ self ._fn = pfunc (
891
+ wrapped_inputs ,
892
+ wrapped_outputs ,
893
+ mode = Mode (linker = get_default_mode ().linker , optimizer = None ),
894
+ accept_inplace = True ,
895
+ on_unused_input = "ignore" ,
896
+ fgraph = fgraph ,
897
+ trust_input = True ,
898
+ )
858
899
859
900
return self ._fn
860
901
@@ -871,6 +912,40 @@ def clone(self):
871
912
res .fgraph = res .fgraph .clone ()
872
913
return res
873
914
915
+ def make_thunk (self , node , storage_map , compute_map , no_recycling , impl = None ):
916
+ from pytensor .link .c .basic import CLinker
917
+ from pytensor .link .vm import VMLinker
918
+
919
+ fg , _ , _ = self ._prepare_fgraph (impl )
920
+ fg_no_recycling = [
921
+ new_o
922
+ for (new_o , old_o ) in zip (fg .outputs , node .outputs , strict = True )
923
+ if old_o in no_recycling
924
+ ]
925
+
926
+ node_input_storage = [storage_map [r ] for r in node .inputs ]
927
+ node_output_storage = [storage_map [r ] for r in node .outputs ]
928
+
929
+ def create_thunk (linker ):
930
+ linker .accept (fg , no_recycling = fg_no_recycling )
931
+ thunk , i , o = linker .make_thunk (
932
+ input_storage = node_input_storage ,
933
+ output_storage = node_output_storage ,
934
+ )
935
+ return thunk
936
+
937
+ if impl != "py" :
938
+ try :
939
+ # We default to CLinker because it generates code for the whole graph that the compiler can reason about.
940
+ # Whereas the VMLinker will compile each node separately and call them in a pre-defined VM.
941
+ # It also has less overhead
942
+ return create_thunk (linker = CLinker ())
943
+ except NotImplementedError :
944
+ # Some Op doesn't have a C implementation, VM it is
945
+ return create_thunk (VMLinker (use_cloop = True , c_thunks = True ))
946
+ else :
947
+ return create_thunk (VMLinker (use_cloop = False , c_thunks = False ))
948
+
874
949
def perform (self , node , inputs , outputs ):
875
950
variables = self .fn (* inputs )
876
951
assert len (variables ) == len (outputs )
0 commit comments