diff --git a/doc/library/compile/io.rst b/doc/library/compile/io.rst index 272d4754db..962d3ea7a5 100644 --- a/doc/library/compile/io.rst +++ b/doc/library/compile/io.rst @@ -35,12 +35,11 @@ The ``inputs`` argument to ``pytensor.function`` is a list, containing the ``Var can be set by ``kwarg``, and its value can be accessed by ``self.``. The default value is ``None``. - ``value``: literal or ``Container``. The initial/default value for this + ``value``: ``Container``. The initial value for this input. If update is ``None``, this input acts just like an argument with a default value in Python. If update is not ``None``, - changes to this - value will "stick around", whether due to an update or a user's - explicit action. + changes to this value will "stick around", whether due to an update + or a user's explicit action. ``update``: Variable instance. This expression Variable will replace ``value`` after each function call. The default value is @@ -73,18 +72,16 @@ The ``inputs`` argument to ``pytensor.function`` is a list, containing the ``Var overwriting its content without being aware of it). -Value: initial and default values ---------------------------------- +Update +------ -A non-None `value` argument makes an In() instance an optional parameter -of the compiled function. For example, in the following code we are -defining an arity-2 function ``inc``. +We can define an update to modify the value >>> import pytensor.tensor as pt >>> from pytensor import function >>> from pytensor.compile.io import In >>> u, x, s = pt.scalars('u', 'x', 's') ->>> inc = function([u, In(x, value=3), In(s, update=(s+x*u), value=10.0)], []) +>>> inc = function([u, In(x), In(s, update=(s+x*u)], []) Since we provided a ``value`` for ``s`` and ``x``, we can call it with just a value for ``u`` like this: diff --git a/pyproject.toml b/pyproject.toml index bbb64549e5..41169f38c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,12 +130,8 @@ exclude = ["doc/", "pytensor/_version.py"] docstring-code-format = true [tool.ruff.lint] -select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"] +select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"] ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"] -unfixable = [ - # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead - "B905", -] [tool.ruff.lint.isort] diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index a4a3d1840a..b669780607 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -873,7 +873,6 @@ def clone(self): def perform(self, node, inputs, outputs): variables = self.fn(*inputs) - assert len(variables) == len(outputs) - # strict=False because asserted above - for output, variable in zip(outputs, variables, strict=False): + # strict=None because we are in a hot loop + for output, variable in zip(outputs, variables): output[0] = variable diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 9cc85f3d24..050c84f97c 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -331,23 +331,35 @@ class Function: ``False``. When ``True``, the `Function` will skip all checks on the inputs. - Attributes - ---------- - finder - Dictionary mapping several kinds of things to containers. - - We set an entry in finder for: - - the index of the input - - the variable instance the input is based on - - the name of the input - - All entries map to the container or to DUPLICATE if an ambiguity - is detected. - inv_finder - Reverse lookup of `finder`. It maps containers to `SymbolicInput`\s. - """ + __slots__ = ( + "vm", + "input_storage", + "output_storage", + "indices", + "outputs", + "unpack_single", + "return_none", + "maker", + "profile", + "trust_input", + "name", + # Created inside __init__ + "_potential_aliased_input_groups", + "_named_inputs", + "_n_unnamed_inputs", + "_finder", + "_inv_finder", + "_has_updates", + "_n_returned_outputs", + "_input_storage_data", + "_update_input_storage", + "_clear_input_storage_data", + "_clear_output_storage_data", + "_nodes_with_inner_function", + ) + pickle_aliased_memory_strategy = "warn" """ How to deal with pickling finding aliased storage. @@ -368,10 +380,8 @@ def __init__( output_storage: list[Container], indices, outputs, - defaults, unpack_single: bool, return_none: bool, - output_keys, maker: "FunctionMaker", trust_input: bool = False, name: str | None = None, @@ -392,20 +402,11 @@ def __init__( tuple elements are used only by Kits, which are deprecated. outputs TODO - defaults - List of 3-tuples, one 3-tuple for each input. - Tuple element 0: ``bool``. Is this input required at each function - call? - Tuple element 1: ``bool``. Should this inputs value be reverted - after each call? - Tuple element 2: ``Any``. The value associated with this input. unpack_single For outputs lists of length 1, should the 0'th element be returned directly? return_none Whether the function should return ``None`` or not. - output_keys - TODO maker The `FunctionMaker` that created this instance. trust_input : bool, default False @@ -422,24 +423,16 @@ def __init__( self.output_storage = output_storage self.indices = indices self.outputs = outputs - self.defaults = defaults self.unpack_single = unpack_single self.return_none = return_none self.maker = maker self.profile = None # reassigned in FunctionMaker.create self.trust_input = trust_input # If True, we don't check the input parameter self.name = name - self.nodes_with_inner_function = [] - self.output_keys = output_keys - - if self.output_keys is not None: - warnings.warn("output_keys is deprecated.", FutureWarning) assert len(self.input_storage) == len(self.maker.fgraph.inputs) assert len(self.output_storage) == len(self.maker.fgraph.outputs) - self.has_defaults = any(refeed for _, refeed, _ in self.defaults) - # Group indexes of inputs that are potentially aliased to each other # Note: Historically, we only worried about aliasing inputs if they belonged to the same type, # even though there could be two distinct types that use the same kinds of underlying objects. @@ -475,41 +468,20 @@ def __init__( if len(group) > 1 ) - # We will be popping stuff off this `containers` object. It is a copy. - containers = list(self.input_storage) - finder = {} - inv_finder = {} - - # Store the list of names of named inputs. - named_inputs = [] - # Count the number of un-named inputs. - n_unnamed_inputs = 0 - + self._finder = finder = {} + self._inv_finder = inv_finder = {} + self._named_inputs = named_inputs = [] + self._n_unnamed_inputs = 0 # Initialize the storage # this loop works by modifying the elements (as variable c) of # self.input_storage inplace. - for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate( - zip(self.indices, defaults, strict=True) - ): + remaining_containers = self.input_storage.copy() + for i, (input, indices, sinputs) in enumerate(self.indices): if indices is None: - # containers is being used as a stack. Here we pop off - # the next one. - c = containers[0] + c = remaining_containers.pop(0) c.strict = getattr(input, "strict", False) c.allow_downcast = getattr(input, "allow_downcast", None) - - if value is not None: - # Always initialize the storage. - if isinstance(value, Container): - # There is no point in obtaining the current value - # stored in the container, since the container is - # shared. - # For safety, we make sure 'refeed' is False, since - # there is no need to refeed the default value. - assert not refeed - else: - c.value = value - c.required = required + c.required = input.value is None c.implicit = input.implicit # this is a count of how many times the input has been # provided (reinitialized to 0 on __call__) @@ -521,71 +493,10 @@ def __init__( else: finder[input.name] = DUPLICATE if input.name is None: - n_unnamed_inputs += 1 + self._n_unnamed_inputs += 1 else: named_inputs.append(input.name) inv_finder[c] = input - containers[:1] = [] - - self.finder = finder - self.inv_finder = inv_finder - - # this class is important in overriding the square-bracket notation: - # fn.value[x] - # self reference is available via the closure on the class - class ValueAttribute: - def __getitem__(self, item): - try: - s = finder[item] - except KeyError: - raise TypeError(f"Unknown input or state: {item}") - if s is DUPLICATE: - raise TypeError( - f"Ambiguous name: {item} - please check the " - "names of the inputs of your function " - "for duplicates." - ) - if isinstance(s, Container): - return s.value - else: - raise NotImplementedError - - def __setitem__(self, item, value): - try: - s = finder[item] - except KeyError: - # Print informative error message. - msg = get_info_on_inputs(named_inputs, n_unnamed_inputs) - raise TypeError(f"Unknown input or state: {item}. {msg}") - if s is DUPLICATE: - raise TypeError( - f"Ambiguous name: {item} - please check the " - "names of the inputs of your function " - "for duplicates." - ) - if isinstance(s, Container): - s.value = value - s.provided += 1 - else: - s(value) - - def __contains__(self, item): - return finder.__contains__(item) - - # this class is important in overriding the square-bracket notation: - # fn.container[x] - # self reference is available via the closure on the class - class ContainerAttribute: - def __getitem__(self, item): - return finder[item] - - def __contains__(self, item): - return finder.__contains__(item) - - # You cannot set the container - - self._value = ValueAttribute() - self._container = ContainerAttribute() update_storage = [ container @@ -595,27 +506,31 @@ def __contains__(self, item): if inp.update is not None ] # Updates are the last inner outputs that are not returned by Function.__call__ - self.n_returned_outputs = len(self.output_storage) - len(update_storage) + self._has_updates = len(update_storage) > 0 + self._n_returned_outputs = len(self.output_storage) - len(update_storage) # Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself - self.update_input_storage: tuple[int, Container] = () + self._update_input_storage: tuple[int, Container] = () if getattr(vm, "need_update_inputs", True): - self.update_input_storage = tuple( + self._update_input_storage = tuple( zip( - range(self.n_returned_outputs, len(output_storage)), + range(self._n_returned_outputs, len(output_storage)), update_storage, strict=True, ) ) + self._input_storage_data = tuple( + container.storage for container in input_storage + ) + # In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage # After the call, we want to erase (some of) these references, to allow Python to GC them if unused - # Required input containers are the non-default inputs, must always be provided again, so we GC them - self.clear_input_storage_data = tuple( + self._clear_input_storage_data = tuple( container.storage for container in input_storage if container.required ) # This is only done when `vm.allow_gc` is True, which can change at runtime. - self.clear_output_storage_data = tuple( + self._clear_output_storage_data = tuple( container.storage for container, variable in zip( self.output_storage, self.maker.fgraph.outputs, strict=True @@ -623,18 +538,11 @@ def __contains__(self, item): if variable.owner is not None # Not a constant output ) - for node in self.maker.fgraph.apply_nodes: - if isinstance(node.op, HasInnerGraph): - self.nodes_with_inner_function.append(node.op) - - def __contains__(self, item): - return self.value.__contains__(item) - - def __getitem__(self, item): - return self.value[item] - - def __setitem__(self, item, value): - self.value[item] = value + self._nodes_with_inner_function = [ + node + for node in self.maker.fgraph.apply_nodes + if isinstance(node.op, HasInnerGraph) + ] def __copy__(self): """ @@ -838,7 +746,6 @@ def checkSV(sv_ori, sv_rpl): # check that. accept_inplace=True, no_fgraph_prep=True, - output_keys=maker.output_keys, name=name, ).create(input_storage, storage_map=new_storage_map) @@ -862,26 +769,126 @@ def checkSV(sv_ori, sv_rpl): # to container, to make Function.value and Function.data work well. # Replace variable in new maker.inputs by the original ones. # So that user can swap SharedVariable in a swapped function - container = f_cpy.finder.pop(in_cpy.variable) + container = f_cpy._finder.pop(in_cpy.variable) if not swapped: - f_cpy.finder[in_ori.variable] = container + f_cpy._finder[in_ori.variable] = container in_cpy.variable = in_ori.variable else: - f_cpy.finder[swap[in_ori.variable]] = container + f_cpy._finder[swap[in_ori.variable]] = container in_cpy.variable = swap[in_ori.variable] f_cpy.trust_input = self.trust_input f_cpy.unpack_single = self.unpack_single return f_cpy - def _restore_defaults(self): - for i, (required, refeed, value) in enumerate(self.defaults): - if refeed: - if isinstance(value, Container): - value = value.storage[0] - self[i] = value + def _validate_inputs(self, args, kwargs): + input_storage = self.input_storage + + if len(args) + len(kwargs) > len(input_storage): + raise TypeError("Too many parameter passed to pytensor function") + + for arg_container in input_storage: + arg_container.provided = 0 + + # Set positional arguments + for arg_container, arg in zip(input_storage, args): + try: + arg_container.storage[0] = arg_container.type.filter( + arg, + strict=arg_container.strict, + allow_downcast=arg_container.allow_downcast, + ) + + except Exception as e: + i = input_storage.index(arg_container) + function_name = "pytensor function" + argument_name = "argument" + if self.name: + function_name += ' with name "' + self.name + '"' + if hasattr(arg, "name") and arg.name: + argument_name += ' with name "' + arg.name + '"' + where = get_variable_trace_string(self.maker.inputs[i].variable) + if len(e.args) == 1: + e.args = ( + "Bad input " + + argument_name + + " to " + + function_name + + f" at index {int(i)} (0-based). {where}" + + e.args[0], + ) + else: + e.args = ( + "Bad input " + + argument_name + + " to " + + function_name + + f" at index {int(i)} (0-based). {where}" + ) + e.args + raise + arg_container.provided += 1 - def __call__(self, *args, output_subset=None, **kwargs): + # Set keyword arguments + if kwargs: # for speed, skip the items for empty kwargs + for key, arg in kwargs.items(): + try: + kwarg_container = self._finder[key] + except KeyError: + # Print informative error message. + msg = get_info_on_inputs(self._named_inputs, self._n_unnamed_inputs) + raise TypeError(f"Unknown input: {key}. {msg}") + if kwarg_container is DUPLICATE: + raise TypeError( + f"Ambiguous name: {key} - please check the names of the inputs of your function for duplicates." + ) + kwarg_container.value = arg + kwarg_container.provided += 1 + + # Collect aliased inputs among the storage space + for potential_group in self._potential_aliased_input_groups: + args_share_memory: list[list[int]] = [] + for i in potential_group: + i_type = self.maker.inputs[i].variable.type + i_val = input_storage[i].storage[0] + + # Check if value is aliased with any of the values in one of the groups + for j_group in args_share_memory: + if any( + i_type.may_share_memory(input_storage[j].storage[0], i_val) + for j in j_group + ): + j_group.append(i) + break + else: # no break + # Create a new group + args_share_memory.append([i]) + + # Check for groups of more than one argument that share memory + for group in args_share_memory: + if len(group) > 1: + # copy all but the first + for i in group[1:]: + input_storage[i].storage[0] = copy.copy( + input_storage[i].storage[0] + ) + + # Check if inputs are missing, or if inputs were set more than once, or + # if we tried to provide inputs that are supposed to be implicit. + for arg_container in input_storage: + if arg_container.required and not arg_container.provided: + raise TypeError( + f"Missing input: {getattr(self._inv_finder[arg_container], 'variable', self._inv_finder[arg_container])}" + ) + if arg_container.provided > 1: + raise TypeError( + f"Multiple values for input: {getattr(self._inv_finder[arg_container], 'variable', self._inv_finder[arg_container])}" + ) + if arg_container.implicit and arg_container.provided > 0: + raise TypeError( + f"Tried to provide value for implicit input: {getattr(self._inv_finder[arg_container], 'variable', self._inv_finder[arg_container])}" + ) + + def __call__(self, *args, **kwargs): """ Evaluates value of a function on given arguments. @@ -909,134 +916,30 @@ def __call__(self, *args, output_subset=None, **kwargs): List of outputs on indices/keys from ``output_subset`` or all of them, if ``output_subset`` is not passed. """ - trust_input = self.trust_input - input_storage = self.input_storage - vm = self.vm - profile = self.profile - - if profile: + if self.profile: t0 = time.perf_counter() - if output_subset is not None: - warnings.warn("output_subset is deprecated.", FutureWarning) - if self.output_keys is not None: - output_subset = [self.output_keys.index(key) for key in output_subset] - # Reinitialize each container's 'provided' counter - if trust_input: - for arg_container, arg in zip(input_storage, args, strict=False): - arg_container.storage[0] = arg + if self.trust_input: + for storage_data, arg in zip(self._input_storage_data, args): + storage_data[0] = arg + if kwargs: # for speed, skip the items for empty kwargs + for k, arg in kwargs.items(): + self._finder[k].storage[0] = arg else: - for arg_container in input_storage: - arg_container.provided = 0 - - if len(args) + len(kwargs) > len(input_storage): - raise TypeError("Too many parameter passed to pytensor function") - - # Set positional arguments - for arg_container, arg in zip(input_storage, args, strict=False): - # See discussion about None as input - # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5 - if arg is None: - arg_container.storage[0] = arg - else: - try: - arg_container.storage[0] = arg_container.type.filter( - arg, - strict=arg_container.strict, - allow_downcast=arg_container.allow_downcast, - ) - - except Exception as e: - i = input_storage.index(arg_container) - function_name = "pytensor function" - argument_name = "argument" - if self.name: - function_name += ' with name "' + self.name + '"' - if hasattr(arg, "name") and arg.name: - argument_name += ' with name "' + arg.name + '"' - where = get_variable_trace_string(self.maker.inputs[i].variable) - if len(e.args) == 1: - e.args = ( - "Bad input " - + argument_name - + " to " - + function_name - + f" at index {int(i)} (0-based). {where}" - + e.args[0], - ) - else: - e.args = ( - "Bad input " - + argument_name - + " to " - + function_name - + f" at index {int(i)} (0-based). {where}" - ) + e.args - self._restore_defaults() - raise - arg_container.provided += 1 - - # Set keyword arguments - if kwargs: # for speed, skip the items for empty kwargs - for k, arg in kwargs.items(): - self[k] = arg - - if not trust_input: - # Collect aliased inputs among the storage space - for potential_group in self._potential_aliased_input_groups: - args_share_memory: list[list[int]] = [] - for i in potential_group: - i_type = self.maker.inputs[i].variable.type - i_val = input_storage[i].storage[0] - - # Check if value is aliased with any of the values in one of the groups - for j_group in args_share_memory: - if any( - i_type.may_share_memory(input_storage[j].storage[0], i_val) - for j in j_group - ): - j_group.append(i) - break - else: # no break - # Create a new group - args_share_memory.append([i]) - - # Check for groups of more than one argument that share memory - for group in args_share_memory: - if len(group) > 1: - # copy all but the first - for i in group[1:]: - input_storage[i].storage[0] = copy.copy( - input_storage[i].storage[0] - ) - - # Check if inputs are missing, or if inputs were set more than once, or - # if we tried to provide inputs that are supposed to be implicit. - for arg_container in input_storage: - if arg_container.required and not arg_container.provided: - self._restore_defaults() - raise TypeError( - f"Missing required input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}" - ) - if arg_container.provided > 1: - self._restore_defaults() - raise TypeError( - f"Multiple values for input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}" - ) - if arg_container.implicit and arg_container.provided > 0: - self._restore_defaults() - raise TypeError( - f"Tried to provide value for implicit input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}" - ) + self._validate_inputs(args, kwargs) # Do the actual work - if profile: - t0_fn = time.perf_counter() try: - outputs = vm() if output_subset is None else vm(output_subset=output_subset) + if self.profile: + t0_fn = time.perf_counter() + outputs = self.vm() + dt_fn = time.perf_counter() - t0_fn + self.maker.mode.fn_time += dt_fn + self.profile.vm_call_time += dt_fn + else: + outputs = self.vm() except Exception: - self._restore_defaults() if hasattr(self.vm, "position_of_error"): # this is a new vm-provided function or c linker # they need this because the exception manipulation @@ -1054,71 +957,39 @@ def __call__(self, *args, output_subset=None, **kwargs): # old-style linkers raise their own exceptions raise - if profile: - dt_fn = time.perf_counter() - t0_fn - self.maker.mode.fn_time += dt_fn - profile.vm_call_time += dt_fn - - # Retrieve the values that were computed if outputs is None: + # Not all VMs can return outputs directly (mainly CLinker?) outputs = [x.storage[0] for x in self.output_storage] # Set updates and filter them out from the returned outputs - for i, input_storage in self.update_input_storage: - input_storage.storage[0] = outputs[i] - outputs = outputs[: self.n_returned_outputs] + if self._has_updates: + for i, input_storage in self._update_input_storage: + input_storage.storage[0] = outputs[i] + outputs = outputs[: self._n_returned_outputs] # Remove input and output values from storage data - for storage_data in self.clear_input_storage_data: - storage_data[0] = None - if getattr(vm, "allow_gc", False): - for storage_data in self.clear_output_storage_data: + if self.vm.allow_gc: + for storage_data in self._clear_input_storage_data: + storage_data[0] = None + for storage_data in self._clear_output_storage_data: storage_data[0] = None - # Put default values back in the storage - if self.has_defaults: - self._restore_defaults() - - if profile: + if self.profile: + profile = self.profile dt_call = time.perf_counter() - t0 pytensor.compile.profiling.total_fct_exec_time += dt_call self.maker.mode.call_time += dt_call profile.fct_callcount += 1 profile.fct_call_time += dt_call - if hasattr(vm, "update_profile"): - vm.update_profile(profile) + if hasattr(self.vm, "update_profile"): + self.vm.update_profile(profile) if profile.ignore_first_call: profile.reset() profile.ignore_first_call = False - if self.return_none: - return None - - if output_subset is not None: - outputs = [outputs[i] for i in output_subset] - - if self.output_keys is None: - if self.unpack_single: - [out] = outputs - return out - else: - return outputs - else: - output_keys = self.output_keys - if output_subset is not None: - output_keys = [output_keys[i] for i in output_subset] - return dict(zip(output_keys, outputs, strict=True)) - - value = property( - lambda self: self._value, - None, # this property itself is not settable - doc="dictionary-like access to the values associated with Variables", - ) - container = property( - lambda self: self._container, - None, # this property itself is not settable - doc=("dictionary-like access to the containers associated with Variables"), - ) + return ( + outputs[0] if self.unpack_single else None if self.return_none else outputs + ) def free(self): """ @@ -1126,13 +997,16 @@ def free(self): """ # 1.no allow_gc return False # 2.has allow_gc, if allow_gc is False, return True - if not getattr(self.vm, "allow_gc", True): + if not self.vm.allow_gc: + for inp_storage in self._clear_input_storage_data: + inp_storage[0] = None + storage_map = self.vm.storage_map for key, value in storage_map.items(): if key.owner is not None: # Not a constant value[0] = None - for node in self.nodes_with_inner_function: + for node in self._nodes_with_inner_function: if hasattr(node.fn, "free"): node.fn.free() @@ -1157,17 +1031,7 @@ def dprint(self, **kwargs): # pickling/deepcopy support for Function def _pickle_Function(f): - # copy of the input storage list - ins = list(f.input_storage) - input_storage = [] - - # strict=False because we are in a hot loop - for (input, indices, inputs), (required, refeed, default) in zip( - f.indices, f.defaults, strict=False - ): - input_storage.append(ins[0]) - del ins[0] - + input_storage = f.input_storage.copy() inputs_data = [x.data for x in f.input_storage] # HACK to detect aliased storage. @@ -1521,6 +1385,9 @@ def __init__( no_fgraph_prep=False, trust_input=False, ): + if output_keys is not None: + raise ValueError("output_keys was deprecated") + # Save the provided mode, not the instantiated mode. # The instantiated mode don't pickle and if we unpickle an PyTensor # function and it get re-compiled, we want the current rewriter to be @@ -1561,6 +1428,20 @@ def __init__( # Wrap them in In or Out instances if needed. inputs = [self.wrap_in(i) for i in inputs] + + # Remove this after a while + if any( + ( + i.value is not None + and not isinstance(i.value, Container) + and i.update is None + ) + for i in inputs + ): + raise ValueError( + "Inputs with default values were deprecated. Use `functools.partial` instead." + ) + outputs = [self.wrap_out(o) for o in outputs] # Check if some input variables are unused @@ -1620,21 +1501,9 @@ def __init__( self.accept_inplace = accept_inplace self.function_builder = function_builder self.on_unused_input = on_unused_input # Used for the pickling/copy - self.output_keys = output_keys self.name = name self.trust_input = trust_input - self.required = [(i.value is None) for i in self.inputs] - self.refeed = [ - ( - i.value is not None - and not isinstance(i.value, Container) - and i.update is None - ) - for i in self.inputs - ] - if any(self.refeed): - warnings.warn("Inputs with default values are deprecated.", FutureWarning) def create(self, input_storage=None, storage_map=None): """ @@ -1652,7 +1521,6 @@ def create(self, input_storage=None, storage_map=None): input_storage = [None] * len(self.inputs) # list of independent one-element lists, will be passed to the linker input_storage_lists = [] - defaults = [] # The following loop is to fill in the input_storage_lists and # defaults lists. @@ -1679,35 +1547,15 @@ def create(self, input_storage=None, storage_map=None): ) input_storage_lists.append(input_storage_i.storage) - storage = input_storage[i].storage[0] - else: # Normal case: one new, independent storage unit input_storage_lists.append([input_storage_i]) - storage = input_storage_i - required = self.required[i] - refeed = self.refeed[i] - # sanity check-- if an input is required it should not - # need to be refed - assert not (required and refeed) # shared variables need neither be input by the user nor refed if input.shared: assert not required - assert not refeed - storage = None - - # if an input is required, it never need be refed - if required: - storage = None - - # make sure that we only store a value if we actually need it - if storage is not None: - assert refeed or not required - - defaults.append((required, refeed, storage)) # Get a function instance start_linker = time.perf_counter() @@ -1730,16 +1578,14 @@ def create(self, input_storage=None, storage_map=None): self.profile.import_time += import_time fn = self.function_builder( - _fn, - _i, - _o, - self.indices, - self.outputs, - defaults, - self.unpack_single, - self.return_none, - self.output_keys, - self, + vm=_fn, + input_storage=_i, + output_storage=_o, + indices=self.indices, + outputs=self.outputs, + unpack_single=self.unpack_single, + return_none=self.return_none, + maker=self, trust_input=self.trust_input, name=self.name, ) @@ -1809,7 +1655,7 @@ def orig_function( else: outputs = FunctionMaker.wrap_out(outputs) - defaults = [getattr(input, "value", None) for input in inputs] + shared_variable_containers = [getattr(input, "value", None) for input in inputs] if isinstance(mode, list | tuple): raise ValueError("We do not support the passing of multiple modes") @@ -1830,7 +1676,7 @@ def orig_function( trust_input=trust_input, ) with config.change_flags(compute_test_value="off"): - fn = m.create(defaults) + fn = m.create(shared_variable_containers) finally: if profile and fn: t2 = time.perf_counter() diff --git a/pytensor/compile/io.py b/pytensor/compile/io.py index 9ce0421235..07554929ff 100644 --- a/pytensor/compile/io.py +++ b/pytensor/compile/io.py @@ -182,9 +182,6 @@ def __init__( borrow=None, shared=False, ): - # if shared, an input's value comes from its persistent - # storage, not from a default stored in the function or from - # the caller self.shared = shared if borrow is None: @@ -204,6 +201,13 @@ def __init__( "overwritten.", ) + if value is not None and not isinstance(value, Container): + from pytensor.compile.sharedvalue import SharedVariable + + if not isinstance(value, SharedVariable): + # This is to catch use of old API to pass default values + raise ValueError("Inputs with default values are deprecated") + if implicit is None: from pytensor.compile.sharedvalue import SharedVariable diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index 9cf34983f2..2ab9146576 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -373,7 +373,12 @@ def make_all( # The function that actually runs your program is one of the f's in streamline. f = streamline( - fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling + fgraph, + thunks, + order, + post_thunk_old_storage=post_thunk_old_storage, + no_recycling=no_recycling, + output_storage=output_storage, ) f.allow_gc = ( @@ -539,20 +544,21 @@ def make_thunk(self, **kwargs): def f(): for inputs in input_lists[1:]: - # strict=False because we are in a hot loop - for input1, input2 in zip(inputs0, inputs, strict=False): + # strict=None because we are in a hot loop + for input1, input2 in zip(inputs0, inputs): input2.storage[0] = copy(input1.storage[0]) for x in to_reset: x[0] = None pre(self, [input.data for input in input_lists[0]], order, thunk_groups) - # strict=False because we are in a hot loop - for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)): + # strict=None because we are in a hot loop + for i, (thunks, node) in enumerate(zip(thunk_groups, order)): try: wrapper(self.fgraph, i, node, *thunks) except Exception: raise_with_op(self.fgraph, node, *thunks) f.thunk_groups = thunk_groups + f.allow_gc = len(self.linkers) == 1 return f, inputs0, outputs0 @@ -668,10 +674,12 @@ def thunk( # since the error may come from any of them? raise_with_op(self.fgraph, output_nodes[0], thunk) - # strict=False because we are in a hot loop - for o_storage, o_val in zip(thunk_outputs, outputs, strict=False): + # strict=None because we are in a hot loop + for o_storage, o_val in zip(thunk_outputs, outputs): o_storage[0] = o_val + return outputs + thunk.inputs = thunk_inputs thunk.outputs = thunk_outputs thunk.lazy = False diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index d509bd1d76..1c08a21f04 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -1188,6 +1188,7 @@ def make_thunk( res = _CThunk(cthunk, init_tasks, tasks, error_storage, module) res.nodes = self.node_order + res.allow_gc = False return res, in_storage, out_storage def cmodule_key(self): @@ -1875,9 +1876,10 @@ def make_all( fgraph, thunks, order, - post_thunk_old_storage, + post_thunk_old_storage=post_thunk_old_storage, no_recycling=no_recycling, nice_errors=self.nice_errors, + output_storage=output_storage, ) f.allow_gc = self.allow_gc diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index a6a82ceebe..6a85bb77f1 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -312,10 +312,10 @@ def py_perform_return(inputs): else: def py_perform_return(inputs): - # strict=False because we are in a hot loop + # strict=None because we are in a hot loop return tuple( out_type.filter(out[0]) - for out_type, out in zip(output_types, py_perform(inputs), strict=False) + for out_type, out in zip(output_types, py_perform(inputs)) ) @numba_njit diff --git a/pytensor/link/numba/dispatch/cython_support.py b/pytensor/link/numba/dispatch/cython_support.py index 8dccf98836..422e4be406 100644 --- a/pytensor/link/numba/dispatch/cython_support.py +++ b/pytensor/link/numba/dispatch/cython_support.py @@ -166,10 +166,7 @@ def __wrapper_address__(self): def __call__(self, *args, **kwargs): # no strict argument because of the JIT # TODO: check - args = [ - dtype(arg) - for arg, dtype in zip(args, self._signature.arg_dtypes) # noqa: B905 - ] + args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)] if self.has_pyx_skip_dispatch(): output = self._pyfunc(*args[:-1], **kwargs) else: diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 1f0a33e595..f7700acf47 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -186,7 +186,7 @@ def ravelmultiindex(*inp): new_arr = arr.T.astype(np.float64).copy() for i, b in enumerate(new_arr): # no strict argument to this zip because numba doesn't support it - for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905 + for j, (d, v) in enumerate(zip(shape, b)): if v < 0 or v >= d: mode_fn(new_arr, i, j, v, d) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 7e1f6ded56..92f8a254f8 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -183,7 +183,7 @@ def block_diag(*arrs): r, c = 0, 0 # no strict argument because it is incompatible with numba - for arr, shape in zip(arrs, shapes): # noqa: B905 + for arr, shape in zip(arrs, shapes): rr, cc = shape out[r : r + rr, c : c + cc] = arr r += rr diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index ee9e183d16..5f471707a5 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -219,7 +219,7 @@ def advanced_subtensor_multiple_vector(x, *idxs): shape_aft = x_shape[after_last_axis:] out_shape = (*shape_bef, *idx_shape, *shape_aft) out_buffer = np.empty(out_shape, dtype=x.dtype) - for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + for i, scalar_idxs in enumerate(zip(*vec_idxs)): out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)] return out_buffer @@ -253,7 +253,7 @@ def advanced_set_subtensor_multiple_vector(x, y, *idxs): y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) for outer in np.ndindex(x_shape[:first_axis]): - for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + for i, scalar_idxs in enumerate(zip(*vec_idxs)): out[(*outer, *scalar_idxs)] = y[(*outer, i)] return out @@ -275,7 +275,7 @@ def advanced_inc_subtensor_multiple_vector(x, y, *idxs): y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) for outer in np.ndindex(x_shape[:first_axis]): - for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + for i, scalar_idxs in enumerate(zip(*vec_idxs)): out[(*outer, *scalar_idxs)] += y[(*outer, i)] return out @@ -314,7 +314,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs): if not len(idxs) == len(vals): raise ValueError("The number of indices and values must match.") # no strict argument because incompatible with numba - for idx, val in zip(idxs, vals): # noqa: B905 + for idx, val in zip(idxs, vals): x[idx] = val return x else: @@ -342,7 +342,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs): raise ValueError("The number of indices and values must match.") # no strict argument because unsupported by numba # TODO: this doesn't come up in tests - for idx, val in zip(idxs, vals): # noqa: B905 + for idx, val in zip(idxs, vals): x[idx] += val return x diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 9cbc3838dd..468422d01f 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -145,9 +145,11 @@ def streamline( fgraph: FunctionGraph, thunks: Sequence[Callable[[], None]], order: Sequence[Apply], + *, post_thunk_old_storage: list["StorageCellType"] | None = None, no_recycling: list["StorageCellType"] | None = None, nice_errors: bool = True, + output_storage: list["StorageCellType"], ) -> "BasicThunkType": """Construct a single thunk that runs a list of thunks. @@ -190,13 +192,14 @@ def streamline_default_f(): for x in no_recycling: x[0] = None try: - # strict=False because we are in a hot loop + # strict=None because we are in a hot loop for thunk, node, old_storage in zip( - thunks, order, post_thunk_old_storage, strict=False + thunks, order, post_thunk_old_storage ): thunk() for old_s in old_storage: old_s[0] = None + return [out[0] for out in output_storage] except Exception: raise_with_op(fgraph, node, thunk) @@ -207,11 +210,12 @@ def streamline_nice_errors_f(): for x in no_recycling: x[0] = None try: - # strict=False because we are in a hot loop - for thunk, node in zip(thunks, order, strict=False): + # strict=None because we are in a hot loop + for thunk, node in zip(thunks, order): thunk() except Exception: raise_with_op(fgraph, node, thunk) + return [out[0] for out in output_storage] f = streamline_nice_errors_f else: @@ -222,6 +226,7 @@ def streamline_fast_f(): x[0] = None for thunk in thunks: thunk() + return [out[0] for out in output_storage] f = streamline_fast_f return f diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 909fc47c27..5349b7d2cf 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4416,8 +4416,8 @@ def make_node(self, *inputs): def perform(self, node, inputs, output_storage): outputs = self.py_perform_fn(*inputs) - # strict=False because we are in a hot loop - for storage, out_val in zip(output_storage, outputs, strict=False): + # strict=None because we are in a hot loop + for storage, out_val in zip(output_storage, outputs): storage[0] = out_val def grad(self, inputs, output_grads): diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 1023e6a127..98b715cc0e 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -196,8 +196,8 @@ def perform(self, node, inputs, output_storage): for i in range(n_steps): carry = inner_fn(*carry, *constant) - # strict=False because we are in a hot loop - for storage, out_val in zip(output_storage, carry, strict=False): + # strict=None because we are in a hot loop + for storage, out_val in zip(output_storage, carry): storage[0] = out_val @property diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 2c3f404449..0a757da575 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1353,20 +1353,8 @@ def prepare_fgraph(self, fgraph): preallocated_mitmot_outs.append(output_idx) - # Make it so that the input is automatically updated to - # the output value, possibly inplace, at the end of the - # function execution. Also, since an update is defined, - # a default value must also be (this is verified by - # DebugMode). - # TODO FIXME: Why do we need a "default value" here? - # This sounds like a serious design issue. - default_shape = tuple( - s if s is not None else 0 for s in inp.type.shape - ) - default_val = np.empty(default_shape, dtype=inp.type.dtype) wrapped_inp = In( variable=inp, - value=default_val, update=fgraph.outputs[output_idx], ) update_mapping[output_idx] = input_idx diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 214a7bdd3d..6939e6b155 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1865,8 +1865,8 @@ def rng_fn(cls, rng, p, size): # to `p.shape[:-1]` in the call to `vsearchsorted` below. if len(size) < (p.ndim - 1): raise ValueError("`size` is incompatible with the shape of `p`") - # strict=False because we are in a hot loop - for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False): + # strict=None because we are in a hot loop + for s, ps in zip(reversed(size), reversed(p.shape[:-1])): if s == 1 and ps != 1: raise ValueError("`size` is incompatible with the shape of `p`") diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 23b4b50265..c91745b60b 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -44,8 +44,8 @@ def params_broadcast_shapes( max_fn = maximum if use_pytensor else max rev_extra_dims: list[int] = [] - # strict=False because we are in a hot loop - for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False): + # strict=None because we are in a hot loop + for ndim_param, param_shape in zip(ndims_params, param_shapes): # We need this in order to use `len` param_shape = tuple(param_shape) extras = tuple(param_shape[: (len(param_shape) - ndim_param)]) @@ -69,7 +69,7 @@ def max_bcast(x, y): (extra_dims + tuple(param_shape)[-ndim_param:]) if ndim_param > 0 else extra_dims - for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False) + for ndim_param, param_shape in zip(ndims_params, param_shapes) ] return bcast_shapes @@ -127,10 +127,9 @@ def broadcast_params( ) broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to - # strict=False because we are in a hot loop + # strict=None because we are in a hot loop bcast_params = [ - broadcast_to_fn(param, shape) - for shape, param in zip(shapes, params, strict=False) + broadcast_to_fn(param, shape) for shape, param in zip(shapes, params) ] return bcast_params diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 5a4cfdc52a..f82bc7a39b 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -447,10 +447,8 @@ def perform(self, node, inp, out_): raise AssertionError( f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}." ) - # strict=False because we are in a hot loop - if not all( - xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None - ): + # strict=None because we are in a hot loop + if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None): raise AssertionError( f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}." ) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index d0b6b5fe0a..0591634b05 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -199,11 +199,7 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: # (do not try to convert the data) up_dtype = ps.upcast(self.dtype, data.dtype) if up_dtype == self.dtype: - # Bug in the following line when data is a - # scalar array, see - # http://projects.scipy.org/numpy/ticket/1611 - # data = data.astype(self.dtype) - data = np.asarray(data, dtype=self.dtype) + data = data.astype(self.dtype) if up_dtype != self.dtype: err_msg = ( f"{self} cannot store a value of dtype {data.dtype} without " @@ -261,17 +257,15 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: " PyTensor C code does not support that.", ) - # strict=False because we are in a hot loop + # strict=None because we are in a hot loop if not all( ds == ts if ts is not None else True - for ds, ts in zip(data.shape, self.shape, strict=False) + for ds, ts in zip(data.shape, self.shape) ): raise TypeError( f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})" ) - if self.filter_checks_isfinite and not np.all(np.isfinite(data)): - raise ValueError("Non-finite elements not allowed") return data def filter_variable(self, other, allow_convert=True): diff --git a/tests/compile/function/test_function.py b/tests/compile/function/test_function.py index d1f94dd689..018f9a40f0 100644 --- a/tests/compile/function/test_function.py +++ b/tests/compile/function/test_function.py @@ -11,6 +11,7 @@ from pytensor.compile.function import function, function_dump from pytensor.compile.io import In from pytensor.configdefaults import config +from pytensor.link.basic import Container from pytensor.npy_2_compat import UintOverflowError from pytensor.tensor.type import ( bscalar, @@ -119,7 +120,9 @@ def test_in_mutable(self): def test_in_update(self): a = dscalar("a") - f = function([In(a, value=0.0, update=a + 1)], a, mode="FAST_RUN") + # A shared variable by any other name + c = Container(a, storage=[np.array(0.0)]) + f = function([In(a, value=c, implicit=True, update=a + 1)], a, mode="FAST_RUN") # Ensure that, through the executions of the function, the state of the # input is persistent and is updated as it should @@ -140,7 +143,8 @@ def test_in_update_shared(self): # updates in the same function behaves as expected shared_var = shared(1.0) a = dscalar("a") - a_wrapped = In(a, value=0.0, update=shared_var) + container = Container(a, storage=[np.array(0.0)]) + a_wrapped = In(a, value=container, update=shared_var) f = function([a_wrapped], [], updates={shared_var: a}, mode="FAST_RUN") # Ensure that, through the executions of the function, the state of diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index 3e23b12f74..616237d5a7 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -59,7 +59,7 @@ def test_doc(self): a = lscalar() b = shared(1) f1 = pfunc([a], (a + b)) - f2 = pfunc([In(a, value=44)], a + b, updates={b: b + 1}) + f2 = pfunc([In(a)], a + b, updates={b: b + 1}) assert b.get_value() == 1 assert f1(3) == 4 assert f2(3) == 4 diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 0990dbeca0..7e2110a26e 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -15,7 +15,7 @@ from pytensor.graph.basic import Constant from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter from pytensor.graph.utils import MissingInputError -from pytensor.link.vm import VMLinker +from pytensor.link.basic import Container from pytensor.printing import debugprint from pytensor.tensor.math import dot, tanh from pytensor.tensor.math import sum as pt_sum @@ -193,96 +193,6 @@ def test_naming_rule2(self): # got unexpected keyword argument 'x' f(5.0, x=9) - def test_naming_rule3(self): - a = scalar() # the a is for 'anonymous' (un-named). - x, s = scalars("xs") - - # x's name is not ignored (as in test_naming_rule2) because a has a default value. - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f = function([x, In(a, value=1.0), s], a / s + x) - assert f(9, 2, 4) == 9.5 # can specify all args in order - assert f(9, 2, s=4) == 9.5 # can give s as kwarg - assert f(9, s=4) == 9.25 # can give s as kwarg, get default a - assert f(x=9, s=4) == 9.25 # can give s as kwarg, omit a, x as kw - with pytest.raises(TypeError): - # got unexpected keyword argument 'a' - f(x=9, a=2, s=4) - with pytest.raises(TypeError): - # takes exactly 3 non-keyword arguments (0 given) - f() - with pytest.raises(TypeError): - # takes exactly 3 non-keyword arguments (1 given) - f(x=9) - - def test_naming_rule4(self): - a = scalar() # the a is for 'anonymous' (un-named). - x, s = scalars("xs") - - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f = function([x, In(a, value=1.0, name="a"), s], a / s + x) - - assert f(9, 2, 4) == 9.5 # can specify all args in order - assert f(9, 2, s=4) == 9.5 # can give s as kwarg - assert f(9, s=4) == 9.25 # can give s as kwarg, get default a - assert f(9, a=2, s=4) == 9.5 # can give s as kwarg, a as kwarg - assert f(x=9, a=2, s=4) == 9.5 # can give all kwargs - assert f(x=9, s=4) == 9.25 # can give all kwargs - with pytest.raises(TypeError): - # takes exactly 3 non-keyword arguments (0 given) - f() - with pytest.raises(TypeError): - # got multiple values for keyword argument 'x' - f(5.0, x=9) - - @pytest.mark.parametrize( - "mode", - [ - Mode( - linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False), - optimizer="fast_compile", - ), - Mode( - linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False), - optimizer="fast_run", - ), - Mode(linker="cvm", optimizer="fast_compile"), - Mode(linker="cvm", optimizer="fast_run"), - ], - ) - def test_state_access(self, mode): - a = scalar() - x, s = scalars("xs") - - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f = function( - [x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)], - s + a * x, - mode=mode, - ) - - assert f[a] == 1.0 - assert f[s] == 0.0 - - assert f(3.0) == 3.0 - assert f[s] == 3.0 - assert f(3.0, a=2.0) == 9.0 # 3.0 + 2*3.0 - - assert ( - f[a] == 1.0 - ) # state hasn't changed permanently, we just overrode it last line - assert f[s] == 9.0 - - f[a] = 5.0 - assert f[a] == 5.0 - assert f(3.0) == 24.0 # 9 + 3*5 - assert f[s] == 24.0 - def test_same_names(self): a, x, s = scalars("xxx") # implicit names would cause error. What do we do? @@ -300,9 +210,9 @@ def test_weird_names(self): def t(): f = function( [ - In(a, name={"adsf", ()}, value=1.0), - In(x, name=(), value=2.0), - In(s, name=scalar(), value=3.0), + In(a, name={"adsf", ()}), + In(x, name=()), + In(s, name=scalar()), ], a + x + s, ) @@ -315,47 +225,29 @@ def test_copy(self): a = scalar() x, s = scalars("xs") - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + f = function( + [ + x, + In(a, name="a"), + In(s, name="s"), + ], + s + a * x, + ) - g = copy.copy(f) + g = f.copy() assert f.unpack_single == g.unpack_single assert f.trust_input == g.trust_input - assert g.container[x].storage is not f.container[x].storage - assert g.container[a].storage is not f.container[a].storage - assert g.container[s].storage is not f.container[s].storage - - # Should not have been copied - assert g.value[a] is f.value[a] - - # Should have been copied because it is mutable - assert g.value[s] is not f.value[s] - - # Their contents should be equal, though - assert np.array_equal(g.value[s], f.value[s]) + assert g._finder[x].storage is not f._finder[x].storage + assert g._finder[a].storage is not f._finder[a].storage + assert g._finder[s].storage is not f._finder[s].storage - # They should be in sync, default value should be copied - assert np.array_equal(f(2, 1), g(2)) + assert g._finder[a].value is None and f._finder[a].value is None + assert g._finder[s].value is None and f._finder[s].value is None - # They should be in sync, default value should be copied - assert np.array_equal(f(2, 1), g(2)) - - # Put them out of sync - f(1, 2) - - # They should not be equal anymore - assert not np.array_equal(f(1, 2), g(1, 2)) + assert np.array_equal(f(2, 1, 0), g(2, 1, 0)) + assert np.array_equal(f(2, 1, 0), g(2, 1, 0)) def test_copy_share_memory(self): x = fscalar("x") @@ -519,88 +411,90 @@ def test_shared_state0(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) - g = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=f.container[s], update=s - a * x, mutable=True), - ], - s + a * x, - ) + f = function( + [ + x, + In(a, name="a"), + In( + s, + value=Container(s, storage=[np.array(0.0)]), + update=s + a * x, + mutable=True, + ), + ], + s + a * x, + ) + g = function( + [ + x, + In(a, name="a"), + In(s, value=f._finder[s], update=s - a * x, mutable=True), + ], + s + a * x, + ) f(1, 2) - assert f[s] == 2 - assert g[s] == 2 + assert f._finder[s].value == 2 + assert g._finder[s].value == 2 g(1, 2) - assert f[s] == 0 - assert g[s] == 0 + assert f._finder[s].value == 0 + assert g._finder[s].value == 0 def test_shared_state1(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) - g = function( - [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x - ) + f = function( + [ + x, + In(a, name="a"), + In( + s, + value=Container(s, storage=[np.array(0.0)]), + update=s + a * x, + mutable=True, + ), + ], + s + a * x, + ) + g = function([x, In(a, name="a"), In(s, value=f._finder[s])], s + a * x) f(1, 2) - assert f[s] == 2 - assert g[s] == 2 + assert f._finder[s].value == 2 + assert g._finder[s].value == 2 f(1, 2) g(1, 2) - assert f[s] == 4 - assert g[s] == 4 + assert f._finder[s].value == 4 + assert g._finder[s].value == 4 def test_shared_state2(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=False), - ], - s + a * x, - ) - g = function( - [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x - ) + f = function( + [ + x, + In(a, name="a"), + In( + s, + value=Container(s, storage=[np.array(0.0)]), + update=s + a * x, + mutable=False, + ), + ], + s + a * x, + ) + g = function([x, In(a, name="a"), In(s, value=f._finder[s])], s + a * x) f(1, 2) - assert f[s] == 2 - assert g[s] == 2 + assert f._finder[s].value == 2 + assert g._finder[s].value == 2 f(1, 2) - assert f[s] == 4 - assert g[s] == 4 + assert f._finder[s].value == 4 + assert g._finder[s].value == 4 g(1, 2) # has no effect on state - assert f[s] == 4 - assert g[s] == 4 + assert f._finder[s].value == 4 + assert g._finder[s].value == 4 def test_shared_state_not_implicit(self): # This test is taken from the documentation in @@ -608,18 +502,20 @@ def test_shared_state_not_implicit(self): # behavior is still intended the doc and the test should both be # updated accordingly. x, s = scalars("xs") - inc = function([x, In(s, update=(s + x), value=10.0)], []) + inc = function( + [x, In(s, update=(s + x), value=Container(s, storage=[np.array(10.0)]))], [] + ) dec = function( - [x, In(s, update=(s - x), value=inc.container[s], implicit=False)], [] + [x, In(s, update=(s - x), value=inc._finder[s], implicit=False)], [] ) - assert dec[s] is inc[s] - inc[s] = 2 - assert dec[s] == 2 + assert dec._finder[s].value is inc._finder[s].value + inc._finder[s].value = 2 + assert dec._finder[s].value == 2 dec(1) - assert inc[s] == 1 + assert inc._finder[s].value == 1 dec(1, 0) - assert inc[s] == -1 - assert dec[s] == -1 + assert inc._finder[s].value == -1 + assert dec._finder[s].value == -1 def test_constant_output(self): # Test that if the output is a constant, we respect the pytensor memory interface @@ -736,22 +632,6 @@ def test_free(self): if not isinstance(key, Constant): assert val[0] is None - def test_default_values(self): - # Check that default values are restored - # when an exception occurs in interactive mode. - - a, b = dscalars("a", "b") - c = a + b - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - funct = function([In(a, name="first"), In(b, value=1, name="second")], c) - x = funct(first=1) - try: - funct(second=2) - except TypeError: - assert funct(first=1) == x - def test_check_for_aliased_inputs(self): b = np.random.random((5, 4)) s1 = shared(b) @@ -802,78 +682,11 @@ def test_output_dictionary(self): # Tests that function works when outputs is a dictionary x = scalar() - with pytest.warns(FutureWarning, match="output_keys is deprecated."): - f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4}) - - outputs = f(10.0) - - assert outputs["a"] == 10.0 - assert outputs["b"] == 30.0 - assert outputs["1"] == 40.0 - assert outputs["c"] == 20.0 - - def test_input_named_variables(self): - # Tests that named variables work when outputs is a dictionary - - x = scalar("x") - y = scalar("y") - - with pytest.warns(FutureWarning, match="output_keys is deprecated."): - f = function([x, y], outputs={"a": x + y, "b": x * y}) - - assert f(2, 4) == {"a": 6, "b": 8} - assert f(2, y=4) == f(2, 4) - assert f(x=2, y=4) == f(2, 4) - - def test_output_order_sorted(self): - # Tests that the output keys are sorted correctly. - - x = scalar("x") - y = scalar("y") - z = scalar("z") - e1 = scalar("1") - e2 = scalar("2") - - with pytest.warns(FutureWarning, match="output_keys is deprecated."): - f = function( - [x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2} - ) - - assert "1" in str(f.outputs[0]) - assert "2" in str(f.outputs[1]) - assert "x" in str(f.outputs[2]) - assert "y" in str(f.outputs[3]) - assert "z" in str(f.outputs[4]) - - def test_composing_function(self): - # Tests that one can compose two pytensor functions when the outputs are - # provided in a dictionary. - - x = scalar("x") - y = scalar("y") - - a = x + y - b = x * y - - with pytest.warns(FutureWarning, match="output_keys is deprecated."): - f = function([x, y], outputs={"a": a, "b": b}) - - a = scalar("a") - b = scalar("b") - - l = a + b - r = a * b - - g = function([a, b], outputs=[l, r]) - - result = g(**f(5, 7)) - - assert result[0] == 47.0 - assert result[1] == 420.0 + with pytest.raises(ValueError, match="output_keys was deprecated"): + function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4}) def test_output_list_still_works(self): # Test that function works if outputs is a list. - x = scalar("x") f = function([x], outputs=[x * 3, x * 2, x * 4, x]) @@ -911,17 +724,14 @@ def test_deepcopy(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f = function( - [ - x, - In(a, value=1.0, name="a", mutable=True), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + f = function( + [ + x, + In(a, name="a", mutable=True), + In(s, update=s + a * x, mutable=True), + ], + s + a * x, + ) try: g = copy.deepcopy(f) except NotImplementedError as e: @@ -933,12 +743,10 @@ def test_deepcopy(self): # print [(k, id(k)) for k in f.finder] # print [(k, id(k)) for k in g.finder] - assert g.container[0].storage is not f.container[0].storage - assert g.container[1].storage is not f.container[1].storage - assert g.container[2].storage is not f.container[2].storage - assert x not in g.container - assert x not in g.value - assert len(f.defaults) == len(g.defaults) + assert g._finder[0].storage is not f._finder[0].storage + assert g._finder[1].storage is not f._finder[1].storage + assert g._finder[2].storage is not f._finder[2].storage + assert x not in g._finder # Shared variable is the first input assert ( f._potential_aliased_input_groups @@ -947,45 +755,24 @@ def test_deepcopy(self): ) assert f.name == g.name assert f.maker.fgraph.name == g.maker.fgraph.name - # print(f"{f.defaults = }") - # print(f"{g.defaults = }") - for (f_req, f_feed, f_val), (g_req, g_feed, g_val) in zip( - f.defaults, g.defaults, strict=True - ): - assert f_req == g_req and f_feed == g_feed and f_val == g_val - - assert g.value[1] is not f.value[1] # should not have been copied - assert ( - g.value[2] is not f.value[2] - ) # should have been copied because it is mutable. - assert not (g.value[2] != f.value[2]).any() # its contents should be identical - - assert f(2, 1) == g( - 2 - ) # they should be in sync, default value should be copied. - assert f(2, 1) == g( - 2 - ) # they should be in sync, default value should be copied. - f(1, 2) # put them out of sync - assert f(1, 2) != g(1, 2) # they should not be equal anymore. - g(1, 2) # put them back in sync - assert f(3) == g(3) # They should be in sync again. + + assert g._finder[1].value is None and f._finder[1].value is None + assert g._finder[2].value is None and f._finder[2].value is None + + assert f(2, 1, 0) == g(2, 1, 0) def test_deepcopy_trust_input(self): a = dscalar() # the a is for 'anonymous' (un-named). x, s = dscalars("xs") - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + f = function( + [ + x, + In(a, name="a"), + In(s, update=s + a * x, mutable=True), + ], + s + a * x, + ) f.trust_input = True try: g = copy.deepcopy(f) @@ -995,35 +782,19 @@ def test_deepcopy_trust_input(self): else: raise assert f.trust_input is g.trust_input - f(np.asarray(2.0)) + f(np.array(2.0), np.array(1.0), np.array(0.0)) with pytest.raises((ValueError, AttributeError, InvalidValueError)): - f(2.0) - g(np.asarray(2.0)) + f(2.0, np.array(1.0), np.array(0.0)) + g(np.array(2.0), np.array(1.0), np.array(0.0)) with pytest.raises((ValueError, AttributeError, InvalidValueError)): - g(2.0) - - def test_output_keys(self): - x = vector() - with pytest.warns(FutureWarning, match="output_keys is deprecated."): - f = function([x], {"vec": x**2}) - o = f([2, 3, 4]) - assert isinstance(o, dict) - assert np.allclose(o["vec"], [4, 9, 16]) - with pytest.warns(FutureWarning, match="output_keys is deprecated."): - g = copy.deepcopy(f) - o = g([2, 3, 4]) - assert isinstance(o, dict) - assert np.allclose(o["vec"], [4, 9, 16]) + g(2.0, np.array(1.0), np.array(0.0)) def test_deepcopy_shared_container(self): # Ensure that shared containers remain shared after a deep copy. a, x = scalars("ax") - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - h = function([In(a, value=0.0)], a) - f = function([x, In(a, value=h.container[a], implicit=True)], x + a) + h = function([In(a, value=Container(a, storage=[np.array(0.0)]))], a) + f = function([x, In(a, value=h._finder[a], implicit=True)], x + a) try: memo = {} @@ -1037,26 +808,23 @@ def test_deepcopy_shared_container(self): return else: raise - h[a] = 1 - hc[ac] = 2 - assert f[a] == 1 - assert fc[ac] == 2 + h._finder[a].value = 1 + hc._finder[ac].value = 2 + assert f._finder[a].value == 1 + assert fc._finder[ac].value == 2 def test_pickle(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + f = function( + [ + x, + In(a, name="a"), + In(s, update=s + a * x, mutable=True), + ], + s + a * x, + ) try: # Note that here we also test protocol 0 on purpose, since it @@ -1072,26 +840,14 @@ def test_pickle(self): # print [(k, id(k)) for k in f.finder] # print [(k, id(k)) for k in g.finder] - assert g.container[0].storage is not f.container[0].storage - assert g.container[1].storage is not f.container[1].storage - assert g.container[2].storage is not f.container[2].storage - assert x not in g.container - assert x not in g.value + assert g._finder[0].storage is not f._finder[0].storage + assert g._finder[1].storage is not f._finder[1].storage + assert g._finder[2].storage is not f._finder[2].storage + assert x not in g._finder - assert g.value[1] is not f.value[1] # should not have been copied - assert ( - g.value[2] is not f.value[2] - ) # should have been copied because it is mutable. - assert not (g.value[2] != f.value[2]).any() # its contents should be identical - - assert f(2, 1) == g( - 2 - ) # they should be in sync, default value should be copied. - assert f(2, 1) == g( - 2 - ) # they should be in sync, default value should be copied. - f(1, 2) # put them out of sync - assert f(1, 2) != g(1, 2) # they should not be equal anymore. + assert g._finder[1].value is None and f._finder[1].value is None + assert g._finder[2].value is None and f._finder[2].value is None + assert f(2, 1, 0) == g(2, 1, 0) def test_optimizations_preserved(self): a = dvector() # the a is for 'anonymous' (un-named). @@ -1144,42 +900,38 @@ def test_multiple_functions(self): # some derived thing, whose inputs aren't all in the list list_of_things.append(a * x + s) - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f1 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + f1 = function( + [ + x, + In(a, name="a"), + In( + s, + value=Container(s, storage=[np.array(0.0)]), + update=s + a * x, + mutable=True, + ), + ], + s + a * x, + ) list_of_things.append(f1) # now put in a function sharing container with the previous one - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f2 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=f1.container[s], update=s + a * x, mutable=True), - ], - s + a * x, - ) + f2 = function( + [ + x, + In(a, name="a"), + In(s, value=f1._finder[s], update=s + a * x, mutable=True), + ], + s + a * x, + ) list_of_things.append(f2) - assert isinstance(f2.container[s].storage, list) - assert f2.container[s].storage is f1.container[s].storage + assert isinstance(f2._finder[s].storage, list) + assert f2._finder[s].storage is f1._finder[s].storage # now put in a function with non-scalar - v_value = np.asarray([2, 3, 4.0], dtype=config.floatX) - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - f3 = function([x, In(v, value=v_value)], x + v) + value = Container(v, storage=[np.asarray([2, 3, 4.0], dtype=config.floatX)]) + f3 = function([x, In(v, value=value)], x + v) list_of_things.append(f3) # try to pickle the entire things @@ -1214,18 +966,18 @@ def test_multiple_functions(self): assert nl[i] != ol[i] # looking at function number 1, input 's' - assert nl[4][nl[0]] is not ol[4][ol[0]] - assert nl[4][nl[0]] == ol[4][ol[0]] - assert nl[4](3) == ol[4](3) + assert nl[4]._finder[nl[0]].value is not ol[4]._finder[ol[0]].value + assert nl[4]._finder[nl[0]].value == ol[4]._finder[ol[0]].value + assert nl[4](3, 1) == ol[4](3, 1) # looking at function number 2, input 's' # make sure it's shared with the first function - assert ol[4].container[ol[0]].storage is ol[5].container[ol[0]].storage - assert nl[4].container[nl[0]].storage is nl[5].container[nl[0]].storage - assert nl[5](3) == ol[5](3) - assert nl[4].value[nl[0]] == 6 + assert ol[4]._finder[ol[0]].storage is ol[5]._finder[ol[0]].storage + assert nl[4]._finder[nl[0]].storage is nl[5]._finder[nl[0]].storage + assert nl[5](3, 1) == ol[5](3, 1) + assert nl[4]._finder[nl[0]].value == 6 - assert np.all(nl[6][nl[2]] == np.asarray([2, 3.0, 4])) + assert np.all(nl[6]._finder[nl[2]].value == np.array([2, 3.0, 4])) def test_broken_pickle_with_shared(self): saves = [] @@ -1279,7 +1031,7 @@ def exc_message(e): def test_pickle_class_with_functions(self): blah = SomethingToPickle() - assert blah.f2.container[blah.s].storage is blah.f1.container[blah.s].storage + assert blah.f2._finder[blah.s].storage is blah.f1._finder[blah.s].storage try: blah2 = copy.deepcopy(blah) @@ -1289,14 +1041,12 @@ def test_pickle_class_with_functions(self): else: raise - assert ( - blah2.f2.container[blah2.s].storage is blah2.f1.container[blah2.s].storage - ) + assert blah2.f2._finder[blah2.s].storage is blah2.f1._finder[blah2.s].storage - assert blah.f1[blah.s] == blah2.f1[blah2.s] + assert blah.f1._finder[blah.s].value == blah2.f1._finder[blah2.s].value - blah.f2(5) - assert blah.f1[blah.s] != blah2.f1[blah2.s] + blah.f2(5, 1) + assert blah.f1._finder[blah.s].value != blah2.f1._finder[blah2.s].value class SomethingToPickle: @@ -1311,29 +1061,28 @@ def __init__(self): self.e = a * x + s - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - self.f1 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + self.f1 = function( + [ + x, + In(a, name="a"), + In( + s, + value=Container(s, storage=[np.array(0.0)]), + update=s + a * x, + mutable=True, + ), + ], + s + a * x, + ) - with pytest.warns( - FutureWarning, match="Inputs with default values are deprecated." - ): - self.f2 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=self.f1.container[s], update=s + a * x, mutable=True), - ], - s + a * x, - ) + self.f2 = function( + [ + x, + In(a, name="a"), + In(s, value=self.f1._finder[s], update=s + a * x, mutable=True), + ], + s + a * x, + ) def test_empty_givens_updates(): @@ -1347,7 +1096,7 @@ def test_empty_givens_updates(): function([In(x)], y, updates={}) -@pytest.mark.parametrize("trust_input", [True, False]) +@pytest.mark.parametrize("trust_input", [True, False], ids=lambda x: f"trust_input={x}") def test_minimal_random_function_call_benchmark(trust_input, benchmark): rng = random_generator_type() x = normal(rng=rng, size=(100,)) @@ -1357,3 +1106,17 @@ def test_minimal_random_function_call_benchmark(trust_input, benchmark): rng_val = np.random.default_rng() benchmark(f, rng_val) + + +@pytest.mark.parametrize("trust_input", [True, False], ids=lambda x: f"trust_input={x}") +@pytest.mark.parametrize("linker", ["c", "cvm", "cvm_nogc"]) +def test_overhead_benchmark(trust_input, linker, benchmark): + x = pt.vector("x") + fn = function( + [In(x, borrow=True)], + Out(x, borrow=True), + trust_input=trust_input, + mode=Mode(linker=linker, optimizer=None), + ) + x_test = np.zeros(10) + benchmark(fn, x_test) diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index fae76fab0d..845eb44e52 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -9,11 +9,9 @@ BadThunkOutput, BadViewMap, DebugMode, - InvalidValueError, StochasticOrder, ) from pytensor.compile.function import function -from pytensor.compile.mode import predefined_modes from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable from pytensor.graph.features import BadOptimization @@ -21,8 +19,8 @@ from pytensor.graph.rewriting.basic import node_rewriter from pytensor.graph.rewriting.db import EquilibriumDB from pytensor.link.c.op import COp -from pytensor.tensor.math import add, dot, log -from pytensor.tensor.type import TensorType, dvector, fmatrix, fvector, scalar, vector +from pytensor.tensor.math import add, dot +from pytensor.tensor.type import dvector, fmatrix, fvector, scalar from tests import unittest_tools as utt @@ -553,59 +551,6 @@ def perform(self, node, inp, out): # f([1,2,3,4],[5,6,7,8]) -class TestCheckIsfinite: - def setup_method(self): - self.old_ts = TensorType.filter_checks_isfinite - self.old_dm = predefined_modes["DEBUG_MODE"].check_isfinite - - def teardown_method(self): - TensorType.filter_checks_isfinite = self.old_ts - predefined_modes["DEBUG_MODE"].check_isfinite = self.old_dm - - def test_check_isfinite(self): - x = vector() - f = function([x], (x + 2) * 5, mode="DEBUG_MODE") - g = function([x], log(x), mode="DEBUG_MODE") - - # this should work - f(np.log([3, 4, 5]).astype(config.floatX)) - - # if TensorType.filter_checks_isfinite were true, these would raise - # ValueError - # if not, DebugMode will check internally, and raise InvalidValueError - # passing an invalid value as an input should trigger ValueError - with pytest.raises(InvalidValueError): - f(np.log([3, -4, 5]).astype(config.floatX)) - with pytest.raises(InvalidValueError): - f((np.asarray([0, 1.0, 0]) / 0).astype(config.floatX)) - with pytest.raises(InvalidValueError): - f((np.asarray([1.0, 1.0, 1.0]) / 0).astype(config.floatX)) - - # generating an invalid value internally should trigger - # InvalidValueError - with pytest.raises(InvalidValueError): - g(np.asarray([3, -4, 5], dtype=config.floatX)) - - # this should disable the exception - TensorType.filter_checks_isfinite = False - predefined_modes["DEBUG_MODE"].check_isfinite = False - # insert several Inf - f(np.asarray(np.asarray([1.0, 1.0, 1.0]) / 0, dtype=config.floatX)) - - def test_check_isfinite_disabled(self): - x = dvector() - f = function([x], (x + 2) * 5, mode=DebugMode(check_isfinite=False)) - - # nan should go through - f(np.log([3, -4, 5])) - - # inf should go through - infs = np.asarray([1.0, 1.0, 1.0]) / 0 - # print infs - f(infs) - return - - class BrokenCImplementationAdd(COp): __props__ = () @@ -804,20 +749,6 @@ def test_output_broadcast_tensor(self): f(v_val) -def test_function_dict(): - """Tests that debug mode works where outputs is a dictionary.""" - - x = scalar("x") - - f = function([x], outputs={"1": x, "2": 2 * x, "3": 3 * x}, mode="DEBUG_MODE") - - result = f(3.0) - - assert result["1"] == 3.0 - assert result["2"] == 6.0 - assert result["3"] == 9.0 - - def test_function_list(): """Tests that debug mode works where the outputs argument is a list.""" diff --git a/tests/link/test_vm.py b/tests/link/test_vm.py index dad7ed4fdd..37aada17c4 100644 --- a/tests/link/test_vm.py +++ b/tests/link/test_vm.py @@ -6,7 +6,6 @@ from pytensor.compile.function import function from pytensor.compile.io import In from pytensor.compile.mode import Mode, get_mode -from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph @@ -17,9 +16,8 @@ from pytensor.link.utils import map_storage from pytensor.link.vm import VM, Loop, Stack, VMLinker from pytensor.tensor.math import cosh, tanh -from pytensor.tensor.type import lscalar, scalar, scalars, vector, vectors +from pytensor.tensor.type import scalar, scalars, vector, vectors from pytensor.tensor.variable import TensorConstant -from tests import unittest_tools as utt class SomeOp(Op): @@ -202,71 +200,6 @@ def build_graph(x, depth=5): # print(f"{linker} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") -@pytest.mark.parametrize( - "linker", [VMLinker(allow_partial_eval=True, use_cloop=False), "cvm"] -) -def test_partial_function(linker): - x = scalar("input") - y = x**2 - f = function( - [x], [y + 7, y - 9, y / 14.0], mode=Mode(optimizer=None, linker=linker) - ) - - if linker == "cvm": - from pytensor.link.c.cvm import CVM - - assert isinstance(f.vm, CVM) - else: - assert isinstance(f.vm, Stack) - - assert f(3, output_subset=[0, 1, 2]) == f(3) - assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]] - - utt.assert_allclose(f(5), np.array([32.0, 16.0, 1.7857142857142858])) - - -@pytest.mark.parametrize( - "linker", [VMLinker(allow_partial_eval=True, use_cloop=False), "cvm"] -) -def test_partial_function_with_output_keys(linker): - x = scalar("input") - y = 3 * x - f = function( - [x], {"a": y * 5, "b": y - 7}, mode=Mode(optimizer=None, linker=linker) - ) - - assert f(5, output_subset=["a"])["a"] == f(5)["a"] - - -@pytest.mark.parametrize( - "linker", [VMLinker(allow_partial_eval=True, use_cloop=False), "cvm"] -) -def test_partial_function_with_updates(linker): - x = lscalar("input") - y = shared(np.asarray(1, "int64"), name="global") - - mode = Mode(optimizer=None, linker=linker) - - f = function( - [x], - [x, x + 34], - updates=[(y, x + 1)], - mode=mode, - ) - g = function( - [x], - [x - 6], - updates=[(y, y + 3)], - mode=mode, - ) - - assert f(3, output_subset=[]) == [] - assert y.get_value() == 4 - assert g(30, output_subset=[0]) == [24] - assert g(40, output_subset=[]) == [] - assert y.get_value() == 10 - - def test_allow_gc_cvm(): mode = config.mode if mode in ["DEBUG_MODE", "DebugMode"]: