diff --git a/.gitignore b/.gitignore index 787c4b0e..6e84dd5a 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,6 @@ benchmark/notebooks/.ipynb_checkpoints benchmark/scripts/__pycache__ benchmark/scripts/benchmarks-pypsa-eur/__pycache__ benchmark/scripts/leftovers/ + +# IDE +.idea/ diff --git a/doc/release_notes.rst b/doc/release_notes.rst index bfd188a0..1ae1f5e4 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,12 +4,14 @@ Release Notes .. Upcoming Version .. ---------------- + +* Improved variable/expression arithmetic methods so that they correctly handle types + Version 0.5.5 -------------- * Internally assign new data fields to expressions with a multiindexed-safe routine. - Version 0.5.4 -------------- diff --git a/linopy/common.py b/linopy/common.py index 3e2449f4..81d4e41d 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: from linopy.constraints import Constraint - from linopy.expressions import LinearExpression + from linopy.expressions import LinearExpression, QuadraticExpression from linopy.variables import Variable @@ -994,13 +994,13 @@ def check_common_keys_values(list_of_dicts: list[dict[str, Any]]) -> bool: def align( - *objects: LinearExpression | Variable | T_Alignable, + *objects: LinearExpression | QuadraticExpression | Variable | T_Alignable, join: JoinOptions = "inner", copy: bool = True, indexes: Any = None, exclude: str | Iterable[Hashable] = frozenset(), fill_value: Any = dtypes.NA, -) -> tuple[LinearExpression | Variable | T_Alignable, ...]: +) -> tuple[LinearExpression | QuadraticExpression | Variable | T_Alignable, ...]: """ Given any number of Variables, Expressions, Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -1055,13 +1055,13 @@ def align( """ - from linopy.expressions import LinearExpression + from linopy.expressions import LinearExpression, QuadraticExpression from linopy.variables import Variable finisher: list[partial[Any] | Callable[[Any], Any]] = [] das: list[Any] = [] for obj in objects: - if isinstance(obj, LinearExpression): + if isinstance(obj, (LinearExpression, QuadraticExpression)): finisher.append(partial(obj.__class__, model=obj.model)) das.append(obj.data) elif isinstance(obj, Variable): @@ -1090,7 +1090,14 @@ def align( return tuple([f(da) for f, da in zip(finisher, aligned)]) -LocT = TypeVar("LocT", "Dataset", "Variable", "LinearExpression", "Constraint") +LocT = TypeVar( + "LocT", + "Dataset", + "Variable", + "LinearExpression", + "QuadraticExpression", + "Constraint", +) class LocIndexer(Generic[LocT]): diff --git a/linopy/expressions.py b/linopy/expressions.py index 88594f75..954c8e80 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -9,13 +9,11 @@ import functools import logging +from abc import ABC, abstractmethod from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field from itertools import product, zip_longest -from typing import ( - TYPE_CHECKING, - Any, -) +from typing import TYPE_CHECKING, Any, TypeVar, overload from warnings import warn import numpy as np @@ -317,44 +315,7 @@ def sum(self, **kwargs: Any) -> LinearExpression: return LinearExpression(ds, self.model) -class LinearExpression: - """ - A linear expression consisting of terms of coefficients and variables. - - The LinearExpression class is a subclass of xarray.Dataset which allows to - apply most xarray functions on it. However most arithmetic operations are - overwritten. Like this you can easily expand and modify the linear - expression. - - Examples - -------- - >>> from linopy import Model - >>> import pandas as pd - >>> m = Model() - >>> x = m.add_variables(pd.Series([0, 0]), 1, name="x") - >>> y = m.add_variables(4, pd.Series([8, 10]), name="y") - - Combining expressions: - - >>> expr = 3 * x - >>> type(expr) - - - >>> other = 4 * y - >>> type(expr + other) - - - Multiplying: - - >>> type(3 * expr) - - - Summation over dimensions - - >>> type(expr.sum(dim="dim_0")) - - """ - +class BaseExpression(ABC): __slots__ = ("_data", "_model") __array_ufunc__ = None __array_priority__ = 10000 @@ -363,6 +324,13 @@ class LinearExpression: _fill_value = FILL_VALUE _data: Dataset + @property + @abstractmethod + def flat(self) -> pd.DataFrame: ... + + @abstractmethod + def to_polars(self) -> pl.DataFrame: ... + def __init__(self, data: Dataset | Any | None, model: Model) -> None: from linopy.model import Model @@ -487,72 +455,46 @@ def print(self, display_max_rows: int = 20, display_max_terms: int = 20) -> None ) print(self) - def __add__(self, other: SideLike) -> LinearExpression: - """ - Add an expression to others. + @abstractmethod + def __add__( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: ... - Note: If other is a numpy array or pandas object without axes names, - dimension names of self will be filled in other - """ - try: - if np.isscalar(other): - return self.assign(const=self.const + other) + @abstractmethod + def __radd__(self: GenericExpression, other: SideLike) -> GenericExpression: ... - other = as_expression(other, model=self.model, dims=self.coord_dims) - return merge([self, other], cls=self.__class__) - except TypeError: - return NotImplemented + @abstractmethod + def __sub__( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: ... - def __radd__(self, other: int) -> LinearExpression | NotImplementedType: - # This is needed for using python's sum function - return self if other == 0 else NotImplemented + @abstractmethod + def __rsub__(self: GenericExpression, other: SideLike) -> GenericExpression: ... - def __sub__(self, other: SideLike) -> LinearExpression: - """ - Subtract others from expression. + @abstractmethod + def __mul__( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: ... - Note: If other is a numpy array or pandas object without axes names, - dimension names of self will be filled in other - """ - try: - if np.isscalar(other): - return self.assign_multiindex_safe(const=self.const - other) + @abstractmethod + def __rmul__( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: ... - other = as_expression(other, model=self.model, dims=self.coord_dims) - return merge([self, -other], cls=self.__class__) - except TypeError: - return NotImplemented + @abstractmethod + def __matmul__( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: ... + + @abstractmethod + def __pow__(self, other: int) -> QuadraticExpression: ... - def __neg__(self) -> LinearExpression | QuadraticExpression: + def __neg__(self: GenericExpression) -> GenericExpression: """ Get the negative of the expression. """ return self.assign_multiindex_safe(coeffs=-self.coeffs, const=-self.const) - def __mul__( - self, - other: SideLike, - ) -> LinearExpression | QuadraticExpression: - """ - Multiply the expr by a factor. - """ - try: - if isinstance(other, QuadraticExpression): - raise TypeError( - "unsupported operand type(s) for *: " - f"{type(self)} and {type(other)}. " - "Higher order non-linear expressions are not yet supported." - ) - elif isinstance(other, (variables.Variable, variables.ScalarVariable)): - other = other.to_linexpr() - - if isinstance(other, (LinearExpression, ScalarLinearExpression)): - return self._multiply_by_linear_expression(other) - else: - return self._multiply_by_constant(other) - except TypeError: - return NotImplemented - def _multiply_by_linear_expression( self, other: LinearExpression | ScalarLinearExpression ) -> QuadraticExpression: @@ -566,66 +508,45 @@ def _multiply_by_linear_expression( # merge on factor dimension only returns v1 * v2 + c1 * c2 ds = other.data[["coeffs", "vars"]].sel(_term=0).broadcast_like(self.data) ds = assign_multiindex_safe(ds, const=other.const) - res = merge([self, ds], dim=FACTOR_DIM, cls=QuadraticExpression) + res = merge([self, ds], dim=FACTOR_DIM, cls=QuadraticExpression) # type: ignore # deal with cross terms c1 * v2 + c2 * v1 if self.has_constant: res = res + self.const * other.reset_const() if other.has_constant: res = res + self.reset_const() * other.const - return res # type: ignore + return res - def _multiply_by_constant(self, other: ConstantLike) -> LinearExpression: + def _multiply_by_constant( + self: GenericExpression, other: ConstantLike + ) -> GenericExpression: multiplier = as_dataarray(other, coords=self.coords, dims=self.coord_dims) coeffs = self.coeffs * multiplier assert all(coeffs.sizes[d] == s for d, s in self.coeffs.sizes.items()) const = self.const * multiplier return self.assign(coeffs=coeffs, const=const) - def __pow__(self, other: int) -> QuadraticExpression: - """ - Power of the expression with a coefficient. The only coefficient allowed is 2. - """ - if not other == 2: - raise ValueError("Power must be 2.") - return self * self # type: ignore - - def __rmul__(self, other: ConstantLike) -> LinearExpression | QuadraticExpression: - """ - Right-multiply the expr by a factor. - """ - return self.__mul__(other) - - def __matmul__( - self, other: LinearExpression | Variable | ndarray | DataArray - ) -> LinearExpression | QuadraticExpression: - """ - Matrix multiplication with other, similar to xarray dot. - """ - if not isinstance(other, (LinearExpression, variables.Variable)): - other = as_dataarray(other, coords=self.coords, dims=self.coord_dims) - - common_dims = list(set(self.coord_dims).intersection(other.dims)) - return (self * other).sum(dim=common_dims) - - def __div__( - self, other: Variable | ConstantLike - ) -> LinearExpression | QuadraticExpression: + def __div__(self: GenericExpression, other: SideLike) -> GenericExpression: try: if isinstance( - other, (LinearExpression, variables.Variable, variables.ScalarVariable) + other, + ( + variables.Variable, + variables.ScalarVariable, + LinearExpression, + ScalarLinearExpression, + QuadraticExpression, + ), ): raise TypeError( "unsupported operand type(s) for /: " f"{type(self)} and {type(other)}" "Non-linear expressions are not yet supported." ) - return self.__mul__(1 / other) + return self._multiply_by_constant(other=1 / other) except TypeError: return NotImplemented - def __truediv__( - self, other: Variable | ConstantLike - ) -> LinearExpression | QuadraticExpression: + def __truediv__(self: GenericExpression, other: SideLike) -> GenericExpression: return self.__div__(other) def __le__(self, rhs: SideLike) -> Constraint: @@ -647,27 +568,33 @@ def __lt__(self, other: Any) -> NotImplementedType: "Inequalities only ever defined for >= rather than >." ) - def add(self, other: SideLike) -> LinearExpression: + def add( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: """ Add an expression to others. """ return self.__add__(other) - def sub(self, other: SideLike) -> LinearExpression: + def sub( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: """ Subtract others from expression. """ return self.__sub__(other) - def mul(self, other: SideLike) -> LinearExpression | QuadraticExpression: + def mul( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: """ Multiply the expr by a factor. """ return self.__mul__(other) def div( - self, other: Variable | float | int - ) -> LinearExpression | QuadraticExpression: + self: GenericExpression, other: VariableLike | ConstantLike + ) -> GenericExpression | QuadraticExpression: """ Divide the expr by a factor. """ @@ -679,15 +606,17 @@ def pow(self, other: int) -> QuadraticExpression: """ return self.__pow__(other) - def dot(self, other: ndarray) -> LinearExpression: + def dot( + self: GenericExpression, other: ndarray + ) -> GenericExpression | QuadraticExpression: """ Matrix multiplication with other, similar to xarray dot. """ return self.__matmul__(other) def __getitem__( - self, selector: int | tuple[slice, list[int]] | slice - ) -> LinearExpression | QuadraticExpression: + self: GenericExpression, selector: int | tuple[slice, list[int]] | slice + ) -> GenericExpression: """ Get selection from the expression. This is a wrapper around the xarray __getitem__ method. It returns a @@ -830,42 +759,12 @@ def solution(self) -> DataArray: sol = (self.coeffs * vals).sum(TERM_DIM) + self.const return sol.rename("solution") - @classmethod - def _sum( - cls, - expr: LinearExpression | Dataset, - dim: DimsLike | None = None, - ) -> Dataset: - data = _expr_unwrap(expr) - - if isinstance(dim, str): - dim = [dim] - elif isinstance(dim, EllipsisType): - dim = None - - if dim is None: - vars = DataArray(data.vars.data.ravel(), dims=TERM_DIM) - coeffs = DataArray(data.coeffs.data.ravel(), dims=TERM_DIM) - const = data.const.sum() - ds = xr.Dataset({"vars": vars, "coeffs": coeffs, "const": const}) - else: - dim = [d for d in dim if d != TERM_DIM] - ds = ( - data[["coeffs", "vars"]] - .reset_index(dim, drop=True) - .rename({TERM_DIM: STACKED_TERM_DIM}) - .stack({TERM_DIM: [STACKED_TERM_DIM] + dim}, create_index=False) - ) - ds = assign_multiindex_safe(ds, const=data.const.sum(dim)) - - return ds - def sum( - self, + self: GenericExpression, dim: DimsLike | None = None, drop_zeros: bool = False, **kwargs: Any, - ) -> LinearExpression: + ) -> GenericExpression: """ Sum the expression over all or a subset of dimensions. @@ -949,228 +848,63 @@ def cumsum( dim_dict = {dim_name: self.data.sizes[dim_name] for dim_name in dim} return self.rolling(dim=dim_dict).sum(keep_attrs=keep_attrs, skipna=skipna) - @classmethod - def from_tuples( - cls, *tuples: tuple, model: Model | None = None - ) -> LinearExpression: + def to_constraint( + self, sign: SignLike, rhs: ConstantLike | VariableLike | ExpressionLike + ) -> Constraint: """ - Create a linear expression by using tuples of coefficients and - variables. - - The function internally checks that all variables in the tuples belong to the same - reference model. + Convert a linear expression to a constraint. Parameters ---------- - tuples : tuples of (coefficients, variables) - Each tuple represents one term in the resulting linear expression, - which can possibly span over multiple dimensions: - - * coefficients : int/float/array_like - The coefficient(s) in the term, if the coefficients array - contains dimensions which do not appear in - the variables, the variables are broadcasted. - * variables : str/array_like/linopy.Variable - The variable(s) going into the term. These may be referenced - by name. + sign : str, array-like + Sign(s) of the constraints. + rhs : constant, Variable, LinearExpression + Right-hand side of the constraint. Returns ------- - linopy.LinearExpression - - Examples - -------- - >>> from linopy import Model - >>> import pandas as pd - >>> m = Model() - >>> x = m.add_variables(pd.Series([0, 0]), 1) - >>> y = m.add_variables(4, pd.Series([8, 10])) - >>> expr = LinearExpression.from_tuples((10, x), (1, y)) - - This is the same as calling ``10*x + y`` but a bit more performant. + Constraint with strict separation of the linear expressions of variables + which are moved to the left-hand-side and constant values which are moved + to the right-hand side. """ - exprs = [] - for t in tuples: - if len(t) == 2: - # assume first element is coefficient and second is variable - c, v = t - if not isinstance(v, (variables.Variable, variables.ScalarVariable)): - raise TypeError("Expected variable as second element of tuple.") - expr = v.to_linexpr(c) - const = None - if model is None: - model = expr.model # TODO: Ensure equality of models - elif len(t) == 1: - # assume that the element is a constant - c, v = None, None - (const,) = as_dataarray(t) - if model is None: - raise ValueError("Model must be provided when using constants.") - expr = LinearExpression(const, model) - else: - raise ValueError("Expected tuples of length 1 or 2.") + all_to_lhs = (self - rhs).data + data = assign_multiindex_safe( + all_to_lhs[["coeffs", "vars"]], sign=sign, rhs=-all_to_lhs.const + ) + return constraints.Constraint(data, model=self.model) - exprs.append(expr) + def reset_const(self: GenericExpression) -> GenericExpression: + """ + Reset the constant of the linear expression to zero. + """ + return self.__class__(self.data[["coeffs", "vars"]], self.model) - return merge(exprs, cls=cls) if len(exprs) > 1 else exprs[0] + def isnull(self) -> DataArray: + """ + Get a boolean mask with true values where there is only missing values in an expression. - @classmethod - def from_rule( - cls, - model: Model, - rule: Callable, - coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, - ) -> LinearExpression: + Returns + ------- + xr.DataArray """ - Create a linear expression from a rule and a set of coordinates. + helper_dims = set(self.vars.dims).intersection(HELPER_DIMS) + return (self.vars == -1).all(helper_dims) & self.const.isnull() - This functionality mirrors the assignment of linear expression as done by - Pyomo. + def where( + self: GenericExpression, + cond: DataArray, + other: LinearExpression + | int + | DataArray + | dict[str, float | int | DataArray] + | None = None, + **kwargs: Any, + ) -> GenericExpression: + """ + Filter variables based on a condition. - - Parameters - ---------- - model : linopy.Model - Passed to function `rule` as a first argument. - rule : callable - Function to be called for each combinations in `coords`. - The first argument of the function is the underlying `linopy.Model`. - The following arguments are given by the coordinates for accessing - the variables. The function has to return a - `ScalarLinearExpression`. Therefore use the `.at` accessor when - indexing variables. - coords : coordinate-like - Coordinates to processed by `xarray.DataArray`. - For each combination of coordinates, the function - given by `rule` is called. The order and size of coords has - to be same as the argument list followed by `model` in - function `rule`. - - - Returns - ------- - linopy.LinearExpression - - Examples - -------- - >>> from linopy import Model, LinearExpression - >>> m = Model() - >>> coords = pd.RangeIndex(10), ["a", "b"] - >>> x = m.add_variables(0, 100, coords) - >>> def bound(m, i, j): - ... if i % 2: - ... return (i - 1) * x.at[i - 1, j] - ... else: - ... return i * x.at[i, j] - ... - >>> expr = LinearExpression.from_rule(m, bound, coords) - >>> con = m.add_constraints(expr <= 10) - """ - if not isinstance(coords, DataArrayCoordinates): - coords = DataArray(coords=coords).coords - - # test output type - output = rule(model, *[c.values[0] for c in coords.values()]) - if not isinstance(output, ScalarLinearExpression) and output is not None: - msg = f"`rule` has to return ScalarLinearExpression not {type(output)}." - raise TypeError(msg) - - combinations = product(*[c.values for c in coords.values()]) - exprs = [] - placeholder = ScalarLinearExpression((np.nan,), (-1,), model) - exprs = [rule(model, *coord) or placeholder for coord in combinations] - return cls._from_scalarexpression_list(exprs, coords, model) - - @classmethod - def _from_scalarexpression_list( - cls, - exprs: list[ScalarLinearExpression], - coords: Mapping, - model: Model, - ) -> LinearExpression: - """ - Create a LinearExpression from a list of lists with different lengths. - """ - shape = list(map(len, coords.values())) - - coeffs = array(tuple(zip_longest(*(e.coeffs for e in exprs), fillvalue=nan))) - vars = array(tuple(zip_longest(*(e.vars for e in exprs), fillvalue=-1))) - - nterm = vars.shape[0] - coeffs = coeffs.reshape((nterm, *shape)) - vars = vars.reshape((nterm, *shape)) - - coeffdata = DataArray(coeffs, coords, dims=(TERM_DIM, *coords)) - vardata = DataArray(vars, coords, dims=(TERM_DIM, *coords)) - ds = Dataset({"coeffs": coeffdata, "vars": vardata}).transpose(..., TERM_DIM) - - return cls(ds, model) - - def to_quadexpr(self) -> QuadraticExpression: - """Convert LinearExpression to QuadraticExpression.""" - vars = self.data.vars.expand_dims(FACTOR_DIM) - fill_value = self._fill_value["vars"] - vars = xr.concat([vars, xr.full_like(vars, fill_value)], dim=FACTOR_DIM) - data = assign_multiindex_safe(self.data, vars=vars) - return QuadraticExpression(data, self.model) - - def to_constraint( - self, sign: SignLike, rhs: ConstantLike | VariableLike | ExpressionLike - ) -> Constraint: - """ - Convert a linear expression to a constraint. - - Parameters - ---------- - sign : str, array-like - Sign(s) of the constraints. - rhs : constant, Variable, LinearExpression - Right-hand side of the constraint. - - Returns - ------- - Constraint with strict separation of the linear expressions of variables - which are moved to the left-hand-side and constant values which are moved - to the right-hand side. - """ - all_to_lhs = (self - rhs).data - data = assign_multiindex_safe( - all_to_lhs[["coeffs", "vars"]], sign=sign, rhs=-all_to_lhs.const - ) - return constraints.Constraint(data, model=self.model) - - def reset_const(self) -> LinearExpression: - """ - Reset the constant of the linear expression to zero. - """ - return self.__class__(self.data[["coeffs", "vars"]], self.model) - - def isnull(self) -> DataArray: - """ - Get a boolean mask with true values where there is only missing values in an expression. - - Returns - ------- - xr.DataArray - """ - helper_dims = set(self.vars.dims).intersection(HELPER_DIMS) - return (self.vars == -1).all(helper_dims) & self.const.isnull() - - def where( - self, - cond: DataArray, - other: LinearExpression - | int - | DataArray - | dict[str, float | int | DataArray] - | None = None, - **kwargs: Any, - ) -> LinearExpression | QuadraticExpression: - """ - Filter variables based on a condition. - - This operation call ``xarray.Dataset.where`` but sets the default - fill value to -1 for variables and ensures preserving the linopy.LinearExpression type. + This operation call ``xarray.Dataset.where`` but sets the default + fill value to -1 for variables and ensures preserving the linopy.LinearExpression type. Parameters ---------- @@ -1187,7 +921,7 @@ def where( Returns ------- - linopy.LinearExpression + linopy.LinearExpression or linopy.QuadraticExpression """ # Cannot set `other` if drop=True _other: dict[str, float] | dict[str, int | float | DataArray] | DataArray | None @@ -1209,14 +943,14 @@ def where( return self.__class__(self.data.where(cond, other=_other, **kwargs), self.model) def fillna( - self, + self: GenericExpression, value: int | float | DataArray | Dataset | LinearExpression | dict[str, float | int | DataArray], - ) -> LinearExpression: + ) -> GenericExpression: """ Fill missing values with a given value. @@ -1232,15 +966,15 @@ def fillna( Returns ------- - linopy.LinearExpression - A new `linopy.LinearExpression` object with missing values filled with the given value. + linopy.LinearExpression or linopy.QuadraticExpression + A new object with missing values filled with the given value. """ value = _expr_unwrap(value) if isinstance(value, (DataArray, np.floating, np.integer, int, float)): value = {"const": value} return self.__class__(self.data.fillna(value), self.model) - def diff(self, dim: str, n: int = 1) -> LinearExpression: + def diff(self: GenericExpression, dim: str, n: int = 1) -> GenericExpression: """ Calculate the n-th order discrete difference along given axis. @@ -1256,7 +990,7 @@ def diff(self, dim: str, n: int = 1) -> LinearExpression: Returns ------- - linopy.LinearExpression + linopy.LinearExpression or linopy.QuadraticExpression """ return self - self.shift({dim: n}) @@ -1361,7 +1095,7 @@ def empty(self) -> EmptyDeprecationWrapper: """ return EmptyDeprecationWrapper(not self.size) - def densify_terms(self) -> LinearExpression: + def densify_terms(self: GenericExpression) -> GenericExpression: """ Move all non-zero term entries to the front and cut off all-zero entries in the term-axis. @@ -1392,7 +1126,7 @@ def densify_terms(self) -> LinearExpression: return self.__class__(data.sel({TERM_DIM: slice(0, nterm)}), self.model) - def sanitize(self) -> LinearExpression: + def sanitize(self: GenericExpression) -> GenericExpression: """ Sanitize LinearExpression by ensuring int dtype for variables. @@ -1403,13 +1137,277 @@ def sanitize(self) -> LinearExpression: if not np.issubdtype(self.vars.dtype, np.integer): return self.assign(vars=self.vars.fillna(-1).astype(int)) - return self + return self + + def equals(self, other: BaseExpression) -> bool: + return self.data.equals(_expr_unwrap(other)) + + def __iter__(self) -> Iterator[Hashable]: + return self.data.__iter__() + + @classmethod + def _sum( + cls, + expr: BaseExpression | Dataset, + dim: DimsLike | None = None, + ) -> Dataset: + data = _expr_unwrap(expr) + if cls is QuadraticExpression: + dim = dim or list(set(data.dims) - set(HELPER_DIMS)) + + if isinstance(dim, str): + dim = [dim] + elif isinstance(dim, EllipsisType): + dim = None + + if dim is None: + vars = DataArray(data.vars.data.ravel(), dims=TERM_DIM) + coeffs = DataArray(data.coeffs.data.ravel(), dims=TERM_DIM) + const = data.const.sum() + ds = xr.Dataset({"vars": vars, "coeffs": coeffs, "const": const}) + else: + dim = [d for d in dim if d != TERM_DIM] + ds = ( + data[["coeffs", "vars"]] + .reset_index(dim, drop=True) + .rename({TERM_DIM: STACKED_TERM_DIM}) + .stack({TERM_DIM: [STACKED_TERM_DIM] + dim}, create_index=False) + ) + ds = assign_multiindex_safe(ds, const=data.const.sum(dim)) + + return ds + + # Wrapped function which would convert variable to dataarray + assign = exprwrap(assign_multiindex_safe) + + assign_multiindex_safe = exprwrap(assign_multiindex_safe) + + assign_attrs = exprwrap(Dataset.assign_attrs) + + assign_coords = exprwrap(Dataset.assign_coords) + + astype = exprwrap(Dataset.astype) + + bfill = exprwrap(Dataset.bfill) + + broadcast_like = exprwrap(Dataset.broadcast_like) + + chunk = exprwrap(Dataset.chunk) + + drop = exprwrap(Dataset.drop) + + drop_vars = exprwrap(Dataset.drop_vars) + + drop_sel = exprwrap(Dataset.drop_sel) + + drop_isel = exprwrap(Dataset.drop_isel) + + expand_dims = exprwrap(Dataset.expand_dims) + + ffill = exprwrap(Dataset.ffill) + + sel = exprwrap(Dataset.sel) + + isel = exprwrap(Dataset.isel) + + shift = exprwrap(Dataset.shift) + + swap_dims = exprwrap(Dataset.swap_dims) + + set_index = exprwrap(Dataset.set_index) + + reindex = exprwrap(Dataset.reindex, fill_value=_fill_value) + + reindex_like = exprwrap(Dataset.reindex_like, fill_value=_fill_value) + + rename = exprwrap(Dataset.rename) + + reset_index = exprwrap(Dataset.reset_index) + + rename_dims = exprwrap(Dataset.rename_dims) + + roll = exprwrap(Dataset.roll) + + stack = exprwrap(Dataset.stack) + + unstack = exprwrap(Dataset.unstack) + + iterate_slices = iterate_slices + + +GenericExpression = TypeVar("GenericExpression", bound=BaseExpression) + + +class LinearExpression(BaseExpression): + """ + A linear expression consisting of terms of coefficients and variables. + + The LinearExpression class is a subclass of xarray.Dataset which allows to + apply most xarray functions on it. However most arithmetic operations are + overwritten. Like this you can easily expand and modify the linear + expression. + + Examples + -------- + >>> from linopy import Model + >>> import pandas as pd + >>> m = Model() + >>> x = m.add_variables(pd.Series([0, 0]), 1, name="x") + >>> y = m.add_variables(4, pd.Series([8, 10]), name="y") + + Combining expressions: + + >>> expr = 3 * x + >>> type(expr) + + + >>> other = 4 * y + >>> type(expr + other) + + + Multiplying: + + >>> type(3 * expr) + + + Summation over dimensions + + >>> type(expr.sum(dim="dim_0")) + + """ + + @overload + def __add__( + self, + other: ConstantLike | VariableLike | ScalarLinearExpression | LinearExpression, + ) -> LinearExpression: ... + + @overload + def __add__(self, other: QuadraticExpression) -> QuadraticExpression: ... + + def __add__( + self, + other: ConstantLike + | VariableLike + | ScalarLinearExpression + | LinearExpression + | QuadraticExpression, + ) -> LinearExpression | QuadraticExpression: + """ + Add an expression to others. + + Note: If other is a numpy array or pandas object without axes names, + dimension names of self will be filled in other + """ + if isinstance(other, QuadraticExpression): + return other.__add__(self) + + try: + if np.isscalar(other): + return self.assign(const=self.const + other) + + other = as_expression(other, model=self.model, dims=self.coord_dims) + return merge([self, other], cls=self.__class__) + except TypeError: + return NotImplemented + + def __radd__(self, other: ConstantLike) -> LinearExpression: + try: + return self + other + except TypeError: + return NotImplemented + + @overload + def __sub__( + self, + other: ConstantLike | VariableLike | ScalarLinearExpression | LinearExpression, + ) -> LinearExpression: ... + + @overload + def __sub__(self, other: QuadraticExpression) -> QuadraticExpression: ... + + def __sub__( + self, + other: ConstantLike + | VariableLike + | ScalarLinearExpression + | LinearExpression + | QuadraticExpression, + ) -> LinearExpression | QuadraticExpression: + try: + return self.__add__(-other) + except TypeError: + return NotImplemented + + def __rsub__(self, other: ConstantLike | Variable) -> LinearExpression: + try: + return (self * -1) + other + except TypeError: + return NotImplemented + + @overload + def __mul__(self, other: ConstantLike) -> LinearExpression: ... + + @overload + def __mul__(self, other: VariableLike | ExpressionLike) -> QuadraticExpression: ... + + def __mul__( + self, + other: SideLike, + ) -> LinearExpression | QuadraticExpression: + """ + Multiply the expr by a factor. + """ + if isinstance(other, QuadraticExpression): + return other.__rmul__(self) + + try: + if isinstance(other, (variables.Variable, variables.ScalarVariable)): + other = other.to_linexpr() + + if isinstance(other, (LinearExpression, ScalarLinearExpression)): + return self._multiply_by_linear_expression(other) + else: + return self._multiply_by_constant(other) + except TypeError: + return NotImplemented + + def __pow__(self, other: int) -> QuadraticExpression: + """ + Power of the expression with a coefficient. The only coefficient allowed is 2. + """ + if not other == 2: + raise ValueError("Power must be 2.") + return self * self # type: ignore + + def __rmul__(self, other: ConstantLike) -> LinearExpression: + """ + Right-multiply the expr by a factor. + """ + try: + return self * other + except TypeError: + return NotImplemented + + @overload + def __matmul__(self, other: ConstantLike) -> LinearExpression: ... - def equals(self, other: LinearExpression) -> bool: - return self.data.equals(_expr_unwrap(other)) + @overload + def __matmul__( + self, other: VariableLike | ExpressionLike + ) -> QuadraticExpression: ... - def __iter__(self) -> Iterator[Hashable]: - return self.data.__iter__() + def __matmul__( + self, other: ConstantLike | VariableLike | ExpressionLike + ) -> LinearExpression | QuadraticExpression: + """ + Matrix multiplication with other, similar to xarray dot. + """ + if not isinstance(other, (LinearExpression, variables.Variable)): + other = as_dataarray(other, coords=self.coords, dims=self.coord_dims) + + common_dims = list(set(self.coord_dims).intersection(other.dims)) + return (self * other).sum(dim=common_dims) @property def flat(self) -> pd.DataFrame: @@ -1435,6 +1433,14 @@ def mask_func(data: pd.DataFrame) -> pd.Series: check_has_nulls(df, name=self.type) return df + def to_quadexpr(self) -> QuadraticExpression: + """Convert LinearExpression to QuadraticExpression.""" + vars = self.data.vars.expand_dims(FACTOR_DIM) + fill_value = self._fill_value["vars"] + vars = xr.concat([vars, xr.full_like(vars, fill_value)], dim=FACTOR_DIM) + data = self.data.assign(vars=vars) + return QuadraticExpression(data, self.model) + def to_polars(self) -> pl.DataFrame: """ Convert the expression to a polars DataFrame. @@ -1453,65 +1459,204 @@ def to_polars(self) -> pl.DataFrame: check_has_nulls_polars(df, name=self.type) return df - # Wrapped function which would convert variable to dataarray - assign = exprwrap(assign_multiindex_safe) + @classmethod + def _from_scalarexpression_list( + cls, + exprs: list[ScalarLinearExpression], + coords: Mapping, + model: Model, + ) -> LinearExpression: + """ + Create a LinearExpression from a list of lists with different lengths. + """ + shape = list(map(len, coords.values())) - assign_multiindex_safe = exprwrap(assign_multiindex_safe) + coeffs = array(tuple(zip_longest(*(e.coeffs for e in exprs), fillvalue=nan))) + vars = array(tuple(zip_longest(*(e.vars for e in exprs), fillvalue=-1))) - assign_attrs = exprwrap(Dataset.assign_attrs) + nterm = vars.shape[0] + coeffs = coeffs.reshape((nterm, *shape)) + vars = vars.reshape((nterm, *shape)) - assign_coords = exprwrap(Dataset.assign_coords) + coeffdata = DataArray(coeffs, coords, dims=(TERM_DIM, *coords)) + vardata = DataArray(vars, coords, dims=(TERM_DIM, *coords)) + ds = Dataset({"coeffs": coeffdata, "vars": vardata}).transpose(..., TERM_DIM) - astype = exprwrap(Dataset.astype) + return cls(ds, model) - bfill = exprwrap(Dataset.bfill) + @classmethod + def from_rule( + cls, + model: Model, + rule: Callable, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + ) -> LinearExpression: + """ + Create a linear expression from a rule and a set of coordinates. - broadcast_like = exprwrap(Dataset.broadcast_like) + This functionality mirrors the assignment of linear expression as done by + Pyomo. - chunk = exprwrap(Dataset.chunk) - drop = exprwrap(Dataset.drop) + Parameters + ---------- + model : linopy.Model + Passed to function `rule` as a first argument. + rule : callable + Function to be called for each combinations in `coords`. + The first argument of the function is the underlying `linopy.Model`. + The following arguments are given by the coordinates for accessing + the variables. The function has to return a + `ScalarLinearExpression`. Therefore use the `.at` accessor when + indexing variables. + coords : coordinate-like + Coordinates to processed by `xarray.DataArray`. + For each combination of coordinates, the function + given by `rule` is called. The order and size of coords has + to be same as the argument list followed by `model` in + function `rule`. - drop_vars = exprwrap(Dataset.drop_vars) - drop_sel = exprwrap(Dataset.drop_sel) + Returns + ------- + linopy.LinearExpression - drop_isel = exprwrap(Dataset.drop_isel) + Examples + -------- + >>> from linopy import Model, LinearExpression + >>> m = Model() + >>> coords = pd.RangeIndex(10), ["a", "b"] + >>> x = m.add_variables(0, 100, coords) + >>> def bound(m, i, j): + ... if i % 2: + ... return (i - 1) * x.at[i - 1, j] + ... else: + ... return i * x.at[i, j] + ... + >>> expr = LinearExpression.from_rule(m, bound, coords) + >>> con = m.add_constraints(expr <= 10) + """ + if not isinstance(coords, DataArrayCoordinates): + coords = DataArray(coords=coords).coords - expand_dims = exprwrap(Dataset.expand_dims) + # test output type + output = rule(model, *[c.values[0] for c in coords.values()]) + if not isinstance(output, ScalarLinearExpression) and output is not None: + msg = f"`rule` has to return ScalarLinearExpression not {type(output)}." + raise TypeError(msg) - ffill = exprwrap(Dataset.ffill) + combinations = product(*[c.values for c in coords.values()]) - sel = exprwrap(Dataset.sel) + placeholder = ScalarLinearExpression((np.nan,), (-1,), model) + exprs = [rule(model, *coord) or placeholder for coord in combinations] + return cls._from_scalarexpression_list(exprs, coords, model) - isel = exprwrap(Dataset.isel) + @classmethod + def from_tuples( + cls, + *tuples: tuple[ConstantLike, str | Variable | ScalarVariable] | ConstantLike, + model: Model | None = None, + ) -> LinearExpression: + """ + Create a linear expression by using tuples of coefficients and + variables. - shift = exprwrap(Dataset.shift) + The function internally checks that all variables in the tuples belong to the same + reference model. - swap_dims = exprwrap(Dataset.swap_dims) + Parameters + ---------- + tuples : A list of elements. Each element is either: + * (coefficients, variables) + * constant - set_index = exprwrap(Dataset.set_index) + Each (coefficients, variables) tuple represents one term in the resulting linear expression, + which can possibly span over multiple dimensions: - reindex = exprwrap(Dataset.reindex, fill_value=_fill_value) + * coefficients : int/float/array_like + The coefficient(s) in the term, if the coefficients array + contains dimensions which do not appear in + the variables, the variables are broadcasted. + * variables : str/array_like/linopy.Variable + The variable(s) going into the term. These may be referenced + by name. + * constant: int/float/array_like + The constant value to add to the expression + model : The linopy.Model. If None this can be inferred from the provided variables - reindex_like = exprwrap(Dataset.reindex_like, fill_value=_fill_value) + Returns + ------- + linopy.LinearExpression - rename = exprwrap(Dataset.rename) + Examples + -------- + >>> from linopy import Model + >>> import pandas as pd + >>> m = Model() + >>> x = m.add_variables(pd.Series([0, 0]), 1) + >>> y = m.add_variables(4, pd.Series([8, 10])) + >>> expr = LinearExpression.from_tuples((10, x), (1, y), 1) - reset_index = exprwrap(Dataset.reset_index) + This is the same as calling ``10*x + y` + 1` but a bit more performant. + """ - rename_dims = exprwrap(Dataset.rename_dims) + def process_one( + t: tuple[ConstantLike, str | Variable | ScalarVariable] + | tuple[ConstantLike] + | ConstantLike, + ) -> LinearExpression: + nonlocal model - roll = exprwrap(Dataset.roll) + if isinstance(t, SUPPORTED_CONSTANT_TYPES): + if model is None: + raise ValueError("Model must be provided when using constants.") + expr = LinearExpression(t, model) + return expr - stack = exprwrap(Dataset.stack) + if not isinstance(t, tuple): + raise ValueError("Expected tuple or constant.") - unstack = exprwrap(Dataset.unstack) + if len(t) == 2: + # assume first element is coefficient and second is variable + c, v = t + if isinstance(v, variables.ScalarVariable): + if not isinstance(c, (int, float)): + raise TypeError( + "Expected int or float as coefficient of scalar variable (first element of tuple)." + ) + expr = v.to_linexpr(c) + elif isinstance(v, variables.Variable): + if not isinstance(c, SUPPORTED_CONSTANT_TYPES): + raise TypeError( + "Expected constant as coefficient of variable (first element of tuple)." + ) + expr = v.to_linexpr(c) + else: + raise TypeError("Expected variable as second element of tuple.") - iterate_slices = iterate_slices + if model is None: + model = expr.model # TODO: Ensure equality of models + return expr + + if len(t) == 1: + warn( + "Passing a single value tuple to LinearExpression.from_tuples is deprecated", + DeprecationWarning, + ) + # assume that the element is a constant + const = as_dataarray(t[0]) + if model is None: + raise ValueError("Model must be provided when using constants.") + return LinearExpression(const, model) + else: + raise ValueError("Expected tuples of length 2") + exprs = [process_one(t) for t in tuples] -class QuadraticExpression(LinearExpression): + return merge(exprs, cls=cls) if len(exprs) > 1 else exprs[0] + + +class QuadraticExpression(BaseExpression): """ A quadratic expression consisting of terms of coefficients and variables. @@ -1541,17 +1686,26 @@ def __init__(self, data: Dataset | None, model: Model) -> None: data = xr.Dataset(data.transpose(..., FACTOR_DIM, TERM_DIM)) self._data = data - def __mul__( - self, other: ConstantLike | VariableLike | ExpressionLike - ) -> QuadraticExpression: + @property + def type(self) -> str: + return "QuadraticExpression" + + @property + def shape(self) -> tuple[int, ...]: + # TODO Implement this + raise NotImplementedError( + f"{self.__class__.__name__} does not support shape property" + ) + + def __mul__(self, other: SideLike) -> QuadraticExpression: """ Multiply the expr by a factor. """ if isinstance( other, ( - LinearExpression, - QuadraticExpression, + BaseExpression, + ScalarLinearExpression, variables.Variable, variables.ScalarVariable, ), @@ -1561,15 +1715,15 @@ def __mul__( f"{type(self)} and {type(other)}. " "Higher order non-linear expressions are not yet supported." ) - return super().__mul__(other) # type: ignore + try: + return self._multiply_by_constant(other) + except TypeError: + return NotImplemented - @property - def type(self) -> str: - return "QuadraticExpression" + def __rmul__(self, other: SideLike) -> QuadraticExpression: + return self * other - def __add__( - self, other: ConstantLike | VariableLike | ExpressionLike - ) -> QuadraticExpression: + def __add__(self, other: SideLike) -> QuadraticExpression: """ Add an expression to others. @@ -1582,27 +1736,20 @@ def __add__( other = as_expression(other, model=self.model, dims=self.coord_dims) - if type(other) is LinearExpression: + if isinstance(other, LinearExpression): other = other.to_quadexpr() - return merge([self, other], cls=self.__class__) # type: ignore + + return merge([self, other], cls=self.__class__) except TypeError: return NotImplemented - def __radd__( - self, other: LinearExpression | int - ) -> LinearExpression | QuadraticExpression: + def __radd__(self, other: ConstantLike) -> QuadraticExpression: """ Add others to expression. """ - if type(other) is LinearExpression: - other = other.to_quadexpr() - return other.__add__(self) - elif other == 0: - return self - else: - return NotImplemented + return self.__add__(other) - def __sub__(self, other: SideLike | QuadraticExpression) -> QuadraticExpression: + def __sub__(self, other: SideLike) -> QuadraticExpression: """ Subtract others from expression. @@ -1616,20 +1763,45 @@ def __sub__(self, other: SideLike | QuadraticExpression) -> QuadraticExpression: other = as_expression(other, model=self.model, dims=self.coord_dims) if type(other) is LinearExpression: other = other.to_quadexpr() - return merge([self, -other], cls=self.__class__) # type: ignore + return merge([self, -other], cls=self.__class__) except TypeError: return NotImplemented - def __rsub__(self, other: LinearExpression) -> QuadraticExpression: + def __rsub__(self, other: SideLike) -> QuadraticExpression: """ Subtract expression from others. """ - if type(other) is LinearExpression: - other = other.to_quadexpr() - return other.__sub__(self) - else: + try: + return (self * -1) + other + except TypeError: return NotImplemented + def __pow__(self, other: SideLike) -> QuadraticExpression: + raise TypeError("Higher order non-linear expressions are not yet supported.") + + def __matmul__( + self, other: ConstantLike | VariableLike | ExpressionLike + ) -> QuadraticExpression: + """ + Matrix multiplication with other, similar to xarray dot. + """ + if isinstance( + other, + ( + BaseExpression, + ScalarLinearExpression, + variables.Variable, + variables.ScalarVariable, + ), + ): + raise TypeError( + "Higher order non-linear expressions are not yet supported." + ) + + other = as_dataarray(other, coords=self.coords, dims=self.coord_dims) + common_dims = list(set(self.coord_dims).intersection(other.dims)) + return (self * other).sum(dim=common_dims) + @property def solution(self) -> DataArray: """ @@ -1642,16 +1814,6 @@ def solution(self) -> DataArray: sol = (self.coeffs * vals.prod(FACTOR_DIM)).sum(TERM_DIM) + self.const return sol.rename("solution") - @classmethod - def _sum( - cls, - expr: Dataset | LinearExpression | QuadraticExpression, - dim: DimsLike | None = None, - ) -> Dataset: - data = _expr_unwrap(expr) - dim = dim or list(set(data.dims) - set(HELPER_DIMS)) - return LinearExpression._sum(expr, dim) - def to_constraint(self, sign: SignLike, rhs: SideLike) -> NotImplementedType: raise NotImplementedError( "Quadratic expressions cannot be used in constraints." @@ -1753,7 +1915,7 @@ def as_expression( ValueError If object cannot be converted to LinearExpression. """ - if isinstance(obj, LinearExpression): + if isinstance(obj, (LinearExpression, QuadraticExpression)): return obj elif isinstance(obj, (variables.Variable, variables.ScalarVariable)): return obj.to_linexpr() @@ -1773,9 +1935,9 @@ def merge( LinearExpression | QuadraticExpression | variables.Variable | Dataset ], dim: str = TERM_DIM, - cls: type[LinearExpression | QuadraticExpression] = LinearExpression, + cls: type[GenericExpression] = None, # type: ignore **kwargs: Any, -) -> LinearExpression | QuadraticExpression: +) -> GenericExpression: """ Merge multiple expression together. @@ -1785,7 +1947,6 @@ def merge( the coordinates of the first object as a basis which overrides the coordinates of the consecutive objects. - Parameters ---------- *exprs : tuple/list @@ -1803,6 +1964,13 @@ def merge( ------- res : linopy.LinearExpression """ + if cls is None: + warn( + "Using merge without specifying the class is deprecated", + DeprecationWarning, + ) + cls = LinearExpression + linopy_types = (variables.Variable, LinearExpression, QuadraticExpression) if not isinstance(exprs, list) and len(add_exprs): @@ -1823,6 +1991,13 @@ def merge( "Convert to QuadraticExpression first." ) + if cls is not QuadraticExpression and any( + type(e) is QuadraticExpression for e in exprs + ): + raise ValueError( + "Cannot merge linear and quadratic expressions to QuadraticExpression" + ) + if cls in linopy_types and dim in HELPER_DIMS: coord_dims = [ {k: v for k, v in e.sizes.items() if k not in HELPER_DIMS} for e in exprs @@ -1871,7 +2046,7 @@ class ScalarLinearExpression: A scalar linear expression container. In contrast to the LinearExpression class, a ScalarLinearExpression - only contains only one label. Use this class to create a constraint + only contains one label. Use this class to create a constraint in a rule. """ diff --git a/linopy/model.py b/linopy/model.py index 5f84e3bf..b71a8d93 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -12,7 +12,7 @@ from collections.abc import Callable, Mapping, Sequence from pathlib import Path from tempfile import NamedTemporaryFile, gettempdir -from typing import Any +from typing import Any, overload import numpy as np import pandas as pd @@ -881,17 +881,33 @@ def calculate_block_maps(self) -> None: blocks = replace_by_map(self.objective.vars, block_map) self.objective = self.objective.assign(blocks=blocks) + @overload def linexpr( - self, *args: tuple[ConstantLike, str | Variable | ScalarVariable] | Callable + self, *args: Sequence[Sequence | pd.Index | DataArray] | Mapping + ) -> LinearExpression: ... + + @overload + def linexpr( + self, *args: tuple[ConstantLike, str | Variable | ScalarVariable] | ConstantLike + ) -> LinearExpression: ... + + def linexpr( + self, + *args: tuple[ConstantLike, str | Variable | ScalarVariable] + | ConstantLike + | Callable + | Sequence[Sequence | pd.Index | DataArray] + | Mapping, ) -> LinearExpression: """ Create a linopy.LinearExpression from argument list. Parameters ---------- - args : tuples of (coefficients, variables) or tuples of - coordinates and a function - If args is a collection of coefficients-variables-tuples, the resulting + args : A mixture of tuples of (coefficients, variables) and constants + or a function and tuples of coordinates + + If args is a collection of coefficients-variables-tuples and constants, the resulting linear expression is built with the function LinearExpression.from_tuples. * coefficients : int/float/array_like The coefficient(s) in the term, if the coefficients array @@ -900,6 +916,8 @@ def linexpr( * variables : str/array_like/linopy.Variable The variable(s) going into the term. These may be referenced by name. + * constant: int/float/array_like + The constant value to add to the expression If args is a collection of coordinates with an appended function at the end, the function LinearExpression.from_rule is used to build the linear @@ -910,7 +928,7 @@ def linexpr( The first argument of the function is the underlying `linopy.Model`. The following arguments are given by the coordinates for accessing the variables. The function has to return a - `ScalarLinearExpression`. Therefore use the direct getter when + `ScalarLinearExpression`. Therefore, use the direct getter when indexing variables. * coords : coordinate-like Coordinates to be processed by `xarray.DataArray`. For each @@ -957,9 +975,14 @@ def linexpr( return LinearExpression.from_rule(self, rule, coords) # type: ignore if not isinstance(args, tuple): raise TypeError(f"Not supported type {args}.") - tuples = [ # type: ignore - (c, self.variables[v]) if isinstance(v, str) else (c, v) for (c, v) in args - ] + + tuples: list[tuple[ConstantLike, VariableLike] | ConstantLike] = [] + for arg in args: + if isinstance(arg, tuple): + c, v = arg + tuples.append((c, self.variables[v]) if isinstance(v, str) else (c, v)) + else: + tuples.append(arg) return LinearExpression.from_tuples(*tuples, model=self) @property diff --git a/linopy/monkey_patch_xarray.py b/linopy/monkey_patch_xarray.py index bd3bbce9..b2527898 100644 --- a/linopy/monkey_patch_xarray.py +++ b/linopy/monkey_patch_xarray.py @@ -26,6 +26,13 @@ def deco(func: Callable) -> Callable: def __mul__( da: DataArray, other: Any, unpatched_method: Callable ) -> DataArray | NotImplementedType: - if isinstance(other, (variables.Variable, expressions.LinearExpression)): + if isinstance( + other, + ( + variables.Variable, + expressions.LinearExpression, + expressions.QuadraticExpression, + ), + ): return NotImplemented return unpatched_method(da, other) diff --git a/linopy/testing.py b/linopy/testing.py index f4970c5f..dfa46081 100644 --- a/linopy/testing.py +++ b/linopy/testing.py @@ -13,13 +13,19 @@ def assert_varequal(a: Variable, b: Variable) -> None: return assert_equal(_var_unwrap(a), _var_unwrap(b)) -def assert_linequal(a: LinearExpression, b: LinearExpression) -> None: +def assert_linequal( + a: LinearExpression | QuadraticExpression, b: LinearExpression | QuadraticExpression +) -> None: """Assert that two linear expressions are equal.""" + assert isinstance(a, LinearExpression) + assert isinstance(b, LinearExpression) return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) -def assert_quadequal(a: QuadraticExpression, b: QuadraticExpression) -> None: - """Assert that two linear expressions are equal.""" +def assert_quadequal( + a: LinearExpression | QuadraticExpression, b: LinearExpression | QuadraticExpression +) -> None: + """Assert that two quadratic or linear expressions are equal.""" return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) diff --git a/linopy/variables.py b/linopy/variables.py index 961aca0b..695ceb85 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -51,11 +51,19 @@ ) from linopy.config import options from linopy.constants import HELPER_DIMS, TERM_DIM -from linopy.types import ConstantLike, DimsLike, NotImplementedType, SideLike +from linopy.types import ( + ConstantLike, + DimsLike, + ExpressionLike, + NotImplementedType, + SideLike, + VariableLike, +) if TYPE_CHECKING: from linopy.constraints import AnonymousScalarConstraint, Constraint from linopy.expressions import ( + GenericExpression, LinearExpression, LinearExpressionGroupby, QuadraticExpression, @@ -381,43 +389,55 @@ def __neg__(self) -> LinearExpression: """ return self.to_linexpr(-1) - def __mul__( - self, other: float | int | ndarray | Variable - ) -> LinearExpression | QuadraticExpression: + @overload + def __mul__(self, other: ConstantLike) -> LinearExpression: ... + + @overload + def __mul__(self, other: ExpressionLike | VariableLike) -> QuadraticExpression: ... + + def __mul__(self, other: SideLike) -> ExpressionLike: """ - Multiply variables with a coefficient. + Multiply variables with a coefficient, variable, or expression. """ try: - if isinstance( - other, (expressions.LinearExpression, Variable, ScalarVariable) - ): + if isinstance(other, (Variable, ScalarVariable)): return self.to_linexpr() * other return self.to_linexpr(other) except TypeError: return NotImplemented + def __rmul__(self, other: ConstantLike) -> LinearExpression: + """ + Right-multiply variables by a constant + """ + try: + return self * other + except TypeError: + return NotImplemented + def __pow__(self, other: int) -> QuadraticExpression: """ Power of the variables with a coefficient. The only coefficient allowed is 2. """ - if isinstance(other, int) and other == 2: + if not isinstance(other, int): + return NotImplemented + if other == 2: expr = self.to_linexpr() return expr._multiply_by_linear_expression(expr) - return NotImplemented + raise ValueError("Can only raise to the power of 2") - def __rmul__(self, other: float | DataArray | int | ndarray) -> LinearExpression: - """ - Right-multiply variables with a coefficient. - """ - try: - return self.to_linexpr(other) - except TypeError: - return NotImplemented + @overload + def __matmul__(self, other: ConstantLike) -> LinearExpression: ... + + @overload + def __matmul__( + self, other: VariableLike | ExpressionLike + ) -> QuadraticExpression: ... def __matmul__( - self, other: LinearExpression | ndarray | Variable - ) -> QuadraticExpression | LinearExpression: + self, other: ConstantLike | VariableLike | ExpressionLike + ) -> LinearExpression | QuadraticExpression: """ Matrix multiplication of variables with a coefficient. """ @@ -448,9 +468,18 @@ def __truediv__( except TypeError: return NotImplemented + @overload def __add__( - self, other: int | QuadraticExpression | LinearExpression | Variable - ) -> QuadraticExpression | LinearExpression: + self, other: ConstantLike | Variable | ScalarLinearExpression + ) -> LinearExpression: ... + + @overload + def __add__(self, other: GenericExpression) -> GenericExpression: ... + + def __add__( + self, + other: ConstantLike | Variable | ScalarLinearExpression | GenericExpression, + ) -> LinearExpression | GenericExpression: """ Add variables to linear expressions or other variables. """ @@ -459,13 +488,24 @@ def __add__( except TypeError: return NotImplemented - def __radd__(self, other: int) -> Variable | NotImplementedType: - # This is needed for using python's sum function - return self if other == 0 else NotImplemented + def __radd__(self, other: ConstantLike) -> LinearExpression: + try: + return self + other + except TypeError: + return NotImplemented + @overload def __sub__( - self, other: QuadraticExpression | LinearExpression | Variable - ) -> QuadraticExpression | LinearExpression: + self, other: ConstantLike | Variable | ScalarLinearExpression + ) -> LinearExpression: ... + + @overload + def __sub__(self, other: GenericExpression) -> GenericExpression: ... + + def __sub__( + self, + other: ConstantLike | Variable | ScalarLinearExpression | GenericExpression, + ) -> LinearExpression | GenericExpression: """ Subtract linear expressions or other variables from the variables. """ @@ -474,6 +514,15 @@ def __sub__( except TypeError: return NotImplemented + def __rsub__(self, other: ConstantLike) -> LinearExpression: + """ + Subtract linear expressions or other variables from the variables. + """ + try: + return self.to_linexpr(-1) + other + except TypeError: + return NotImplemented + def __le__(self, other: SideLike) -> Constraint: return self.to_linexpr().__le__(other) @@ -1467,8 +1516,7 @@ class ScalarVariable: """ A scalar variable container. - In contrast to the Variable class, a ScalarVariable only contains - only one label. Use this class to create a expression or constraint + In contrast to the Variable class, a ScalarVariable only contains one label. Use this class to create a expression or constraint in a rule. """ @@ -1520,7 +1568,7 @@ def to_scalar_linexpr(self, coeff: int | float = 1) -> ScalarLinearExpression: raise TypeError(f"Coefficient must be a numeric value, got {type(coeff)}.") return expressions.ScalarLinearExpression((coeff,), (self.label,), self.model) - def to_linexpr(self, coeff: int = 1) -> LinearExpression: + def to_linexpr(self, coeff: int | float = 1) -> LinearExpression: return self.to_scalar_linexpr(coeff).to_linexpr() def __neg__(self) -> ScalarLinearExpression: @@ -1540,6 +1588,8 @@ def __mul__(self, coeff: int | float) -> ScalarLinearExpression: return self.to_scalar_linexpr(coeff) def __rmul__(self, coeff: int | float) -> ScalarLinearExpression: + if isinstance(coeff, (Variable, ScalarVariable)): + return NotImplemented return self.to_scalar_linexpr(coeff) def __div__(self, coeff: int | float) -> ScalarLinearExpression: diff --git a/test/test_compatible_arithmetrics.py b/test/test_compatible_arithmetrics.py index 0b5829cf..1d1618ba 100644 --- a/test/test_compatible_arithmetrics.py +++ b/test/test_compatible_arithmetrics.py @@ -4,9 +4,10 @@ import pandas as pd import pytest import xarray as xr +from xarray.testing import assert_equal from linopy import LESS_EQUAL, Model, Variable -from linopy.testing import assert_linequal +from linopy.testing import assert_linequal, assert_quadequal class SomeOtherDatatype: @@ -94,17 +95,18 @@ def test_arithmetric_operations_variable(m: Model) -> None: rng = np.random.default_rng() data = xr.DataArray(rng.random(x.shape), coords=x.coords) other_datatype = SomeOtherDatatype(data.copy()) - assert_linequal(x + data, x + other_datatype) # type: ignore - assert_linequal(x - data, x - other_datatype) # type: ignore - assert_linequal(x * data, x * other_datatype) # type: ignore + assert_linequal(x + data, x + other_datatype) + assert_linequal(x - data, x - other_datatype) + assert_linequal(x * data, x * other_datatype) assert_linequal(x / data, x / other_datatype) # type: ignore assert_linequal(data * x, other_datatype * x) # type: ignore - assert x.__add__(object()) is NotImplemented # type: ignore - assert x.__sub__(object()) is NotImplemented # type: ignore - assert x.__mul__(object()) is NotImplemented # type: ignore + assert x.__add__(object()) is NotImplemented + assert x.__sub__(object()) is NotImplemented + assert x.__mul__(object()) is NotImplemented assert x.__truediv__(object()) is NotImplemented # type: ignore assert x.__pow__(object()) is NotImplemented # type: ignore - assert x.__pow__(3) is NotImplemented + with pytest.raises(ValueError): + x.__pow__(3) def test_arithmetric_operations_expr(m: Model) -> None: @@ -123,6 +125,17 @@ def test_arithmetric_operations_expr(m: Model) -> None: assert expr.__truediv__(object()) is NotImplemented +def test_arithmetric_operations_vars_and_expr(m: Model) -> None: + x = m.variables["x"] + x_expr = x * 1.0 + + assert_quadequal(x**2, x_expr**2) + assert_quadequal(x**2 + x, x + x**2) + assert_quadequal(x**2 * 2, x**2 * 2) + with pytest.raises(TypeError): + _ = x**2 * x + + def test_arithmetric_operations_con(m: Model) -> None: c = m.constraints["c"] x = m.variables["x"] @@ -133,7 +146,8 @@ def test_arithmetric_operations_con(m: Model) -> None: assert_linequal(c.lhs - data, c.lhs - other_datatype) assert_linequal(c.lhs * data, c.lhs * other_datatype) assert_linequal(c.lhs / data, c.lhs / other_datatype) - assert_linequal(c.rhs + data, c.rhs + other_datatype) # type: ignore - assert_linequal(c.rhs - data, c.rhs - other_datatype) # type: ignore - assert_linequal(c.rhs * data, c.rhs * other_datatype) # type: ignore - assert_linequal(c.rhs / data, c.rhs / other_datatype) # type: ignore + + assert_equal(c.rhs + data, c.rhs + other_datatype) + assert_equal(c.rhs - data, c.rhs - other_datatype) + assert_equal(c.rhs * data, c.rhs * other_datatype) + assert_equal(c.rhs / data, c.rhs / other_datatype) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 3ed9482f..cec12882 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -14,10 +14,11 @@ import xarray as xr from xarray.testing import assert_equal -from linopy import LinearExpression, Model, Variable, merge +from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge from linopy.constants import HELPER_DIMS, TERM_DIM from linopy.expressions import ScalarLinearExpression -from linopy.testing import assert_linequal +from linopy.testing import assert_linequal, assert_quadequal +from linopy.variables import ScalarVariable @pytest.fixture @@ -158,6 +159,13 @@ def test_linexpr_with_scalars(m: Model) -> None: assert_equal(expr.coeffs, target) +def test_linexpr_with_variables_and_constants( + m: Model, x: Variable, y: Variable +) -> None: + expr = m.linexpr((10, x), (1, y), 2) + assert (expr.const == 2).all() + + def test_linexpr_with_series(m: Model, v: Variable) -> None: lhs = pd.Series(np.arange(20)), v expr = m.linexpr(lhs) @@ -187,6 +195,9 @@ def test_linear_expression_with_multiplication(x: Variable) -> None: expr2 = x.mul(1) assert_linequal(expr, expr2) + expr3 = expr.mul(1) + assert_linequal(expr, expr3) + expr = x / 1 assert isinstance(expr, LinearExpression) @@ -196,6 +207,9 @@ def test_linear_expression_with_multiplication(x: Variable) -> None: expr2 = x.div(1) assert_linequal(expr, expr2) + expr3 = expr.div(1) + assert_linequal(expr, expr3) + expr = np.array([1, 2]) * x assert isinstance(expr, LinearExpression) @@ -208,6 +222,17 @@ def test_linear_expression_with_multiplication(x: Variable) -> None: expr = pd.Series([1, 2], index=pd.RangeIndex(2, name="dim_0")) * x assert isinstance(expr, LinearExpression) + quad = x * x + assert isinstance(quad, QuadraticExpression) + + with pytest.raises(TypeError): + quad * quad + + expr = x * 1 + assert isinstance(expr, LinearExpression) + assert expr.__mul__(object()) is NotImplemented + assert expr.__rmul__(object()) is NotImplemented + def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> None: expr = 10 * x + y @@ -225,6 +250,20 @@ def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> expr2 = x.add(y) assert_linequal(expr, expr2) + expr3 = (x * 1).add(y) + assert_linequal(expr, expr3) + + expr3 = x + (x * x) + assert isinstance(expr3, QuadraticExpression) + + +def test_linear_expression_with_raddition(m: Model, x: Variable) -> None: + expr = x * 1.0 + expr_2: LinearExpression = 10.0 + expr + assert isinstance(expr, LinearExpression) + expr_3: LinearExpression = expr + 10.0 + assert_linequal(expr_2, expr_3) + def test_linear_expression_with_subtraction(m: Model, x: Variable, y: Variable) -> None: expr = x - y @@ -234,11 +273,24 @@ def test_linear_expression_with_subtraction(m: Model, x: Variable, y: Variable) expr2 = x.sub(y) assert_linequal(expr, expr2) + expr3: LinearExpression = x * 1 + expr4 = expr3.sub(y) + assert_linequal(expr, expr4) + expr = -x - 8 * y assert isinstance(expr, LinearExpression) assert_linequal(expr, m.linexpr((-1, "x"), (-8, "y"))) +def test_linear_expression_rsubtraction(x: Variable, y: Variable) -> None: + expr = x * 1.0 + expr_2: LinearExpression = 10.0 - expr + assert isinstance(expr_2, LinearExpression) + expr_3: LinearExpression = (expr - 10.0) * -1 + assert_linequal(expr_2, expr_3) + assert expr.__rsub__(object()) is NotImplemented + + def test_linear_expression_with_constant(m: Model, x: Variable, y: Variable) -> None: expr = x + 1 assert isinstance(expr, LinearExpression) @@ -277,7 +329,10 @@ def test_linear_expression_with_errors(m: Model, x: Variable) -> None: x / (1 * x) with pytest.raises(TypeError): - m.linexpr((10, x.labels), (1, "y")) # type: ignore + m.linexpr((10, x.labels), (1, "y")) + + with pytest.raises(TypeError): + m.linexpr(a=2) # type: ignore def test_linear_expression_from_rule(m: Model, x: Variable, y: Variable) -> None: @@ -322,6 +377,9 @@ def test_linear_expression_addition(x: Variable, y: Variable, z: Variable) -> No assert (res.coords["dim_1"] == other.coords["dim_1"]).all() assert res.data.notnull().all().to_array().all() + res2 = expr.add(other) + assert_linequal(res, res2) + assert isinstance(x - expr, LinearExpression) assert isinstance(x + expr, LinearExpression) @@ -442,6 +500,18 @@ def test_linear_expression_sum_warn_unknown_kwargs(z: Variable) -> None: (1 * z).sum(unknown_kwarg="dim_0") +def test_linear_expression_power(x: Variable) -> None: + expr: LinearExpression = x * 1.0 + qd_expr = expr**2 + assert isinstance(qd_expr, QuadraticExpression) + + qd_expr2 = expr.pow(2) + assert_quadequal(qd_expr, qd_expr2) + + with pytest.raises(ValueError): + expr**3 + + def test_linear_expression_multiplication( x: Variable, y: Variable, z: Variable ) -> None: @@ -1008,6 +1078,47 @@ def test_linear_expression_rolling_from_variable(v: Variable) -> None: assert rolled.nterm == 2 +def test_linear_expression_from_tuples(x: Variable, y: Variable) -> None: + expr = LinearExpression.from_tuples((10, x), (1, y)) + assert isinstance(expr, LinearExpression) + + with pytest.warns(DeprecationWarning): + expr2 = LinearExpression.from_tuples((10, x), (1,)) + assert isinstance(expr2, LinearExpression) + assert (expr2.const == 1).all() + + expr3 = LinearExpression.from_tuples((10, x), 1) + assert isinstance(expr3, LinearExpression) + assert_linequal(expr2, expr3) + + expr4 = LinearExpression.from_tuples((10, x), (1, y), 1) + assert isinstance(expr4, LinearExpression) + assert (expr4.const == 1).all() + + expr5 = LinearExpression.from_tuples(1, model=x.model) + assert isinstance(expr5, LinearExpression) + + +def test_linear_expression_from_tuples_bad_calls( + m: Model, x: Variable, y: Variable +) -> None: + with pytest.raises(ValueError): + LinearExpression.from_tuples((10, x), (1, y), x) + + with pytest.raises(ValueError): + LinearExpression.from_tuples((10, x, 3), (1, y), 1) + + sv = ScalarVariable(label=0, model=m) + with pytest.raises(TypeError): + LinearExpression.from_tuples((np.array([1, 1]), sv)) + + with pytest.raises(TypeError): + LinearExpression.from_tuples((x, x)) + + with pytest.raises(ValueError): + LinearExpression.from_tuples(10) + + def test_linear_expression_sanitize(x: Variable, y: Variable, z: Variable) -> None: expr = 10 * x + y + z assert isinstance(expr.sanitize(), LinearExpression) @@ -1017,24 +1128,25 @@ def test_merge(x: Variable, y: Variable, z: Variable) -> None: expr1 = (10 * x + y).sum("dim_0") expr2 = z.sum("dim_0") - res = merge([expr1, expr2]) + res = merge([expr1, expr2], cls=LinearExpression) assert res.nterm == 6 - res = merge([expr1, expr2]) - assert res.nterm == 6 + with pytest.warns(DeprecationWarning): + res: LinearExpression = merge([expr1, expr2]) # type: ignore + assert res.nterm == 6 # now concat with same length of terms expr1 = z.sel(dim_0=0).sum("dim_1") expr2 = z.sel(dim_0=1).sum("dim_1") - res = merge([expr1, expr2], dim="dim_1") + res = merge([expr1, expr2], dim="dim_1", cls=LinearExpression) assert res.nterm == 3 # now with different length of terms expr1 = z.sel(dim_0=0, dim_1=slice(0, 1)).sum("dim_1") expr2 = z.sel(dim_0=1).sum("dim_1") - res = merge([expr1, expr2], dim="dim_1") + res = merge([expr1, expr2], dim="dim_1", cls=LinearExpression) assert res.nterm == 3 assert res.sel(dim_1=0).vars[2].item() == -1 diff --git a/test/test_optimization.py b/test/test_optimization.py index 0a6a0fb4..b34650dc 100644 --- a/test/test_optimization.py +++ b/test/test_optimization.py @@ -386,7 +386,7 @@ def test_default_setting_expression_sol_accessor( qexpr = 4 * x**2 assert_equal(qexpr.solution, 4 * x.solution**2) - qexpr = 4 * x * y + qexpr = 4 * (x * y) # type: ignore assert_equal(qexpr.solution, 4 * x.solution * y.solution) diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index f2ae7c8a..6a41d94f 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -9,7 +9,7 @@ from linopy import Model, Variable, merge from linopy.constants import FACTOR_DIM, TERM_DIM -from linopy.expressions import QuadraticExpression +from linopy.expressions import LinearExpression, QuadraticExpression from linopy.testing import assert_quadequal @@ -41,6 +41,13 @@ def test_quadratic_expression_from_variables_multiplication( assert quad_expr.data.sizes[FACTOR_DIM] == 2 +def test_adding_quadratic_expressions(x: Variable) -> None: + quad_expr = x * x + double_quad = quad_expr + quad_expr + assert isinstance(double_quad, QuadraticExpression) + assert double_quad.__add__(object()) is NotImplemented + + def test_quadratic_expression_from_variables_power(x: Variable) -> None: power_expr = x**2 target: QuadraticExpression = x * x # type: ignore @@ -116,6 +123,18 @@ def test_matmul_expr_and_expr(x: Variable, y: Variable, z: Variable) -> None: assert expr.nterm == 6 assert_quadequal(expr, target) + with pytest.raises(TypeError): + (x**2) @ (y**2) + + +def test_matmul_with_const(x: Variable) -> None: + expr = x * x + const = DataArray([2.0, 1.0], dims=["dim_0"]) + expr2 = expr @ const + assert isinstance(expr2, QuadraticExpression) + assert expr2.nterm == 2 + assert expr2.data.sizes[FACTOR_DIM] == 2 + def test_quadratic_expression_dot_and_matmul(x: Variable, y: Variable) -> None: matmul_expr: QuadraticExpression = 10 * x @ y # type: ignore @@ -144,8 +163,12 @@ def test_quadratic_expression_raddition(x: Variable, y: Variable) -> None: assert (expr.const == 5).all() assert expr.nterm == 2 - with pytest.raises(TypeError): - 5 + x * y + x + expr_2 = 5 + x * y + x + assert isinstance(expr_2, QuadraticExpression) + assert (expr_2.const == 5).all() + assert expr_2.nterm == 2 + + assert_quadequal(expr, expr_2) def test_quadratic_expression_subtraction(x: Variable, y: Variable) -> None: @@ -153,6 +176,7 @@ def test_quadratic_expression_subtraction(x: Variable, y: Variable) -> None: assert isinstance(expr, QuadraticExpression) assert (expr.const == -5).all() assert expr.nterm == 2 + assert expr.__sub__(object()) is NotImplemented def test_quadratic_expression_rsubtraction(x: Variable, y: Variable) -> None: @@ -161,6 +185,11 @@ def test_quadratic_expression_rsubtraction(x: Variable, y: Variable) -> None: assert (expr.const == -5).all() assert expr.nterm == 2 + expr2 = 5 - x * y + assert isinstance(expr2, QuadraticExpression) + assert (expr2.const == 5).all() + assert expr2.nterm == 1 + def test_quadratic_expression_sum(x: Variable, y: Variable) -> None: base_expr = x * y + x + 5 @@ -188,6 +217,10 @@ def test_quadratic_expression_wrong_multiplication(x: Variable, y: Variable) -> with pytest.raises(TypeError): x * x * y + quad = x * x + with pytest.raises(TypeError): + quad * quad + def merge_raise_deprecation_warning(x: Variable, y: Variable) -> None: expr: QuadraticExpression = x * y # type: ignore @@ -198,16 +231,18 @@ def merge_raise_deprecation_warning(x: Variable, y: Variable) -> None: def test_merge_linear_expression_and_quadratic_expression( x: Variable, y: Variable ) -> None: - linexpr = 10 * x + y + 5 - quadexpr = x * y + linexpr: LinearExpression = 10 * x + y + 5 + quadexpr: QuadraticExpression = x * y # type: ignore + merge([linexpr.to_quadexpr(), quadexpr], cls=QuadraticExpression) with pytest.raises(ValueError): merge([linexpr, quadexpr], cls=QuadraticExpression) - with pytest.warns(DeprecationWarning): - merge(linexpr, quadexpr, cls=QuadraticExpression) # type: ignore - linexpr = linexpr.to_quadexpr() - merged_expr = merge([linexpr, quadexpr], cls=QuadraticExpression) + with pytest.warns(DeprecationWarning): + merge(quadexpr, quadexpr, cls=QuadraticExpression) # type: ignore + + quadexpr_2 = linexpr.to_quadexpr() + merged_expr = merge([quadexpr_2, quadexpr], cls=QuadraticExpression) assert isinstance(merged_expr, QuadraticExpression) assert merged_expr.nterm == 3 assert merged_expr.const.sum() == 10 @@ -217,6 +252,12 @@ def test_merge_linear_expression_and_quadratic_expression( first_term = merged_expr.data.isel({TERM_DIM: 0}) assert (first_term.vars.isel({FACTOR_DIM: 1}) == -1).all() + qdexpr = merge([x**2, y**2], cls=QuadraticExpression) + assert isinstance(qdexpr, QuadraticExpression) + + with pytest.raises(ValueError): + merge([x**2, y**2], cls=LinearExpression) + def test_quadratic_expression_loc(x: Variable) -> None: expr = x * x @@ -287,3 +328,16 @@ def test_matrices_matrix_mixed_linear_and_quadratic( def test_quadratic_to_constraint(x: Variable, y: Variable) -> None: with pytest.raises(NotImplementedError): x * y <= 10 + + +def test_power_of_three(x: Variable) -> None: + with pytest.raises(TypeError): + x * x * x + with pytest.raises(TypeError): + (x * 1) * (x * x) + with pytest.raises(TypeError): + (x * x) * (x * 1) + with pytest.raises(ValueError): + x**3 + with pytest.raises(TypeError): + (x * x) * (x * x) diff --git a/test/test_typing.py b/test/test_typing.py new file mode 100644 index 00000000..99a27033 --- /dev/null +++ b/test/test_typing.py @@ -0,0 +1,25 @@ +import xarray as xr + +import linopy + + +def test_operations_with_data_arrays_are_typed_correctly() -> None: + m = linopy.Model() + + a: xr.DataArray = xr.DataArray([1, 2, 3]) + + v: linopy.Variable = m.add_variables(lower=0.0, name="v") + e: linopy.LinearExpression = v * 1.0 + q = v * v + + _ = a * v + _ = v * a + _ = v + a + + _ = a * e + _ = e * a + _ = e + a + + _ = a * q + _ = q * a + _ = q + a diff --git a/test/test_variable.py b/test/test_variable.py index 065dd1ce..6b0e4595 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -17,6 +17,7 @@ import linopy import linopy.variables from linopy import Model +from linopy.testing import assert_linequal @pytest.fixture @@ -305,3 +306,39 @@ def test_variable_iterate_slices(x: linopy.Variable) -> None: for s in slices: assert isinstance(s, linopy.variables.Variable) assert s.size <= 2 + + +def test_variable_addition(x: linopy.Variable) -> None: + expr1 = x + 1 + assert isinstance(expr1, linopy.expressions.LinearExpression) + expr2 = 1 + x + assert isinstance(expr2, linopy.expressions.LinearExpression) + assert_linequal(expr1, expr2) + + assert x.__radd__(object()) is NotImplemented + assert x.__add__(object()) is NotImplemented + + +def test_variable_subtraction(x: linopy.Variable) -> None: + expr1 = -x + 1 + assert isinstance(expr1, linopy.expressions.LinearExpression) + expr2 = 1 - x + assert isinstance(expr2, linopy.expressions.LinearExpression) + assert_linequal(expr1, expr2) + + assert x.__rsub__(object()) is NotImplemented + assert x.__sub__(object()) is NotImplemented + + +def test_variable_multiplication(x: linopy.Variable) -> None: + expr1 = x * 2 + assert isinstance(expr1, linopy.expressions.LinearExpression) + expr2 = 2 * x + assert isinstance(expr2, linopy.expressions.LinearExpression) + assert_linequal(expr1, expr2) + + expr3 = x * x + assert isinstance(expr3, linopy.expressions.QuadraticExpression) + + assert x.__rmul__(object()) is NotImplemented + assert x.__mul__(object()) is NotImplemented diff --git a/test/test_variables.py b/test/test_variables.py index a55e92bb..3984b091 100644 --- a/test/test_variables.py +++ b/test/test_variables.py @@ -13,6 +13,7 @@ import linopy from linopy import Model from linopy.testing import assert_varequal +from linopy.variables import ScalarVariable @pytest.fixture @@ -115,3 +116,9 @@ def test_variables_get_name_by_label(m: Model) -> None: with pytest.raises(ValueError): m.variables.get_name_by_label("anystring") # type: ignore + + +def test_scalar_variable(m: Model) -> None: + x = ScalarVariable(label=0, model=m) + assert isinstance(x, ScalarVariable) + assert x.__rmul__(x) is NotImplemented # type: ignore