66
77import pytensor .scalar as ps
88from pytensor .compile .function import function
9- from pytensor .gradient import grad , jacobian
9+ from pytensor .gradient import grad , grad_not_implemented , jacobian
1010from pytensor .graph .basic import Apply , Constant
1111from pytensor .graph .fg import FunctionGraph
1212from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
1313from pytensor .graph .replace import graph_replace
1414from pytensor .graph .traversal import ancestors , truncated_graph_inputs
15+ from pytensor .scalar import ScalarType , ScalarVariable
1516from pytensor .tensor .basic import (
1617 atleast_2d ,
1718 concatenate ,
19+ scalar_from_tensor ,
1820 tensor ,
1921 tensor_from_scalar ,
2022 zeros_like ,
2123)
2224from pytensor .tensor .math import dot
2325from pytensor .tensor .slinalg import solve
26+ from pytensor .tensor .type import DenseTensorType
2427from pytensor .tensor .variable import TensorVariable , Variable
2528
2629
@@ -143,9 +146,9 @@ def _find_optimization_parameters(
143146def _get_parameter_grads_from_vector (
144147 grad_wrt_args_vector : TensorVariable ,
145148 x_star : TensorVariable ,
146- args : Sequence [Variable ],
149+ args : Sequence [TensorVariable | ScalarVariable ],
147150 output_grad : TensorVariable ,
148- ) -> list [TensorVariable ]:
151+ ) -> list [TensorVariable | ScalarVariable ]:
149152 """
150153 Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
151154 returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
@@ -160,7 +163,10 @@ def _get_parameter_grads_from_vector(
160163 (* x_star .shape , * arg_shape )
161164 )
162165
163- grad_wrt_args .append (dot (output_grad , arg_grad ))
166+ grad_wrt_arg = dot (output_grad , arg_grad )
167+ if isinstance (arg .type , ScalarType ):
168+ grad_wrt_arg = scalar_from_tensor (grad_wrt_arg )
169+ grad_wrt_args .append (grad_wrt_arg )
164170
165171 cursor += arg_size
166172
@@ -267,12 +273,12 @@ def build_fn(self):
267273def scalar_implict_optimization_grads (
268274 inner_fx : TensorVariable ,
269275 inner_x : TensorVariable ,
270- inner_args : Sequence [Variable ],
271- args : Sequence [Variable ],
276+ inner_args : Sequence [TensorVariable | ScalarVariable ],
277+ args : Sequence [TensorVariable | ScalarVariable ],
272278 x_star : TensorVariable ,
273279 output_grad : TensorVariable ,
274280 fgraph : FunctionGraph ,
275- ) -> list [Variable ]:
281+ ) -> list [TensorVariable | ScalarVariable ]:
276282 df_dx , * df_dthetas = grad (
277283 inner_fx , [inner_x , * inner_args ], disconnected_inputs = "ignore"
278284 )
@@ -291,11 +297,11 @@ def scalar_implict_optimization_grads(
291297def implict_optimization_grads (
292298 df_dx : TensorVariable ,
293299 df_dtheta_columns : Sequence [TensorVariable ],
294- args : Sequence [Variable ],
300+ args : Sequence [TensorVariable | ScalarVariable ],
295301 x_star : TensorVariable ,
296302 output_grad : TensorVariable ,
297303 fgraph : FunctionGraph ,
298- ) -> list [TensorVariable ]:
304+ ) -> list [TensorVariable | ScalarVariable ]:
299305 r"""
300306 Compute gradients of an optimization problem with respect to its parameters.
301307
@@ -410,7 +416,19 @@ def perform(self, node, inputs, outputs):
410416 outputs [1 ][0 ] = np .bool_ (res .success )
411417
412418 def L_op (self , inputs , outputs , output_grads ):
419+ # TODO: Handle disconnected inputs
413420 x , * args = inputs
421+ if non_supported_types := tuple (
422+ inp .type
423+ for inp in inputs
424+ if not isinstance (inp .type , DenseTensorType | ScalarType )
425+ ):
426+ # TODO: Support SparseTensorTypes
427+ # TODO: Remaining types are likely just disconnected anyway
428+ msg = f"Minimize gradient not implemented due to inputs of type { non_supported_types } "
429+ return [
430+ grad_not_implemented (self , i , inp , msg ) for i , inp in enumerate (inputs )
431+ ]
414432 x_star , _ = outputs
415433 output_grad , _ = output_grads
416434
@@ -560,7 +578,19 @@ def perform(self, node, inputs, outputs):
560578 outputs [1 ][0 ] = np .bool_ (res .success )
561579
562580 def L_op (self , inputs , outputs , output_grads ):
581+ # TODO: Handle disconnected inputs
563582 x , * args = inputs
583+ if non_supported_types := tuple (
584+ inp .type
585+ for inp in inputs
586+ if not isinstance (inp .type , DenseTensorType | ScalarType )
587+ ):
588+ # TODO: Support SparseTensorTypes
589+ # TODO: Remaining types are likely just disconnected anyway
590+ msg = f"MinimizeOp gradient not implemented due to inputs of type { non_supported_types } "
591+ return [
592+ grad_not_implemented (self , i , inp , msg ) for i , inp in enumerate (inputs )
593+ ]
564594 x_star , _success = outputs
565595 output_grad , _ = output_grads
566596
@@ -727,7 +757,19 @@ def perform(self, node, inputs, outputs):
727757 outputs [1 ][0 ] = np .bool_ (res .converged )
728758
729759 def L_op (self , inputs , outputs , output_grads ):
760+ # TODO: Handle disconnected inputs
730761 x , * args = inputs
762+ if non_supported_types := tuple (
763+ inp .type
764+ for inp in inputs
765+ if not isinstance (inp .type , DenseTensorType | ScalarType )
766+ ):
767+ # TODO: Support SparseTensorTypes
768+ # TODO: Remaining types are likely just disconnected anyway
769+ msg = f"RootScalarOp gradient not implemented due to inputs of type { non_supported_types } "
770+ return [
771+ grad_not_implemented (self , i , inp , msg ) for i , inp in enumerate (inputs )
772+ ]
731773 x_star , _ = outputs
732774 output_grad , _ = output_grads
733775
@@ -908,6 +950,17 @@ def perform(self, node, inputs, outputs):
908950 def L_op (self , inputs , outputs , output_grads ):
909951 # TODO: Handle disconnected inputs
910952 x , * args = inputs
953+ if non_supported_types := tuple (
954+ inp .type
955+ for inp in inputs
956+ if not isinstance (inp .type , DenseTensorType | ScalarType )
957+ ):
958+ # TODO: Support SparseTensorTypes
959+ # TODO: Remaining types are likely just disconnected anyway
960+ msg = f"RootOp gradient not implemented due to inputs of type { non_supported_types } "
961+ return [
962+ grad_not_implemented (self , i , inp , msg ) for i , inp in enumerate (inputs )
963+ ]
911964 x_star , _ = outputs
912965 output_grad , _ = output_grads
913966
0 commit comments