Skip to content

Commit 99d37ce

Browse files
committed
Avoid PyTensor function overhead in OpFromGraph
Also provide pure C-implementation when all Ops allow it. LazyLinker does not complain about thunks that return outputs, since itself can be a thunk. Adding a Python wrapper that hides the outputs incurs considerable overhead, and modifying the LazyLinker to optionally not return outputs seems unnecessarily complex.
1 parent 676296c commit 99d37ce

File tree

5 files changed

+145
-40
lines changed

5 files changed

+145
-40
lines changed

pytensor/compile/builders.py

+80-5
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
from functools import partial
77
from typing import Union, cast
88

9-
from pytensor.compile.function import function
10-
from pytensor.compile.function.pfunc import rebuild_collect_shared
9+
from pytensor.compile import get_default_mode, insert_deepcopy
10+
from pytensor.compile.function.pfunc import pfunc, rebuild_collect_shared
11+
from pytensor.compile.function.types import add_supervisor_to_fgraph
12+
from pytensor.compile.io import In, Out
13+
from pytensor.compile.mode import Mode
1114
from pytensor.compile.sharedvalue import SharedVariable
1215
from pytensor.configdefaults import config
1316
from pytensor.gradient import DisconnectedType, Rop, grad
@@ -433,6 +436,7 @@ def __init__(
433436
assert isinstance(name, str), "name must be None or string object"
434437
self.name = name
435438
self.destroy_map = destroy_map if destroy_map is not None else {}
439+
self._prepared_fgraph = None
436440

437441
def __eq__(self, other):
438442
# TODO: recognize a copy
@@ -847,14 +851,51 @@ def infer_shape(self, fgraph, node, shapes):
847851

848852
return ret
849853

854+
def _prepare_fgraph(self, impl):
855+
if self._prepared_fgraph is None:
856+
mode = get_default_mode()
857+
if impl == "py":
858+
mode = mode.excluding("cxx")
859+
rewriter = mode.optimizer
860+
861+
# We are cloning fgraph too many times, but one of the existing tests checks for this
862+
# TestOpFromGraph.test_outputs_consistency
863+
fgraph = self.fgraph.clone()
864+
self._wrapped_inputs = [
865+
In(inp, borrow=False, mutable=False) for inp in fgraph.inputs
866+
]
867+
# These are just temporary because the graph rewirite may change them
868+
temp_wrapped_outputs = [
869+
Out(out, borrow=True) for out in self.fgraph.outputs
870+
]
871+
add_supervisor_to_fgraph(
872+
fgraph,
873+
self._wrapped_inputs,
874+
accept_inplace=False,
875+
)
876+
rewriter(fgraph)
877+
insert_deepcopy(fgraph, self._wrapped_inputs, temp_wrapped_outputs)
878+
self._wrapped_outputs = [Out(out, borrow=True) for out in fgraph.outputs]
879+
self._prepared_fgraph = fgraph
880+
881+
return self._prepared_fgraph, self._wrapped_inputs, self._wrapped_outputs
882+
850883
@property
851884
def fn(self):
852-
"""Lazily compile the inner function graph."""
853885
if getattr(self, "_fn", None) is not None:
854886
return self._fn
855887

856-
self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
857-
self._fn.trust_input = True
888+
fgraph, wrapped_inputs, wrapped_outputs = self._prepare_fgraph(impl=None)
889+
890+
self._fn = pfunc(
891+
wrapped_inputs,
892+
wrapped_outputs,
893+
mode=Mode(linker=get_default_mode().linker, optimizer=None),
894+
accept_inplace=True,
895+
on_unused_input="ignore",
896+
fgraph=fgraph,
897+
trust_input=True,
898+
)
858899

859900
return self._fn
860901

@@ -871,6 +912,40 @@ def clone(self):
871912
res.fgraph = res.fgraph.clone()
872913
return res
873914

915+
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
916+
from pytensor.link.c.basic import CLinker
917+
from pytensor.link.vm import VMLinker
918+
919+
fg, _, _ = self._prepare_fgraph(impl)
920+
fg_no_recycling = [
921+
new_o
922+
for (new_o, old_o) in zip(fg.outputs, node.outputs, strict=True)
923+
if old_o in no_recycling
924+
]
925+
926+
node_input_storage = [storage_map[r] for r in node.inputs]
927+
node_output_storage = [storage_map[r] for r in node.outputs]
928+
929+
def create_thunk(linker):
930+
linker.accept(fg, no_recycling=fg_no_recycling)
931+
thunk, i, o = linker.make_thunk(
932+
input_storage=node_input_storage,
933+
output_storage=node_output_storage,
934+
)
935+
return thunk
936+
937+
if impl != "py":
938+
try:
939+
# We default to CLinker because it generates code for the whole graph that the compiler can reason about.
940+
# Whereas the VMLinker will compile each node separately and call them in a pre-defined VM.
941+
# It also has less overhead
942+
return create_thunk(linker=CLinker())
943+
except NotImplementedError:
944+
# Some Op doesn't have a C implementation, VM it is
945+
return create_thunk(VMLinker(use_cloop=True, c_thunks=True))
946+
else:
947+
return create_thunk(VMLinker(use_cloop=False, c_thunks=False))
948+
874949
def perform(self, node, inputs, outputs):
875950
variables = self.fn(*inputs)
876951
assert len(variables) == len(outputs)

pytensor/link/c/c_code/lazylinker_c.c

+2-15
Original file line numberDiff line numberDiff line change
@@ -676,20 +676,7 @@ static int lazy_rec_eval(CLazyLinker *self, Py_ssize_t var_idx, PyObject *one,
676676
// rval is new ref
677677
if (rval) // pycall returned normally (no exception)
678678
{
679-
if (rval == Py_None) {
680-
Py_DECREF(rval); // ignore a return of None
681-
} else if (PyList_Check(rval)) {
682-
PyErr_SetString(PyExc_TypeError,
683-
"non-lazy thunk should return None, not list");
684-
err = 1;
685-
goto pyfail;
686-
} else // don't know what it returned, but it wasn't right.
687-
{
688-
PyErr_SetObject(PyExc_TypeError, rval);
689-
err = 1;
690-
// We don't release rval since we put it in the error above
691-
goto fail;
692-
}
679+
Py_DECREF(rval); // ignore whatever was returned
693680
} else // pycall returned NULL (internal error)
694681
{
695682
err = 1;
@@ -981,7 +968,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = {
981968
};
982969

983970
static PyObject *get_version(PyObject *dummy, PyObject *args) {
984-
PyObject *result = PyFloat_FromDouble(0.3);
971+
PyObject *result = PyFloat_FromDouble(0.4);
985972
return result;
986973
}
987974

pytensor/link/c/lazylinker_c.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
_logger = logging.getLogger(__file__)
1515

1616
force_compile = False
17-
version = 0.3 # must match constant returned in function get_version()
17+
version = 0.4 # must match constant returned in function get_version()
1818
lazylinker_ext: ModuleType | None = None
1919

2020

pytensor/tensor/rewriting/basic.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1120,15 +1120,11 @@ def unconditional_constant_folding(fgraph, node):
11201120
compute_map[o] = [False]
11211121

11221122
thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
1123-
required = thunk()
1124-
1125-
# A node whose inputs are all provided should always return successfully
1126-
assert not required
1123+
thunk()
11271124

11281125
rval = []
11291126
for output in node.outputs:
11301127
data = storage_map[output][0]
1131-
assert compute_map[output][0], (output, data)
11321128

11331129
# TODO: `Type` itself should provide an interface for constructing
11341130
# instances appropriate for a given constant.

tests/compile/test_builders.py

+61-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
import pytensor.tensor as pt
7+
from pytensor import scan
78
from pytensor.compile import shared
89
from pytensor.compile.builders import OpFromGraph
910
from pytensor.compile.function import function
@@ -15,9 +16,10 @@
1516
grad,
1617
verify_grad,
1718
)
18-
from pytensor.graph.basic import equal_computations
19+
from pytensor.graph.basic import Apply, equal_computations
1920
from pytensor.graph.fg import FunctionGraph
2021
from pytensor.graph.null_type import NullType, null_type
22+
from pytensor.graph.op import Op
2123
from pytensor.graph.rewriting.utils import rewrite_graph
2224
from pytensor.graph.utils import MissingInputError
2325
from pytensor.printing import debugprint
@@ -537,17 +539,6 @@ def test_infer_shape(self):
537539
assert opt_res.shape_feature.shape_of[x] is None
538540
assert opt_res.shape_feature.shape_of[z][0].data == 2
539541

540-
@config.change_flags(compute_test_value="raise")
541-
def test_compute_test_value(self):
542-
x = scalar("x")
543-
x.tag.test_value = np.array(1.0, dtype=config.floatX)
544-
op = OpFromGraph([x], [x**3])
545-
y = scalar("y")
546-
y.tag.test_value = np.array(1.0, dtype=config.floatX)
547-
f = op(y)
548-
grad_f = grad(f, y)
549-
assert grad_f.tag.test_value is not None
550-
551542
def test_make_node_shared(self):
552543
"""Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`."""
553544

@@ -622,14 +613,15 @@ def test_outputs_consistency(self):
622613
"""Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`."""
623614

624615
x = scalar("x")
625-
op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN")
616+
op = OpFromGraph([x], [x**2 / x])
626617

627618
# Confirm that the inner-graph is as expected
628619
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])
629620

630621
# These outputs of the compiled `op.fgraph` should differ from the
631622
# original, uncompiled `op.fgraph` outputs
632-
fn = op.fn
623+
with config.change_flags(mode="FAST_RUN"):
624+
fn = op.fn
633625
new_inputs = fn.maker.fgraph.inputs
634626
new_outputs = fn.maker.fgraph.outputs
635627
assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x])
@@ -740,3 +732,58 @@ def test_debugprint():
740732

741733
for truth, out in zip(exp_res.split("\n"), lines, strict=True):
742734
assert truth.strip() == out.strip()
735+
736+
737+
@pytest.mark.parametrize("kind", ("ofg", "inlined", "scan"))
738+
@pytest.mark.parametrize("c_op", (True, False), ids=lambda x: f"c_op={x}")
739+
def test_benchmark(c_op, kind, benchmark):
740+
class ExpWithoutC(Op):
741+
def make_node(self, x):
742+
return Apply(self, [x], [x.type()])
743+
744+
def perform(self, node, inputs, output_storage):
745+
output_storage[0][0] = np.exp(inputs[0])
746+
747+
exp_without_c = ExpWithoutC()
748+
749+
n = 25
750+
751+
def _f(x):
752+
if isinstance(x, np.ndarray):
753+
y = np.exp(x)
754+
else:
755+
if c_op:
756+
y = pt.exp(x)
757+
else:
758+
y = exp_without_c(x)
759+
y /= y.sum()
760+
return y
761+
762+
x = pt.vector("x")
763+
764+
if kind == "ofg":
765+
f = OpFromGraph([x], [_f(x)])
766+
else:
767+
f = _f
768+
769+
if kind == "scan":
770+
# Scan is included for a reference of how bad the overhead can be
771+
outs, _ = scan(fn=f, outputs_info=[x], n_steps=n)
772+
out = outs[-1]
773+
else:
774+
out = x
775+
for i in range(n):
776+
out = f(out)
777+
778+
compiled_fn = function([x], out, trust_input=True, mode="FAST_RUN")
779+
compiled_fn.vm.allow_gc = False
780+
781+
rng = np.random.default_rng(1)
782+
x_test = rng.normal(size=(10,))
783+
784+
res = benchmark(compiled_fn, x_test)
785+
786+
expected_res = x_test
787+
for i in range(n):
788+
expected_res = _f(expected_res)
789+
np.testing.assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)