From b4c136d046339d8f35898968852c239cbe2b1144 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Thu, 1 May 2025 13:51:43 +0200 Subject: [PATCH 01/21] Fixed typing of arithmetic methods --- doc/release_notes.rst | 1 + linopy/expressions.py | 74 +++++++++++++++++++------------------------ linopy/variables.py | 50 ++++++++++++++--------------- test/test_typing.py | 32 +++++++++++++++++++ 4 files changed, 91 insertions(+), 66 deletions(-) create mode 100644 test/test_typing.py diff --git a/doc/release_notes.rst b/doc/release_notes.rst index cbe20019..75869ae1 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -12,6 +12,7 @@ Upcoming Version gap tolerance. * Improve the mapping of termination conditions for the SCIP solver * Treat GLPK's `integer undefined` status as not-OK +* Fixed variable/expression arithmetic methods so that they correctly handle types Version 0.5.3 -------------- diff --git a/linopy/expressions.py b/linopy/expressions.py index ec266af6..e34f212f 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -12,10 +12,7 @@ from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field from itertools import product, zip_longest -from typing import ( - TYPE_CHECKING, - Any, -) +from typing import TYPE_CHECKING, Any from warnings import warn import numpy as np @@ -487,13 +484,16 @@ def print(self, display_max_rows: int = 20, display_max_terms: int = 20) -> None ) print(self) - def __add__(self, other: SideLike) -> LinearExpression: + def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression: """ Add an expression to others. Note: If other is a numpy array or pandas object without axes names, dimension names of self will be filled in other """ + if isinstance(other, QuadraticExpression): + return other.__add__(self) + try: if np.isscalar(other): return self.assign(const=self.const + other) @@ -503,9 +503,8 @@ def __add__(self, other: SideLike) -> LinearExpression: except TypeError: return NotImplemented - def __radd__(self, other: int) -> LinearExpression | NotImplementedType: - # This is needed for using python's sum function - return self if other == 0 else NotImplemented + def __radd__(self, other: ConstantLike) -> LinearExpression: + return self.__add__(other) def __sub__(self, other: SideLike) -> LinearExpression: """ @@ -514,6 +513,9 @@ def __sub__(self, other: SideLike) -> LinearExpression: Note: If other is a numpy array or pandas object without axes names, dimension names of self will be filled in other """ + if isinstance(other, QuadraticExpression): + return other.__rsub__(self) + try: if np.isscalar(other): return self.assign_multiindex_safe(const=self.const - other) @@ -523,7 +525,7 @@ def __sub__(self, other: SideLike) -> LinearExpression: except TypeError: return NotImplemented - def __neg__(self) -> LinearExpression | QuadraticExpression: + def __neg__(self) -> LinearExpression: """ Get the negative of the expression. """ @@ -536,14 +538,11 @@ def __mul__( """ Multiply the expr by a factor. """ + if isinstance(other, QuadraticExpression): + return other.__rmul__(self) # type: ignore + try: - if isinstance(other, QuadraticExpression): - raise TypeError( - "unsupported operand type(s) for *: " - f"{type(self)} and {type(other)}. " - "Higher order non-linear expressions are not yet supported." - ) - elif isinstance(other, (variables.Variable, variables.ScalarVariable)): + if isinstance(other, (variables.Variable, variables.ScalarVariable)): other = other.to_linexpr() if isinstance(other, (LinearExpression, ScalarLinearExpression)): @@ -593,7 +592,7 @@ def __pow__(self, other: int) -> QuadraticExpression: raise ValueError("Power must be 2.") return self * self # type: ignore - def __rmul__(self, other: ConstantLike) -> LinearExpression | QuadraticExpression: + def __rmul__(self, other: ConstantLike) -> LinearExpression: """ Right-multiply the expr by a factor. """ @@ -1545,9 +1544,7 @@ def __init__(self, data: Dataset | None, model: Model) -> None: data = xr.Dataset(data.transpose(..., FACTOR_DIM, TERM_DIM)) self._data = data - def __mul__( - self, other: ConstantLike | VariableLike | ExpressionLike - ) -> QuadraticExpression: + def __mul__(self, other: ConstantLike) -> QuadraticExpression: """ Multiply the expr by a factor. """ @@ -1567,13 +1564,14 @@ def __mul__( ) return super().__mul__(other) # type: ignore + def __rmul__(self, other: ConstantLike) -> QuadraticExpression: + return self.__mul__(other) + @property def type(self) -> str: return "QuadraticExpression" - def __add__( - self, other: ConstantLike | VariableLike | ExpressionLike - ) -> QuadraticExpression: + def __add__(self, other: SideLike) -> QuadraticExpression: """ Add an expression to others. @@ -1592,21 +1590,13 @@ def __add__( except TypeError: return NotImplemented - def __radd__( - self, other: LinearExpression | int - ) -> LinearExpression | QuadraticExpression: + def __radd__(self, other: ConstantLike) -> QuadraticExpression: """ Add others to expression. """ - if type(other) is LinearExpression: - other = other.to_quadexpr() - return other.__add__(self) - elif other == 0: - return self - else: - return NotImplemented + return other.__add__(self) - def __sub__(self, other: SideLike | QuadraticExpression) -> QuadraticExpression: + def __sub__(self, other: SideLike) -> QuadraticExpression: """ Subtract others from expression. @@ -1624,15 +1614,17 @@ def __sub__(self, other: SideLike | QuadraticExpression) -> QuadraticExpression: except TypeError: return NotImplemented - def __rsub__(self, other: LinearExpression) -> QuadraticExpression: + def __rsub__(self, other: SideLike) -> QuadraticExpression: """ Subtract expression from others. """ - if type(other) is LinearExpression: - other = other.to_quadexpr() - return other.__sub__(self) - else: - return NotImplemented + return self.__neg__().__add__(other) + + def __neg__(self) -> QuadraticExpression: + """ + Get the negative of the expression. + """ + return super().__neg__() # type: ignore @property def solution(self) -> DataArray: @@ -1875,7 +1867,7 @@ class ScalarLinearExpression: A scalar linear expression container. In contrast to the LinearExpression class, a ScalarLinearExpression - only contains only one label. Use this class to create a constraint + only contains one label. Use this class to create a constraint in a rule. """ diff --git a/linopy/variables.py b/linopy/variables.py index 1c50441e..c2b89188 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -382,22 +382,24 @@ def __neg__(self) -> LinearExpression: """ return self.to_linexpr(-1) - def __mul__( - self, other: float | int | ndarray | Variable - ) -> LinearExpression | QuadraticExpression: + def __mul__(self, other: SideLike) -> LinearExpression | QuadraticExpression: """ - Multiply variables with a coefficient. + Multiply variables with a coefficient, variable, or expression. """ try: - if isinstance( - other, (expressions.LinearExpression, Variable, ScalarVariable) - ): + if isinstance(other, (Variable, ScalarVariable)): return self.to_linexpr() * other return self.to_linexpr(other) except TypeError: return NotImplemented + def __rmul__(self, other: ConstantLike) -> LinearExpression: + """ + Right-multiply variables by a constant + """ + return self.to_linexpr(other) + def __pow__(self, other: int) -> QuadraticExpression: """ Power of the variables with a coefficient. The only coefficient allowed is 2. @@ -407,15 +409,6 @@ def __pow__(self, other: int) -> QuadraticExpression: return expr._multiply_by_linear_expression(expr) return NotImplemented - def __rmul__(self, other: float | DataArray | int | ndarray) -> LinearExpression: - """ - Right-multiply variables with a coefficient. - """ - try: - return self.to_linexpr(other) - except TypeError: - return NotImplemented - def __matmul__( self, other: LinearExpression | ndarray | Variable ) -> QuadraticExpression | LinearExpression: @@ -449,9 +442,7 @@ def __truediv__( except TypeError: return NotImplemented - def __add__( - self, other: int | QuadraticExpression | LinearExpression | Variable - ) -> QuadraticExpression | LinearExpression: + def __add__(self, other: SideLike) -> LinearExpression: """ Add variables to linear expressions or other variables. """ @@ -460,13 +451,13 @@ def __add__( except TypeError: return NotImplemented - def __radd__(self, other: int) -> Variable | NotImplementedType: - # This is needed for using python's sum function - return self if other == 0 else NotImplemented + def __radd__(self, other: ConstantLike) -> LinearExpression: + try: + return self.__add__(other) + except ValueError: + return NotImplemented - def __sub__( - self, other: QuadraticExpression | LinearExpression | Variable - ) -> QuadraticExpression | LinearExpression: + def __sub__(self, other: SideLike) -> LinearExpression: """ Subtract linear expressions or other variables from the variables. """ @@ -475,6 +466,15 @@ def __sub__( except TypeError: return NotImplemented + def __rsub__(self, other: ConstantLike) -> LinearExpression: + """ + Subtract linear expressions or other variables from the variables. + """ + try: + return self.to_linexpr(-1) + other + except TypeError: + return NotImplemented + def __le__(self, other: SideLike) -> Constraint: return self.to_linexpr().__le__(other) diff --git a/test/test_typing.py b/test/test_typing.py new file mode 100644 index 00000000..ca145f5c --- /dev/null +++ b/test/test_typing.py @@ -0,0 +1,32 @@ +import xarray as xr +from mypy import api + +import linopy + + +def test_operations_with_data_arrays_are_typed_correctly() -> None: + m = linopy.Model() + + a: xr.DataArray = xr.DataArray([1, 2, 3]) + + v: linopy.Variable = m.add_variables(lower=0.0, name="v") + e: linopy.LinearExpression = v * 1.0 + q = v * v + assert isinstance(q, linopy.QuadraticExpression) + + _ = a * v + _ = v * a + _ = v + a + + _ = a * e + _ = e * a + _ = e + a + + _ = a * q + _ = q * a + _ = q + a + + # Get the path of this file + file_path = __file__ + result = api.run([file_path]) + assert result[2] == 0, "Mypy returned issues: " + result[0] From 4cb1f3791a2d588ef2ebdfeb8155dcce3a89c196 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Fri, 2 May 2025 20:21:16 +0200 Subject: [PATCH 02/21] changes based on pr comments --- linopy/expressions.py | 4 ++-- test/test_compatible_arithmetrics.py | 25 ++++++++++++++++++------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index e34f212f..4666bfae 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -539,7 +539,7 @@ def __mul__( Multiply the expr by a factor. """ if isinstance(other, QuadraticExpression): - return other.__rmul__(self) # type: ignore + return other.__rmul__(self) try: if isinstance(other, (variables.Variable, variables.ScalarVariable)): @@ -1594,7 +1594,7 @@ def __radd__(self, other: ConstantLike) -> QuadraticExpression: """ Add others to expression. """ - return other.__add__(self) + return self.__add__(other) def __sub__(self, other: SideLike) -> QuadraticExpression: """ diff --git a/test/test_compatible_arithmetrics.py b/test/test_compatible_arithmetrics.py index 0b5829cf..2c028717 100644 --- a/test/test_compatible_arithmetrics.py +++ b/test/test_compatible_arithmetrics.py @@ -94,14 +94,14 @@ def test_arithmetric_operations_variable(m: Model) -> None: rng = np.random.default_rng() data = xr.DataArray(rng.random(x.shape), coords=x.coords) other_datatype = SomeOtherDatatype(data.copy()) - assert_linequal(x + data, x + other_datatype) # type: ignore - assert_linequal(x - data, x - other_datatype) # type: ignore - assert_linequal(x * data, x * other_datatype) # type: ignore - assert_linequal(x / data, x / other_datatype) # type: ignore + assert_linequal(x + data, x + other_datatype) + assert_linequal(x - data, x - other_datatype) + assert_linequal(x * data, x * other_datatype) + assert_linequal(x / data, x / other_datatype) assert_linequal(data * x, other_datatype * x) # type: ignore - assert x.__add__(object()) is NotImplemented # type: ignore - assert x.__sub__(object()) is NotImplemented # type: ignore - assert x.__mul__(object()) is NotImplemented # type: ignore + assert x.__add__(object()) is NotImplemented + assert x.__sub__(object()) is NotImplemented + assert x.__mul__(object()) is NotImplemented assert x.__truediv__(object()) is NotImplemented # type: ignore assert x.__pow__(object()) is NotImplemented # type: ignore assert x.__pow__(3) is NotImplemented @@ -123,6 +123,17 @@ def test_arithmetric_operations_expr(m: Model) -> None: assert expr.__truediv__(object()) is NotImplemented +def test_arithmetric_operations_vars_and_expr(m: Model) -> None: + x = m.variables["x"] + x_expr = x * 1.0 + + assert_linequal(x**2, x_expr**2) + assert_linequal(x**2 + x, x + x**2) + assert_linequal(x**2 * 2, x**2 * 2) + with pytest.raises(TypeError): + _ = x**2 * x # type: ignore + + def test_arithmetric_operations_con(m: Model) -> None: c = m.constraints["c"] x = m.variables["x"] From 4be198ff83622195a047bb148cf625541a56523a Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Fri, 2 May 2025 21:34:28 +0200 Subject: [PATCH 03/21] Changes to typing --- linopy/expressions.py | 49 ++++++++++++++++++++-------- linopy/variables.py | 21 ++++++++++-- test/test_compatible_arithmetrics.py | 4 +-- test/test_optimization.py | 2 +- test/test_typing.py | 11 +++---- 5 files changed, 62 insertions(+), 25 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 4666bfae..6a663597 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -12,7 +12,7 @@ from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field from itertools import product, zip_longest -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar, overload from warnings import warn import numpy as np @@ -531,10 +531,18 @@ def __neg__(self) -> LinearExpression: """ return self.assign_multiindex_safe(coeffs=-self.coeffs, const=-self.const) + @overload def __mul__( - self, + self: GenericLinearExpression, other: ConstantLike + ) -> GenericLinearExpression: ... + + @overload + def __mul__(self, other: VariableLike | ExpressionLike) -> QuadraticExpression: ... + + def __mul__( + self: GenericLinearExpression, other: SideLike, - ) -> LinearExpression | QuadraticExpression: + ) -> GenericLinearExpression | QuadraticExpression: """ Multiply the expr by a factor. """ @@ -577,7 +585,9 @@ def _multiply_by_linear_expression( res = res + self.reset_const() * other.const return res # type: ignore - def _multiply_by_constant(self, other: ConstantLike) -> LinearExpression: + def _multiply_by_constant( + self: GenericLinearExpression, other: ConstantLike + ) -> GenericLinearExpression: multiplier = as_dataarray(other, coords=self.coords, dims=self.coord_dims) coeffs = self.coeffs * multiplier assert all(coeffs.sizes[d] == s for d, s in self.coeffs.sizes.items()) @@ -592,7 +602,9 @@ def __pow__(self, other: int) -> QuadraticExpression: raise ValueError("Power must be 2.") return self * self # type: ignore - def __rmul__(self, other: ConstantLike) -> LinearExpression: + def __rmul__( + self: GenericLinearExpression, other: ConstantLike + ) -> GenericLinearExpression: """ Right-multiply the expr by a factor. """ @@ -611,11 +623,18 @@ def __matmul__( return (self * other).sum(dim=common_dims) def __div__( - self, other: Variable | ConstantLike - ) -> LinearExpression | QuadraticExpression: + self: GenericLinearExpression, other: SideLike + ) -> GenericLinearExpression: try: if isinstance( - other, (LinearExpression, variables.Variable, variables.ScalarVariable) + other, + ( + variables.Variable, + variables.ScalarVariable, + LinearExpression, + ScalarLinearExpression, + QuadraticExpression, + ), ): raise TypeError( "unsupported operand type(s) for /: " @@ -627,8 +646,8 @@ def __div__( return NotImplemented def __truediv__( - self, other: Variable | ConstantLike - ) -> LinearExpression | QuadraticExpression: + self: GenericLinearExpression, other: SideLike + ) -> GenericLinearExpression: return self.__div__(other) def __le__(self, rhs: SideLike) -> Constraint: @@ -1514,6 +1533,9 @@ def to_polars(self) -> pl.DataFrame: iterate_slices = iterate_slices +GenericLinearExpression = TypeVar("GenericLinearExpression", bound=LinearExpression) + + class QuadraticExpression(LinearExpression): """ A quadratic expression consisting of terms of coefficients and variables. @@ -1544,7 +1566,7 @@ def __init__(self, data: Dataset | None, model: Model) -> None: data = xr.Dataset(data.transpose(..., FACTOR_DIM, TERM_DIM)) self._data = data - def __mul__(self, other: ConstantLike) -> QuadraticExpression: + def __mul__(self, other: SideLike) -> QuadraticExpression: """ Multiply the expr by a factor. """ @@ -1553,6 +1575,7 @@ def __mul__(self, other: ConstantLike) -> QuadraticExpression: ( LinearExpression, QuadraticExpression, + ScalarLinearExpression, variables.Variable, variables.ScalarVariable, ), @@ -1562,9 +1585,9 @@ def __mul__(self, other: ConstantLike) -> QuadraticExpression: f"{type(self)} and {type(other)}. " "Higher order non-linear expressions are not yet supported." ) - return super().__mul__(other) # type: ignore + return super().__mul__(other) - def __rmul__(self, other: ConstantLike) -> QuadraticExpression: + def __rmul__(self, other: SideLike) -> QuadraticExpression: return self.__mul__(other) @property diff --git a/linopy/variables.py b/linopy/variables.py index c2b89188..c4d30152 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -52,7 +52,14 @@ ) from linopy.config import options from linopy.constants import HELPER_DIMS, TERM_DIM -from linopy.types import ConstantLike, DimsLike, NotImplementedType, SideLike +from linopy.types import ( + ConstantLike, + DimsLike, + ExpressionLike, + NotImplementedType, + SideLike, + VariableLike, +) if TYPE_CHECKING: from linopy.constraints import AnonymousScalarConstraint, Constraint @@ -382,7 +389,13 @@ def __neg__(self) -> LinearExpression: """ return self.to_linexpr(-1) - def __mul__(self, other: SideLike) -> LinearExpression | QuadraticExpression: + @overload + def __mul__(self, other: ConstantLike) -> LinearExpression: ... + + @overload + def __mul__(self, other: ExpressionLike | VariableLike) -> QuadraticExpression: ... + + def __mul__(self, other: SideLike) -> ExpressionLike: """ Multiply variables with a coefficient, variable, or expression. """ @@ -398,7 +411,7 @@ def __rmul__(self, other: ConstantLike) -> LinearExpression: """ Right-multiply variables by a constant """ - return self.to_linexpr(other) + return self * other def __pow__(self, other: int) -> QuadraticExpression: """ @@ -1539,6 +1552,8 @@ def __mul__(self, coeff: int | float) -> ScalarLinearExpression: return self.to_scalar_linexpr(coeff) def __rmul__(self, coeff: int | float) -> ScalarLinearExpression: + if isinstance(coeff, Variable): + return NotImplemented return self.to_scalar_linexpr(coeff) def __div__(self, coeff: int | float) -> ScalarLinearExpression: diff --git a/test/test_compatible_arithmetrics.py b/test/test_compatible_arithmetrics.py index 2c028717..8cb01c31 100644 --- a/test/test_compatible_arithmetrics.py +++ b/test/test_compatible_arithmetrics.py @@ -97,7 +97,7 @@ def test_arithmetric_operations_variable(m: Model) -> None: assert_linequal(x + data, x + other_datatype) assert_linequal(x - data, x - other_datatype) assert_linequal(x * data, x * other_datatype) - assert_linequal(x / data, x / other_datatype) + assert_linequal(x / data, x / other_datatype) # type: ignore assert_linequal(data * x, other_datatype * x) # type: ignore assert x.__add__(object()) is NotImplemented assert x.__sub__(object()) is NotImplemented @@ -131,7 +131,7 @@ def test_arithmetric_operations_vars_and_expr(m: Model) -> None: assert_linequal(x**2 + x, x + x**2) assert_linequal(x**2 * 2, x**2 * 2) with pytest.raises(TypeError): - _ = x**2 * x # type: ignore + _ = x**2 * x def test_arithmetric_operations_con(m: Model) -> None: diff --git a/test/test_optimization.py b/test/test_optimization.py index 0a6a0fb4..b34650dc 100644 --- a/test/test_optimization.py +++ b/test/test_optimization.py @@ -386,7 +386,7 @@ def test_default_setting_expression_sol_accessor( qexpr = 4 * x**2 assert_equal(qexpr.solution, 4 * x.solution**2) - qexpr = 4 * x * y + qexpr = 4 * (x * y) # type: ignore assert_equal(qexpr.solution, 4 * x.solution * y.solution) diff --git a/test/test_typing.py b/test/test_typing.py index ca145f5c..e9afa55b 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -5,6 +5,11 @@ def test_operations_with_data_arrays_are_typed_correctly() -> None: + # Get the path of this file + file_path = __file__ + result = api.run([file_path]) + assert result[2] == 0, "Mypy returned issues: " + result[0] + m = linopy.Model() a: xr.DataArray = xr.DataArray([1, 2, 3]) @@ -12,7 +17,6 @@ def test_operations_with_data_arrays_are_typed_correctly() -> None: v: linopy.Variable = m.add_variables(lower=0.0, name="v") e: linopy.LinearExpression = v * 1.0 q = v * v - assert isinstance(q, linopy.QuadraticExpression) _ = a * v _ = v * a @@ -25,8 +29,3 @@ def test_operations_with_data_arrays_are_typed_correctly() -> None: _ = a * q _ = q * a _ = q + a - - # Get the path of this file - file_path = __file__ - result = api.run([file_path]) - assert result[2] == 0, "Mypy returned issues: " + result[0] From 93bfdfac4bfee7a5c82af2aa4e2180c729d1200e Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Mon, 5 May 2025 19:34:09 +0200 Subject: [PATCH 04/21] Further typing changes --- .gitignore | 3 +++ linopy/expressions.py | 32 +++++++++++++++---------------- linopy/variables.py | 20 +++++++++++++++++-- test/test_linear_expression.py | 8 ++++++++ test/test_quadratic_expression.py | 8 ++++++-- 5 files changed, 51 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index 603fe6ed..13866fb1 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ benchmark/notebooks/.ipynb_checkpoints benchmark/scripts/__pycache__ benchmark/scripts/benchmarks-pypsa-eur/__pycache__ benchmark/scripts/leftovers/ + +# IDE +.idea/ diff --git a/linopy/expressions.py b/linopy/expressions.py index 6a663597..0a39f8da 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -484,6 +484,14 @@ def print(self, display_max_rows: int = 20, display_max_terms: int = 20) -> None ) print(self) + @overload + def __add__( + self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression + ) -> LinearExpression: ... + + @overload + def __add__(self, other: QuadraticExpression) -> QuadraticExpression: ... + def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression: """ Add an expression to others. @@ -506,24 +514,16 @@ def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression: def __radd__(self, other: ConstantLike) -> LinearExpression: return self.__add__(other) - def __sub__(self, other: SideLike) -> LinearExpression: - """ - Subtract others from expression. - - Note: If other is a numpy array or pandas object without axes names, - dimension names of self will be filled in other - """ - if isinstance(other, QuadraticExpression): - return other.__rsub__(self) + @overload + def __sub__( + self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression + ) -> LinearExpression: ... - try: - if np.isscalar(other): - return self.assign_multiindex_safe(const=self.const - other) + @overload + def __sub__(self, other: QuadraticExpression) -> QuadraticExpression: ... - other = as_expression(other, model=self.model, dims=self.coord_dims) - return merge([self, -other], cls=self.__class__) - except TypeError: - return NotImplemented + def __sub__(self, other: SideLike) -> LinearExpression | QuadraticExpression: + return self.__add__(-other) def __neg__(self) -> LinearExpression: """ diff --git a/linopy/variables.py b/linopy/variables.py index c4d30152..84307e07 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -455,7 +455,15 @@ def __truediv__( except TypeError: return NotImplemented - def __add__(self, other: SideLike) -> LinearExpression: + @overload + def __add__( + self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression + ) -> LinearExpression: ... + + @overload + def __add__(self, other: QuadraticExpression) -> QuadraticExpression: ... + + def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression: """ Add variables to linear expressions or other variables. """ @@ -470,7 +478,15 @@ def __radd__(self, other: ConstantLike) -> LinearExpression: except ValueError: return NotImplemented - def __sub__(self, other: SideLike) -> LinearExpression: + @overload + def __sub__( + self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression + ) -> LinearExpression: ... + + @overload + def __sub__(self, other: QuadraticExpression) -> QuadraticExpression: ... + + def __sub__(self, other: SideLike) -> LinearExpression | QuadraticExpression: """ Subtract linear expressions or other variables from the variables. """ diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 3ed9482f..69d824ea 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -226,6 +226,14 @@ def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> assert_linequal(expr, expr2) +def test_linear_expression_with_raddition(m: Model, x: Variable): + expr = x * 1.0 + expr_2: LinearExpression = 10.0 + expr # type: ignore + assert isinstance(expr, LinearExpression) + expr_3: LinearExpression = expr + 10.0 # type: ignore + assert_linequal(expr_2, expr_3) + + def test_linear_expression_with_subtraction(m: Model, x: Variable, y: Variable) -> None: expr = x - y assert isinstance(expr, LinearExpression) diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index f2ae7c8a..fa99f6fb 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -144,8 +144,12 @@ def test_quadratic_expression_raddition(x: Variable, y: Variable) -> None: assert (expr.const == 5).all() assert expr.nterm == 2 - with pytest.raises(TypeError): - 5 + x * y + x + expr_2 = 5 + x * y + x + assert isinstance(expr_2, QuadraticExpression) + assert (expr_2.const == 5).all() + assert expr_2.nterm == 2 + + assert_quadequal(expr, expr_2) def test_quadratic_expression_subtraction(x: Variable, y: Variable) -> None: From 410d0da99217510c165d01c1dbfd3a0885203cb8 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Mon, 5 May 2025 19:43:20 +0200 Subject: [PATCH 05/21] fixed test --- linopy/expressions.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 0a39f8da..33da37b0 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -523,7 +523,10 @@ def __sub__( def __sub__(self, other: QuadraticExpression) -> QuadraticExpression: ... def __sub__(self, other: SideLike) -> LinearExpression | QuadraticExpression: - return self.__add__(-other) + try: + return self.__add__(-other) + except TypeError: + return NotImplemented def __neg__(self) -> LinearExpression: """ @@ -1641,7 +1644,10 @@ def __rsub__(self, other: SideLike) -> QuadraticExpression: """ Subtract expression from others. """ - return self.__neg__().__add__(other) + try: + return self.__neg__() + other + except TypeError: + return NotImplemented def __neg__(self) -> QuadraticExpression: """ From e7107e3ebdcbd53eaabece9b6e46bd8e57636131 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Mon, 5 May 2025 20:14:10 +0200 Subject: [PATCH 06/21] added tests --- linopy/expressions.py | 10 +++++++-- linopy/variables.py | 7 ++++-- test/test_linear_expression.py | 8 ++++++- test/test_quadratic_expression.py | 13 +++++++++++ test/test_variable.py | 37 +++++++++++++++++++++++++++++++ 5 files changed, 70 insertions(+), 5 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 33da37b0..0da7dd34 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -512,7 +512,10 @@ def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression: return NotImplemented def __radd__(self, other: ConstantLike) -> LinearExpression: - return self.__add__(other) + try: + return self.__add__(other) + except TypeError: + return NotImplemented @overload def __sub__( @@ -611,7 +614,10 @@ def __rmul__( """ Right-multiply the expr by a factor. """ - return self.__mul__(other) + try: + return self.__mul__(other) + except TypeError: + return NotImplemented def __matmul__( self, other: LinearExpression | Variable | ndarray | DataArray diff --git a/linopy/variables.py b/linopy/variables.py index 84307e07..cdbebd2e 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -411,7 +411,10 @@ def __rmul__(self, other: ConstantLike) -> LinearExpression: """ Right-multiply variables by a constant """ - return self * other + try: + return self * other + except TypeError: + return NotImplemented def __pow__(self, other: int) -> QuadraticExpression: """ @@ -475,7 +478,7 @@ def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression: def __radd__(self, other: ConstantLike) -> LinearExpression: try: return self.__add__(other) - except ValueError: + except TypeError: return NotImplemented @overload diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 69d824ea..be7efe6a 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -14,7 +14,7 @@ import xarray as xr from xarray.testing import assert_equal -from linopy import LinearExpression, Model, Variable, merge +from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge from linopy.constants import HELPER_DIMS, TERM_DIM from linopy.expressions import ScalarLinearExpression from linopy.testing import assert_linequal @@ -208,6 +208,9 @@ def test_linear_expression_with_multiplication(x: Variable) -> None: expr = pd.Series([1, 2], index=pd.RangeIndex(2, name="dim_0")) * x assert isinstance(expr, LinearExpression) + assert expr.__mul__(object()) is NotImplemented + assert expr.__rmul__(object()) is NotImplemented # type: ignore + def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> None: expr = 10 * x + y @@ -225,6 +228,9 @@ def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> expr2 = x.add(y) assert_linequal(expr, expr2) + expr3 = x + (x * x) + assert isinstance(expr3, QuadraticExpression) + def test_linear_expression_with_raddition(m: Model, x: Variable): expr = x * 1.0 diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index fa99f6fb..d8a1f556 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -291,3 +291,16 @@ def test_matrices_matrix_mixed_linear_and_quadratic( def test_quadratic_to_constraint(x: Variable, y: Variable) -> None: with pytest.raises(NotImplementedError): x * y <= 10 + + +def test_power_of_three(x: Variable) -> None: + with pytest.raises(TypeError): + x * x * x + with pytest.raises(TypeError): + (x * 1) * (x * x) + with pytest.raises(TypeError): + (x * x) * (x * 1) + with pytest.raises(TypeError): + x**3 + with pytest.raises(TypeError): + (x * x) * (x * x) diff --git a/test/test_variable.py b/test/test_variable.py index 065dd1ce..6b0e4595 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -17,6 +17,7 @@ import linopy import linopy.variables from linopy import Model +from linopy.testing import assert_linequal @pytest.fixture @@ -305,3 +306,39 @@ def test_variable_iterate_slices(x: linopy.Variable) -> None: for s in slices: assert isinstance(s, linopy.variables.Variable) assert s.size <= 2 + + +def test_variable_addition(x: linopy.Variable) -> None: + expr1 = x + 1 + assert isinstance(expr1, linopy.expressions.LinearExpression) + expr2 = 1 + x + assert isinstance(expr2, linopy.expressions.LinearExpression) + assert_linequal(expr1, expr2) + + assert x.__radd__(object()) is NotImplemented + assert x.__add__(object()) is NotImplemented + + +def test_variable_subtraction(x: linopy.Variable) -> None: + expr1 = -x + 1 + assert isinstance(expr1, linopy.expressions.LinearExpression) + expr2 = 1 - x + assert isinstance(expr2, linopy.expressions.LinearExpression) + assert_linequal(expr1, expr2) + + assert x.__rsub__(object()) is NotImplemented + assert x.__sub__(object()) is NotImplemented + + +def test_variable_multiplication(x: linopy.Variable) -> None: + expr1 = x * 2 + assert isinstance(expr1, linopy.expressions.LinearExpression) + expr2 = 2 * x + assert isinstance(expr2, linopy.expressions.LinearExpression) + assert_linequal(expr1, expr2) + + expr3 = x * x + assert isinstance(expr3, linopy.expressions.QuadraticExpression) + + assert x.__rmul__(object()) is NotImplemented + assert x.__mul__(object()) is NotImplemented From ddbfb0ed1a3c7689674f94e79ffa94a27beb5e81 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Tue, 6 May 2025 00:12:51 +0200 Subject: [PATCH 07/21] Went down a rabbit hole --- linopy/expressions.py | 1054 ++++++++++++++++------------- linopy/testing.py | 6 +- linopy/variables.py | 31 +- test/test_linear_expression.py | 19 +- test/test_quadratic_expression.py | 28 +- 5 files changed, 640 insertions(+), 498 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 0da7dd34..7a25f3fd 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -9,10 +9,11 @@ import functools import logging +from abc import ABC, abstractmethod from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field from itertools import product, zip_longest -from typing import TYPE_CHECKING, Any, TypeVar, overload +from typing import TYPE_CHECKING, Any, Type, TypeVar, overload from warnings import warn import numpy as np @@ -314,44 +315,7 @@ def sum(self, **kwargs: Any) -> LinearExpression: return LinearExpression(ds, self.model) -class LinearExpression: - """ - A linear expression consisting of terms of coefficients and variables. - - The LinearExpression class is a subclass of xarray.Dataset which allows to - apply most xarray functions on it. However most arithmetic operations are - overwritten. Like this you can easily expand and modify the linear - expression. - - Examples - -------- - >>> from linopy import Model - >>> import pandas as pd - >>> m = Model() - >>> x = m.add_variables(pd.Series([0, 0]), 1, name="x") - >>> y = m.add_variables(4, pd.Series([8, 10]), name="y") - - Combining expressions: - - >>> expr = 3 * x - >>> type(expr) - - - >>> other = 4 * y - >>> type(expr + other) - - - Multiplying: - - >>> type(3 * expr) - - - Summation over dimensions - - >>> type(expr.sum(dim="dim_0")) - - """ - +class BaseExpression(ABC): __slots__ = ("_data", "_model") __array_ufunc__ = None __array_priority__ = 10000 @@ -360,6 +324,13 @@ class LinearExpression: _fill_value = FILL_VALUE _data: Dataset + @property + @abstractmethod + def flat(self) -> pd.DataFrame: ... + + @abstractmethod + def to_polars(self) -> pl.DataFrame: ... + def __init__(self, data: Dataset | Any | None, model: Model) -> None: from linopy.model import Model @@ -484,88 +455,46 @@ def print(self, display_max_rows: int = 20, display_max_terms: int = 20) -> None ) print(self) - @overload + @abstractmethod def __add__( - self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression - ) -> LinearExpression: ... - - @overload - def __add__(self, other: QuadraticExpression) -> QuadraticExpression: ... - - def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression: - """ - Add an expression to others. + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: ... - Note: If other is a numpy array or pandas object without axes names, - dimension names of self will be filled in other - """ - if isinstance(other, QuadraticExpression): - return other.__add__(self) + @abstractmethod + def __radd__(self: GenericExpression, other: SideLike) -> GenericExpression: ... - try: - if np.isscalar(other): - return self.assign(const=self.const + other) + @abstractmethod + def __sub__( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: ... - other = as_expression(other, model=self.model, dims=self.coord_dims) - return merge([self, other], cls=self.__class__) - except TypeError: - return NotImplemented + @abstractmethod + def __rsub__(self: GenericExpression, other: SideLike) -> GenericExpression: ... - def __radd__(self, other: ConstantLike) -> LinearExpression: - try: - return self.__add__(other) - except TypeError: - return NotImplemented + @abstractmethod + def __mul__( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: ... - @overload - def __sub__( - self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression - ) -> LinearExpression: ... + @abstractmethod + def __rmul__( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: ... - @overload - def __sub__(self, other: QuadraticExpression) -> QuadraticExpression: ... + @abstractmethod + def __matmul__( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: ... - def __sub__(self, other: SideLike) -> LinearExpression | QuadraticExpression: - try: - return self.__add__(-other) - except TypeError: - return NotImplemented + @abstractmethod + def __pow__(self, other: int) -> QuadraticExpression: ... - def __neg__(self) -> LinearExpression: + def __neg__(self: GenericExpression) -> GenericExpression: """ Get the negative of the expression. """ return self.assign_multiindex_safe(coeffs=-self.coeffs, const=-self.const) - @overload - def __mul__( - self: GenericLinearExpression, other: ConstantLike - ) -> GenericLinearExpression: ... - - @overload - def __mul__(self, other: VariableLike | ExpressionLike) -> QuadraticExpression: ... - - def __mul__( - self: GenericLinearExpression, - other: SideLike, - ) -> GenericLinearExpression | QuadraticExpression: - """ - Multiply the expr by a factor. - """ - if isinstance(other, QuadraticExpression): - return other.__rmul__(self) - - try: - if isinstance(other, (variables.Variable, variables.ScalarVariable)): - other = other.to_linexpr() - - if isinstance(other, (LinearExpression, ScalarLinearExpression)): - return self._multiply_by_linear_expression(other) - else: - return self._multiply_by_constant(other) - except TypeError: - return NotImplemented - def _multiply_by_linear_expression( self, other: LinearExpression | ScalarLinearExpression ) -> QuadraticExpression: @@ -583,57 +512,24 @@ def _multiply_by_linear_expression( .broadcast_like(self.data) .assign(const=other.const) ) - res = merge([self, ds], dim=FACTOR_DIM, cls=QuadraticExpression) + res = merge([self, ds], dim=FACTOR_DIM, cls=QuadraticExpression) # type: ignore # deal with cross terms c1 * v2 + c2 * v1 if self.has_constant: res = res + self.const * other.reset_const() if other.has_constant: res = res + self.reset_const() * other.const - return res # type: ignore + return res def _multiply_by_constant( - self: GenericLinearExpression, other: ConstantLike - ) -> GenericLinearExpression: + self: GenericExpression, other: ConstantLike + ) -> GenericExpression: multiplier = as_dataarray(other, coords=self.coords, dims=self.coord_dims) coeffs = self.coeffs * multiplier assert all(coeffs.sizes[d] == s for d, s in self.coeffs.sizes.items()) const = self.const * multiplier return self.assign(coeffs=coeffs, const=const) - def __pow__(self, other: int) -> QuadraticExpression: - """ - Power of the expression with a coefficient. The only coefficient allowed is 2. - """ - if not other == 2: - raise ValueError("Power must be 2.") - return self * self # type: ignore - - def __rmul__( - self: GenericLinearExpression, other: ConstantLike - ) -> GenericLinearExpression: - """ - Right-multiply the expr by a factor. - """ - try: - return self.__mul__(other) - except TypeError: - return NotImplemented - - def __matmul__( - self, other: LinearExpression | Variable | ndarray | DataArray - ) -> LinearExpression | QuadraticExpression: - """ - Matrix multiplication with other, similar to xarray dot. - """ - if not isinstance(other, (LinearExpression, variables.Variable)): - other = as_dataarray(other, coords=self.coords, dims=self.coord_dims) - - common_dims = list(set(self.coord_dims).intersection(other.dims)) - return (self * other).sum(dim=common_dims) - - def __div__( - self: GenericLinearExpression, other: SideLike - ) -> GenericLinearExpression: + def __div__(self: GenericExpression, other: SideLike) -> GenericExpression: try: if isinstance( other, @@ -650,13 +546,11 @@ def __div__( f"{type(self)} and {type(other)}" "Non-linear expressions are not yet supported." ) - return self.__mul__(1 / other) + return self._multiply_by_constant(other=1 / other) except TypeError: return NotImplemented - def __truediv__( - self: GenericLinearExpression, other: SideLike - ) -> GenericLinearExpression: + def __truediv__(self: GenericExpression, other: SideLike) -> GenericExpression: return self.__div__(other) def __le__(self, rhs: SideLike) -> Constraint: @@ -678,27 +572,33 @@ def __lt__(self, other: Any) -> NotImplementedType: "Inequalities only ever defined for >= rather than >." ) - def add(self, other: SideLike) -> LinearExpression: + def add( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: """ Add an expression to others. """ return self.__add__(other) - def sub(self, other: SideLike) -> LinearExpression: + def sub( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: """ Subtract others from expression. """ return self.__sub__(other) - def mul(self, other: SideLike) -> LinearExpression | QuadraticExpression: + def mul( + self: GenericExpression, other: SideLike + ) -> GenericExpression | QuadraticExpression: """ Multiply the expr by a factor. """ return self.__mul__(other) def div( - self, other: Variable | float | int - ) -> LinearExpression | QuadraticExpression: + self: GenericExpression, other: VariableLike | ConstantLike + ) -> GenericExpression | QuadraticExpression: """ Divide the expr by a factor. """ @@ -710,15 +610,17 @@ def pow(self, other: int) -> QuadraticExpression: """ return self.__pow__(other) - def dot(self, other: ndarray) -> LinearExpression: + def dot( + self: GenericExpression, other: ndarray + ) -> GenericExpression | QuadraticExpression: """ Matrix multiplication with other, similar to xarray dot. """ return self.__matmul__(other) def __getitem__( - self, selector: int | tuple[slice, list[int]] | slice - ) -> LinearExpression | QuadraticExpression: + self: GenericExpression, selector: int | tuple[slice, list[int]] | slice + ) -> GenericExpression: """ Get selection from the expression. This is a wrapper around the xarray __getitem__ method. It returns a @@ -861,42 +763,12 @@ def solution(self) -> DataArray: sol = (self.coeffs * vals).sum(TERM_DIM) + self.const return sol.rename("solution") - @classmethod - def _sum( - cls, - expr: LinearExpression | Dataset, - dim: DimsLike | None = None, - ) -> Dataset: - data = _expr_unwrap(expr) - - if isinstance(dim, str): - dim = [dim] - elif isinstance(dim, EllipsisType): - dim = None - - if dim is None: - vars = DataArray(data.vars.data.ravel(), dims=TERM_DIM) - coeffs = DataArray(data.coeffs.data.ravel(), dims=TERM_DIM) - const = data.const.sum() - ds = xr.Dataset({"vars": vars, "coeffs": coeffs, "const": const}) - else: - dim = [d for d in dim if d != TERM_DIM] - ds = ( - data[["coeffs", "vars"]] - .reset_index(dim, drop=True) - .rename({TERM_DIM: STACKED_TERM_DIM}) - .stack({TERM_DIM: [STACKED_TERM_DIM] + dim}, create_index=False) - ) - ds = assign_multiindex_safe(ds, const=data.const.sum(dim)) - - return ds - def sum( - self, + self: GenericExpression, dim: DimsLike | None = None, drop_zeros: bool = False, **kwargs: Any, - ) -> LinearExpression: + ) -> GenericExpression: """ Sum the expression over all or a subset of dimensions. @@ -980,215 +852,50 @@ def cumsum( dim_dict = {dim_name: self.data.sizes[dim_name] for dim_name in dim} return self.rolling(dim=dim_dict).sum(keep_attrs=keep_attrs, skipna=skipna) - @classmethod - def from_tuples( - cls, *tuples: tuple, model: Model | None = None - ) -> LinearExpression: + def to_constraint( + self, sign: SignLike, rhs: ConstantLike | VariableLike | ExpressionLike + ) -> Constraint: """ - Create a linear expression by using tuples of coefficients and - variables. - - The function internally checks that all variables in the tuples belong to the same - reference model. + Convert a linear expression to a constraint. Parameters ---------- - tuples : tuples of (coefficients, variables) - Each tuple represents one term in the resulting linear expression, - which can possibly span over multiple dimensions: - - * coefficients : int/float/array_like - The coefficient(s) in the term, if the coefficients array - contains dimensions which do not appear in - the variables, the variables are broadcasted. - * variables : str/array_like/linopy.Variable - The variable(s) going into the term. These may be referenced - by name. + sign : str, array-like + Sign(s) of the constraints. + rhs : constant, Variable, LinearExpression + Right-hand side of the constraint. Returns ------- - linopy.LinearExpression - - Examples - -------- - >>> from linopy import Model - >>> import pandas as pd - >>> m = Model() - >>> x = m.add_variables(pd.Series([0, 0]), 1) - >>> y = m.add_variables(4, pd.Series([8, 10])) - >>> expr = LinearExpression.from_tuples((10, x), (1, y)) - - This is the same as calling ``10*x + y`` but a bit more performant. + Constraint with strict separation of the linear expressions of variables + which are moved to the left-hand-side and constant values which are moved + to the right-hand side. """ - exprs = [] - for t in tuples: - if len(t) == 2: - # assume first element is coefficient and second is variable - c, v = t - if not isinstance(v, (variables.Variable, variables.ScalarVariable)): - raise TypeError("Expected variable as second element of tuple.") - expr = v.to_linexpr(c) - const = None - if model is None: - model = expr.model # TODO: Ensure equality of models - elif len(t) == 1: - # assume that the element is a constant - c, v = None, None - (const,) = as_dataarray(t) - if model is None: - raise ValueError("Model must be provided when using constants.") - expr = LinearExpression(const, model) - else: - raise ValueError("Expected tuples of length 1 or 2.") + all_to_lhs = (self - rhs).data + data = assign_multiindex_safe( + all_to_lhs[["coeffs", "vars"]], sign=sign, rhs=-all_to_lhs.const + ) + return constraints.Constraint(data, model=self.model) - exprs.append(expr) + def reset_const(self: GenericExpression) -> GenericExpression: + """ + Reset the constant of the linear expression to zero. + """ + return self.__class__(self.data[["coeffs", "vars"]], self.model) - return merge(exprs, cls=cls) if len(exprs) > 1 else exprs[0] + def isnull(self) -> DataArray: + """ + Get a boolean mask with true values where there is only missing values in an expression. - @classmethod - def from_rule( - cls, - model: Model, - rule: Callable, - coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, - ) -> LinearExpression: + Returns + ------- + xr.DataArray """ - Create a linear expression from a rule and a set of coordinates. - - This functionality mirrors the assignment of linear expression as done by - Pyomo. - - - Parameters - ---------- - model : linopy.Model - Passed to function `rule` as a first argument. - rule : callable - Function to be called for each combinations in `coords`. - The first argument of the function is the underlying `linopy.Model`. - The following arguments are given by the coordinates for accessing - the variables. The function has to return a - `ScalarLinearExpression`. Therefore use the `.at` accessor when - indexing variables. - coords : coordinate-like - Coordinates to processed by `xarray.DataArray`. - For each combination of coordinates, the function - given by `rule` is called. The order and size of coords has - to be same as the argument list followed by `model` in - function `rule`. - - - Returns - ------- - linopy.LinearExpression - - Examples - -------- - >>> from linopy import Model, LinearExpression - >>> m = Model() - >>> coords = pd.RangeIndex(10), ["a", "b"] - >>> x = m.add_variables(0, 100, coords) - >>> def bound(m, i, j): - ... if i % 2: - ... return (i - 1) * x.at[i - 1, j] - ... else: - ... return i * x.at[i, j] - ... - >>> expr = LinearExpression.from_rule(m, bound, coords) - >>> con = m.add_constraints(expr <= 10) - """ - if not isinstance(coords, DataArrayCoordinates): - coords = DataArray(coords=coords).coords - - # test output type - output = rule(model, *[c.values[0] for c in coords.values()]) - if not isinstance(output, ScalarLinearExpression) and output is not None: - msg = f"`rule` has to return ScalarLinearExpression not {type(output)}." - raise TypeError(msg) - - combinations = product(*[c.values for c in coords.values()]) - exprs = [] - placeholder = ScalarLinearExpression((np.nan,), (-1,), model) - exprs = [rule(model, *coord) or placeholder for coord in combinations] - return cls._from_scalarexpression_list(exprs, coords, model) - - @classmethod - def _from_scalarexpression_list( - cls, - exprs: list[ScalarLinearExpression], - coords: Mapping, - model: Model, - ) -> LinearExpression: - """ - Create a LinearExpression from a list of lists with different lengths. - """ - shape = list(map(len, coords.values())) - - coeffs = array(tuple(zip_longest(*(e.coeffs for e in exprs), fillvalue=nan))) - vars = array(tuple(zip_longest(*(e.vars for e in exprs), fillvalue=-1))) - - nterm = vars.shape[0] - coeffs = coeffs.reshape((nterm, *shape)) - vars = vars.reshape((nterm, *shape)) - - coeffdata = DataArray(coeffs, coords, dims=(TERM_DIM, *coords)) - vardata = DataArray(vars, coords, dims=(TERM_DIM, *coords)) - ds = Dataset({"coeffs": coeffdata, "vars": vardata}).transpose(..., TERM_DIM) - - return cls(ds, model) - - def to_quadexpr(self) -> QuadraticExpression: - """Convert LinearExpression to QuadraticExpression.""" - vars = self.data.vars.expand_dims(FACTOR_DIM) - fill_value = self._fill_value["vars"] - vars = xr.concat([vars, xr.full_like(vars, fill_value)], dim=FACTOR_DIM) - data = self.data.assign(vars=vars) - return QuadraticExpression(data, self.model) - - def to_constraint( - self, sign: SignLike, rhs: ConstantLike | VariableLike | ExpressionLike - ) -> Constraint: - """ - Convert a linear expression to a constraint. - - Parameters - ---------- - sign : str, array-like - Sign(s) of the constraints. - rhs : constant, Variable, LinearExpression - Right-hand side of the constraint. - - Returns - ------- - Constraint with strict separation of the linear expressions of variables - which are moved to the left-hand-side and constant values which are moved - to the right-hand side. - """ - all_to_lhs = (self - rhs).data - data = assign_multiindex_safe( - all_to_lhs[["coeffs", "vars"]], sign=sign, rhs=-all_to_lhs.const - ) - return constraints.Constraint(data, model=self.model) - - def reset_const(self) -> LinearExpression: - """ - Reset the constant of the linear expression to zero. - """ - return self.__class__(self.data[["coeffs", "vars"]], self.model) - - def isnull(self) -> DataArray: - """ - Get a boolean mask with true values where there is only missing values in an expression. - - Returns - ------- - xr.DataArray - """ - helper_dims = set(self.vars.dims).intersection(HELPER_DIMS) - return (self.vars == -1).all(helper_dims) & self.const.isnull() + helper_dims = set(self.vars.dims).intersection(HELPER_DIMS) + return (self.vars == -1).all(helper_dims) & self.const.isnull() def where( - self, + self: GenericExpression, cond: DataArray, other: LinearExpression | int @@ -1196,7 +903,7 @@ def where( | dict[str, float | int | DataArray] | None = None, **kwargs: Any, - ) -> LinearExpression | QuadraticExpression: + ) -> GenericExpression: """ Filter variables based on a condition. @@ -1218,7 +925,7 @@ def where( Returns ------- - linopy.LinearExpression + linopy.LinearExpression or linopy.QuadraticExpression """ # Cannot set `other` if drop=True _other: dict[str, float] | dict[str, int | float | DataArray] | DataArray | None @@ -1240,14 +947,14 @@ def where( return self.__class__(self.data.where(cond, other=_other, **kwargs), self.model) def fillna( - self, + self: GenericExpression, value: int | float | DataArray | Dataset | LinearExpression | dict[str, float | int | DataArray], - ) -> LinearExpression: + ) -> GenericExpression: """ Fill missing values with a given value. @@ -1263,15 +970,15 @@ def fillna( Returns ------- - linopy.LinearExpression - A new `linopy.LinearExpression` object with missing values filled with the given value. + linopy.LinearExpression or linopy.QuadraticExpression + A new object with missing values filled with the given value. """ value = _expr_unwrap(value) if isinstance(value, (DataArray, np.floating, np.integer, int, float)): value = {"const": value} return self.__class__(self.data.fillna(value), self.model) - def diff(self, dim: str, n: int = 1) -> LinearExpression: + def diff(self: GenericExpression, dim: str, n: int = 1) -> GenericExpression: """ Calculate the n-th order discrete difference along given axis. @@ -1287,7 +994,7 @@ def diff(self, dim: str, n: int = 1) -> LinearExpression: Returns ------- - linopy.LinearExpression + linopy.LinearExpression or linopy.QuadraticExpression """ return self - self.shift({dim: n}) @@ -1392,7 +1099,7 @@ def empty(self) -> EmptyDeprecationWrapper: """ return EmptyDeprecationWrapper(not self.size) - def densify_terms(self) -> LinearExpression: + def densify_terms(self: GenericExpression) -> GenericExpression: """ Move all non-zero term entries to the front and cut off all-zero entries in the term-axis. @@ -1423,24 +1130,288 @@ def densify_terms(self) -> LinearExpression: return self.__class__(data.sel({TERM_DIM: slice(0, nterm)}), self.model) - def sanitize(self) -> LinearExpression: + def sanitize(self: GenericExpression) -> GenericExpression: + """ + Sanitize LinearExpression by ensuring int dtype for variables. + + Returns + ------- + linopy.LinearExpression + """ + if not np.issubdtype(self.vars.dtype, np.integer): + return self.assign(vars=self.vars.fillna(-1).astype(int)) + + return self + + def equals(self, other: BaseExpression) -> bool: + return self.data.equals(_expr_unwrap(other)) + + def __iter__(self) -> Iterator[Hashable]: + return self.data.__iter__() + + @classmethod + def _sum( + cls: Type[GenericExpression], + expr: GenericExpression | Dataset, + dim: DimsLike | None = None, + ) -> Dataset: + data = _expr_unwrap(expr) + if cls is QuadraticExpression: + dim = dim or list(set(data.dims) - set(HELPER_DIMS)) + + if isinstance(dim, str): + dim = [dim] + elif isinstance(dim, EllipsisType): + dim = None + + if dim is None: + vars = DataArray(data.vars.data.ravel(), dims=TERM_DIM) + coeffs = DataArray(data.coeffs.data.ravel(), dims=TERM_DIM) + const = data.const.sum() + ds = xr.Dataset({"vars": vars, "coeffs": coeffs, "const": const}) + else: + dim = [d for d in dim if d != TERM_DIM] + ds = ( + data[["coeffs", "vars"]] + .reset_index(dim, drop=True) + .rename({TERM_DIM: STACKED_TERM_DIM}) + .stack({TERM_DIM: [STACKED_TERM_DIM] + dim}, create_index=False) + ) + ds = assign_multiindex_safe(ds, const=data.const.sum(dim)) + + return ds + + # Wrapped function which would convert variable to dataarray + assign = exprwrap(Dataset.assign) + + assign_multiindex_safe = exprwrap(assign_multiindex_safe) + + assign_attrs = exprwrap(Dataset.assign_attrs) + + assign_coords = exprwrap(Dataset.assign_coords) + + astype = exprwrap(Dataset.astype) + + bfill = exprwrap(Dataset.bfill) + + broadcast_like = exprwrap(Dataset.broadcast_like) + + chunk = exprwrap(Dataset.chunk) + + drop = exprwrap(Dataset.drop) + + drop_vars = exprwrap(Dataset.drop_vars) + + drop_sel = exprwrap(Dataset.drop_sel) + + drop_isel = exprwrap(Dataset.drop_isel) + + expand_dims = exprwrap(Dataset.expand_dims) + + ffill = exprwrap(Dataset.ffill) + + sel = exprwrap(Dataset.sel) + + isel = exprwrap(Dataset.isel) + + shift = exprwrap(Dataset.shift) + + swap_dims = exprwrap(Dataset.swap_dims) + + set_index = exprwrap(Dataset.set_index) + + reindex = exprwrap(Dataset.reindex, fill_value=_fill_value) + + reindex_like = exprwrap(Dataset.reindex_like, fill_value=_fill_value) + + rename = exprwrap(Dataset.rename) + + reset_index = exprwrap(Dataset.reset_index) + + rename_dims = exprwrap(Dataset.rename_dims) + + roll = exprwrap(Dataset.roll) + + stack = exprwrap(Dataset.stack) + + unstack = exprwrap(Dataset.unstack) + + iterate_slices = iterate_slices + + +GenericExpression = TypeVar("GenericExpression", bound=BaseExpression) + + +class LinearExpression(BaseExpression): + """ + A linear expression consisting of terms of coefficients and variables. + + The LinearExpression class is a subclass of xarray.Dataset which allows to + apply most xarray functions on it. However most arithmetic operations are + overwritten. Like this you can easily expand and modify the linear + expression. + + Examples + -------- + >>> from linopy import Model + >>> import pandas as pd + >>> m = Model() + >>> x = m.add_variables(pd.Series([0, 0]), 1, name="x") + >>> y = m.add_variables(4, pd.Series([8, 10]), name="y") + + Combining expressions: + + >>> expr = 3 * x + >>> type(expr) + + + >>> other = 4 * y + >>> type(expr + other) + + + Multiplying: + + >>> type(3 * expr) + + + Summation over dimensions + + >>> type(expr.sum(dim="dim_0")) + + """ + + @overload + def __add__( + self, + other: ConstantLike | VariableLike | ScalarLinearExpression | LinearExpression, + ) -> LinearExpression: ... + + @overload + def __add__(self, other: QuadraticExpression) -> QuadraticExpression: ... + + def __add__( + self, + other: ConstantLike + | VariableLike + | ScalarLinearExpression + | LinearExpression + | QuadraticExpression, + ) -> LinearExpression | QuadraticExpression: + """ + Add an expression to others. + + Note: If other is a numpy array or pandas object without axes names, + dimension names of self will be filled in other + """ + if isinstance(other, QuadraticExpression): + return other.__add__(self) + + try: + if np.isscalar(other): + return self.assign(const=self.const + other) + + other = as_expression(other, model=self.model, dims=self.coord_dims) + return merge([self, other], cls=self.__class__) + except TypeError: + return NotImplemented + + def __radd__(self, other: ConstantLike) -> LinearExpression: + try: + return self.__add__(other) + except TypeError: + return NotImplemented + + @overload + def __sub__( + self, + other: ConstantLike | VariableLike | ScalarLinearExpression | LinearExpression, + ) -> LinearExpression: ... + + @overload + def __sub__(self, other: QuadraticExpression) -> QuadraticExpression: ... + + def __sub__( + self, + other: ConstantLike + | VariableLike + | ScalarLinearExpression + | LinearExpression + | QuadraticExpression, + ) -> LinearExpression | QuadraticExpression: + try: + return self.__add__(-other) + except TypeError: + return NotImplemented + + def __rsub__(self, other: ConstantLike | Variable) -> LinearExpression: + try: + return self.__add__(-other) + except TypeError: + return NotImplemented + + @overload + def __mul__(self, other: ConstantLike) -> LinearExpression: ... + + @overload + def __mul__(self, other: VariableLike | ExpressionLike) -> QuadraticExpression: ... + + def __mul__( + self, + other: SideLike, + ) -> LinearExpression | QuadraticExpression: + """ + Multiply the expr by a factor. + """ + if isinstance(other, QuadraticExpression): + return other.__rmul__(self) + + try: + if isinstance(other, (variables.Variable, variables.ScalarVariable)): + other = other.to_linexpr() + + if isinstance(other, (LinearExpression, ScalarLinearExpression)): + return self._multiply_by_linear_expression(other) + else: + return self._multiply_by_constant(other) + except TypeError: + return NotImplemented + + def __pow__(self, other: int) -> QuadraticExpression: + """ + Power of the expression with a coefficient. The only coefficient allowed is 2. + """ + if not other == 2: + raise ValueError("Power must be 2.") + return self * self # type: ignore + + def __rmul__(self, other: ConstantLike) -> LinearExpression: + """ + Right-multiply the expr by a factor. + """ + try: + return self.__mul__(other) + except TypeError: + return NotImplemented + + @overload + def __matmul__(self, other: ConstantLike) -> LinearExpression: ... + + @overload + def __matmul__( + self, other: VariableLike | ExpressionLike + ) -> QuadraticExpression: ... + + def __matmul__( + self, other: ConstantLike | VariableLike | ExpressionLike + ) -> LinearExpression | QuadraticExpression: """ - Sanitize LinearExpression by ensuring int dtype for variables. - - Returns - ------- - linopy.LinearExpression + Matrix multiplication with other, similar to xarray dot. """ - if not np.issubdtype(self.vars.dtype, np.integer): - return self.assign(vars=self.vars.fillna(-1).astype(int)) - - return self - - def equals(self, other: LinearExpression) -> bool: - return self.data.equals(_expr_unwrap(other)) + if not isinstance(other, (LinearExpression, variables.Variable)): + other = as_dataarray(other, coords=self.coords, dims=self.coord_dims) - def __iter__(self) -> Iterator[Hashable]: - return self.data.__iter__() + common_dims = list(set(self.coord_dims).intersection(other.dims)) + return (self * other).sum(dim=common_dims) @property def flat(self) -> pd.DataFrame: @@ -1466,6 +1437,14 @@ def mask_func(data: pd.DataFrame) -> pd.Series: check_has_nulls(df, name=self.type) return df + def to_quadexpr(self) -> QuadraticExpression: + """Convert LinearExpression to QuadraticExpression.""" + vars = self.data.vars.expand_dims(FACTOR_DIM) + fill_value = self._fill_value["vars"] + vars = xr.concat([vars, xr.full_like(vars, fill_value)], dim=FACTOR_DIM) + data = self.data.assign(vars=vars) + return QuadraticExpression(data, self.model) + def to_polars(self) -> pl.DataFrame: """ Convert the expression to a polars DataFrame. @@ -1484,68 +1463,165 @@ def to_polars(self) -> pl.DataFrame: check_has_nulls_polars(df, name=self.type) return df - # Wrapped function which would convert variable to dataarray - assign = exprwrap(Dataset.assign) - - assign_multiindex_safe = exprwrap(assign_multiindex_safe) - - assign_attrs = exprwrap(Dataset.assign_attrs) - - assign_coords = exprwrap(Dataset.assign_coords) - - astype = exprwrap(Dataset.astype) - - bfill = exprwrap(Dataset.bfill) - - broadcast_like = exprwrap(Dataset.broadcast_like) - - chunk = exprwrap(Dataset.chunk) + @classmethod + def _from_scalarexpression_list( + cls, + exprs: list[ScalarLinearExpression], + coords: Mapping, + model: Model, + ) -> LinearExpression: + """ + Create a LinearExpression from a list of lists with different lengths. + """ + shape = list(map(len, coords.values())) - drop = exprwrap(Dataset.drop) + coeffs = array(tuple(zip_longest(*(e.coeffs for e in exprs), fillvalue=nan))) + vars = array(tuple(zip_longest(*(e.vars for e in exprs), fillvalue=-1))) - drop_vars = exprwrap(Dataset.drop_vars) + nterm = vars.shape[0] + coeffs = coeffs.reshape((nterm, *shape)) + vars = vars.reshape((nterm, *shape)) - drop_sel = exprwrap(Dataset.drop_sel) + coeffdata = DataArray(coeffs, coords, dims=(TERM_DIM, *coords)) + vardata = DataArray(vars, coords, dims=(TERM_DIM, *coords)) + ds = Dataset({"coeffs": coeffdata, "vars": vardata}).transpose(..., TERM_DIM) - drop_isel = exprwrap(Dataset.drop_isel) + return cls(ds, model) - expand_dims = exprwrap(Dataset.expand_dims) + @classmethod + def from_rule( + cls, + model: Model, + rule: Callable, + coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + ) -> LinearExpression: + """ + Create a linear expression from a rule and a set of coordinates. - ffill = exprwrap(Dataset.ffill) + This functionality mirrors the assignment of linear expression as done by + Pyomo. - sel = exprwrap(Dataset.sel) - isel = exprwrap(Dataset.isel) + Parameters + ---------- + model : linopy.Model + Passed to function `rule` as a first argument. + rule : callable + Function to be called for each combinations in `coords`. + The first argument of the function is the underlying `linopy.Model`. + The following arguments are given by the coordinates for accessing + the variables. The function has to return a + `ScalarLinearExpression`. Therefore use the `.at` accessor when + indexing variables. + coords : coordinate-like + Coordinates to processed by `xarray.DataArray`. + For each combination of coordinates, the function + given by `rule` is called. The order and size of coords has + to be same as the argument list followed by `model` in + function `rule`. - shift = exprwrap(Dataset.shift) - swap_dims = exprwrap(Dataset.swap_dims) + Returns + ------- + linopy.LinearExpression - set_index = exprwrap(Dataset.set_index) + Examples + -------- + >>> from linopy import Model, LinearExpression + >>> m = Model() + >>> coords = pd.RangeIndex(10), ["a", "b"] + >>> x = m.add_variables(0, 100, coords) + >>> def bound(m, i, j): + ... if i % 2: + ... return (i - 1) * x.at[i - 1, j] + ... else: + ... return i * x.at[i, j] + ... + >>> expr = LinearExpression.from_rule(m, bound, coords) + >>> con = m.add_constraints(expr <= 10) + """ + if not isinstance(coords, DataArrayCoordinates): + coords = DataArray(coords=coords).coords - reindex = exprwrap(Dataset.reindex, fill_value=_fill_value) + # test output type + output = rule(model, *[c.values[0] for c in coords.values()]) + if not isinstance(output, ScalarLinearExpression) and output is not None: + msg = f"`rule` has to return ScalarLinearExpression not {type(output)}." + raise TypeError(msg) - reindex_like = exprwrap(Dataset.reindex_like, fill_value=_fill_value) + combinations = product(*[c.values for c in coords.values()]) + exprs = [] + placeholder = ScalarLinearExpression((np.nan,), (-1,), model) + exprs = [rule(model, *coord) or placeholder for coord in combinations] + return cls._from_scalarexpression_list(exprs, coords, model) - rename = exprwrap(Dataset.rename) + @classmethod + def from_tuples( + cls, *tuples: tuple, model: Model | None = None + ) -> LinearExpression: + """ + Create a linear expression by using tuples of coefficients and + variables. - reset_index = exprwrap(Dataset.reset_index) + The function internally checks that all variables in the tuples belong to the same + reference model. - rename_dims = exprwrap(Dataset.rename_dims) + Parameters + ---------- + tuples : tuples of (coefficients, variables) + Each tuple represents one term in the resulting linear expression, + which can possibly span over multiple dimensions: - roll = exprwrap(Dataset.roll) + * coefficients : int/float/array_like + The coefficient(s) in the term, if the coefficients array + contains dimensions which do not appear in + the variables, the variables are broadcasted. + * variables : str/array_like/linopy.Variable + The variable(s) going into the term. These may be referenced + by name. - stack = exprwrap(Dataset.stack) + Returns + ------- + linopy.LinearExpression - unstack = exprwrap(Dataset.unstack) + Examples + -------- + >>> from linopy import Model + >>> import pandas as pd + >>> m = Model() + >>> x = m.add_variables(pd.Series([0, 0]), 1) + >>> y = m.add_variables(4, pd.Series([8, 10])) + >>> expr = LinearExpression.from_tuples((10, x), (1, y)) - iterate_slices = iterate_slices + This is the same as calling ``10*x + y`` but a bit more performant. + """ + exprs = [] + for t in tuples: + if len(t) == 2: + # assume first element is coefficient and second is variable + c, v = t + if not isinstance(v, (variables.Variable, variables.ScalarVariable)): + raise TypeError("Expected variable as second element of tuple.") + expr = v.to_linexpr(c) + const = None + if model is None: + model = expr.model # TODO: Ensure equality of models + elif len(t) == 1: + # assume that the element is a constant + c, v = None, None + (const,) = as_dataarray(t) + if model is None: + raise ValueError("Model must be provided when using constants.") + expr = LinearExpression(const, model) + else: + raise ValueError("Expected tuples of length 1 or 2.") + exprs.append(expr) -GenericLinearExpression = TypeVar("GenericLinearExpression", bound=LinearExpression) + return merge(exprs, cls=cls) if len(exprs) > 1 else exprs[0] -class QuadraticExpression(LinearExpression): +class QuadraticExpression(BaseExpression): """ A quadratic expression consisting of terms of coefficients and variables. @@ -1575,6 +1651,10 @@ def __init__(self, data: Dataset | None, model: Model) -> None: data = xr.Dataset(data.transpose(..., FACTOR_DIM, TERM_DIM)) self._data = data + @property + def type(self) -> str: + return "QuadraticExpression" + def __mul__(self, other: SideLike) -> QuadraticExpression: """ Multiply the expr by a factor. @@ -1582,8 +1662,7 @@ def __mul__(self, other: SideLike) -> QuadraticExpression: if isinstance( other, ( - LinearExpression, - QuadraticExpression, + BaseExpression, ScalarLinearExpression, variables.Variable, variables.ScalarVariable, @@ -1594,15 +1673,20 @@ def __mul__(self, other: SideLike) -> QuadraticExpression: f"{type(self)} and {type(other)}. " "Higher order non-linear expressions are not yet supported." ) - return super().__mul__(other) + try: + if isinstance(other, (variables.Variable, variables.ScalarVariable)): + other = other.to_linexpr() + + if isinstance(other, (LinearExpression, ScalarLinearExpression)): + return self._multiply_by_linear_expression(other) + else: + return self._multiply_by_constant(other) + except TypeError: + return NotImplemented def __rmul__(self, other: SideLike) -> QuadraticExpression: return self.__mul__(other) - @property - def type(self) -> str: - return "QuadraticExpression" - def __add__(self, other: SideLike) -> QuadraticExpression: """ Add an expression to others. @@ -1616,9 +1700,10 @@ def __add__(self, other: SideLike) -> QuadraticExpression: other = as_expression(other, model=self.model, dims=self.coord_dims) - if type(other) is LinearExpression: + if isinstance(other, LinearExpression): other = other.to_quadexpr() - return merge([self, other], cls=self.__class__) # type: ignore + + return merge([self, other], cls=self.__class__) except TypeError: return NotImplemented @@ -1642,7 +1727,7 @@ def __sub__(self, other: SideLike) -> QuadraticExpression: other = as_expression(other, model=self.model, dims=self.coord_dims) if type(other) is LinearExpression: other = other.to_quadexpr() - return merge([self, -other], cls=self.__class__) # type: ignore + return merge([self, -other], cls=self.__class__) except TypeError: return NotImplemented @@ -1655,11 +1740,31 @@ def __rsub__(self, other: SideLike) -> QuadraticExpression: except TypeError: return NotImplemented - def __neg__(self) -> QuadraticExpression: + def __pow__(self, other: SideLike) -> QuadraticExpression: + raise TypeError("Higher order non-linear expressions are not yet supported.") + + def __matmul__( + self, other: ConstantLike | VariableLike | ExpressionLike + ) -> QuadraticExpression: """ - Get the negative of the expression. + Matrix multiplication with other, similar to xarray dot. """ - return super().__neg__() # type: ignore + if isinstance( + other, + ( + BaseExpression, + ScalarLinearExpression, + variables.Variable, + variables.ScalarVariable, + ), + ): + raise TypeError( + "Higher order non-linear expressions are not yet supported." + ) + + other = as_dataarray(other, coords=self.coords, dims=self.coord_dims) + common_dims = list(set(self.coord_dims).intersection(other.dims)) + return (self * other).sum(dim=common_dims) @property def solution(self) -> DataArray: @@ -1673,16 +1778,6 @@ def solution(self) -> DataArray: sol = (self.coeffs * vals.prod(FACTOR_DIM)).sum(TERM_DIM) + self.const return sol.rename("solution") - @classmethod - def _sum( - cls, - expr: Dataset | LinearExpression | QuadraticExpression, - dim: DimsLike | None = None, - ) -> Dataset: - data = _expr_unwrap(expr) - dim = dim or list(set(data.dims) - set(HELPER_DIMS)) - return LinearExpression._sum(expr, dim) - def to_constraint(self, sign: SignLike, rhs: SideLike) -> NotImplementedType: raise NotImplementedError( "Quadratic expressions cannot be used in constraints." @@ -1784,7 +1879,7 @@ def as_expression( ValueError If object cannot be converted to LinearExpression. """ - if isinstance(obj, LinearExpression): + if isinstance(obj, (LinearExpression, QuadraticExpression)): return obj elif isinstance(obj, (variables.Variable, variables.ScalarVariable)): return obj.to_linexpr() @@ -1804,9 +1899,9 @@ def merge( LinearExpression | QuadraticExpression | variables.Variable | Dataset ], dim: str = TERM_DIM, - cls: type[LinearExpression | QuadraticExpression] = LinearExpression, + cls: type[GenericExpression] = None, # type: ignore **kwargs: Any, -) -> LinearExpression | QuadraticExpression: +) -> GenericExpression: """ Merge multiple expression together. @@ -1816,7 +1911,6 @@ def merge( the coordinates of the first object as a basis which overrides the coordinates of the consecutive objects. - Parameters ---------- *exprs : tuple/list @@ -1834,6 +1928,13 @@ def merge( ------- res : linopy.LinearExpression """ + if cls is None: + warn( + "Using merge without specifying the class is deprecated", + DeprecationWarning, + ) + cls = LinearExpression + linopy_types = (variables.Variable, LinearExpression, QuadraticExpression) if not isinstance(exprs, list) and len(add_exprs): @@ -1854,6 +1955,13 @@ def merge( "Convert to QuadraticExpression first." ) + if cls is not QuadraticExpression and any( + type(e) is QuadraticExpression for e in exprs + ): + raise ValueError( + "Cannot merge linear and quadratic expressions to QuadraticExpression" + ) + if cls in linopy_types and dim in HELPER_DIMS: coord_dims = [ {k: v for k, v in e.sizes.items() if k not in HELPER_DIMS} for e in exprs diff --git a/linopy/testing.py b/linopy/testing.py index f4970c5f..2cb75fcd 100644 --- a/linopy/testing.py +++ b/linopy/testing.py @@ -13,8 +13,12 @@ def assert_varequal(a: Variable, b: Variable) -> None: return assert_equal(_var_unwrap(a), _var_unwrap(b)) -def assert_linequal(a: LinearExpression, b: LinearExpression) -> None: +def assert_linequal( + a: LinearExpression | QuadraticExpression, b: LinearExpression | QuadraticExpression +) -> None: """Assert that two linear expressions are equal.""" + assert isinstance(a, LinearExpression) + assert isinstance(b, LinearExpression) return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) diff --git a/linopy/variables.py b/linopy/variables.py index cdbebd2e..563311e1 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -64,6 +64,7 @@ if TYPE_CHECKING: from linopy.constraints import AnonymousScalarConstraint, Constraint from linopy.expressions import ( + GenericExpression, LinearExpression, LinearExpressionGroupby, QuadraticExpression, @@ -425,9 +426,17 @@ def __pow__(self, other: int) -> QuadraticExpression: return expr._multiply_by_linear_expression(expr) return NotImplemented + @overload + def __matmul__(self, other: ConstantLike) -> LinearExpression: ... + + @overload + def __matmul__( + self, other: VariableLike | ExpressionLike + ) -> QuadraticExpression: ... + def __matmul__( - self, other: LinearExpression | ndarray | Variable - ) -> QuadraticExpression | LinearExpression: + self, other: ConstantLike | VariableLike | ExpressionLike + ) -> LinearExpression | QuadraticExpression: """ Matrix multiplication of variables with a coefficient. """ @@ -460,13 +469,16 @@ def __truediv__( @overload def __add__( - self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression + self, other: ConstantLike | Variable | ScalarLinearExpression ) -> LinearExpression: ... @overload - def __add__(self, other: QuadraticExpression) -> QuadraticExpression: ... + def __add__(self, other: GenericExpression) -> GenericExpression: ... - def __add__(self, other: SideLike) -> LinearExpression | QuadraticExpression: + def __add__( + self, + other: ConstantLike | Variable | ScalarLinearExpression | GenericExpression, + ) -> LinearExpression | GenericExpression: """ Add variables to linear expressions or other variables. """ @@ -483,13 +495,16 @@ def __radd__(self, other: ConstantLike) -> LinearExpression: @overload def __sub__( - self, other: ConstantLike | Variable | ScalarLinearExpression | LinearExpression + self, other: ConstantLike | Variable | ScalarLinearExpression ) -> LinearExpression: ... @overload - def __sub__(self, other: QuadraticExpression) -> QuadraticExpression: ... + def __sub__(self, other: GenericExpression) -> GenericExpression: ... - def __sub__(self, other: SideLike) -> LinearExpression | QuadraticExpression: + def __sub__( + self, + other: ConstantLike | Variable | ScalarLinearExpression | GenericExpression, + ) -> LinearExpression | GenericExpression: """ Subtract linear expressions or other variables from the variables. """ diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index be7efe6a..b9c710fd 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -209,7 +209,7 @@ def test_linear_expression_with_multiplication(x: Variable) -> None: assert isinstance(expr, LinearExpression) assert expr.__mul__(object()) is NotImplemented - assert expr.__rmul__(object()) is NotImplemented # type: ignore + assert expr.__rmul__(object()) is NotImplemented def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> None: @@ -232,11 +232,11 @@ def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> assert isinstance(expr3, QuadraticExpression) -def test_linear_expression_with_raddition(m: Model, x: Variable): +def test_linear_expression_with_raddition(m: Model, x: Variable) -> None: expr = x * 1.0 - expr_2: LinearExpression = 10.0 + expr # type: ignore + expr_2: LinearExpression = 10.0 + expr assert isinstance(expr, LinearExpression) - expr_3: LinearExpression = expr + 10.0 # type: ignore + expr_3: LinearExpression = expr + 10.0 assert_linequal(expr_2, expr_3) @@ -1031,24 +1031,25 @@ def test_merge(x: Variable, y: Variable, z: Variable) -> None: expr1 = (10 * x + y).sum("dim_0") expr2 = z.sum("dim_0") - res = merge([expr1, expr2]) + res = merge([expr1, expr2], cls=LinearExpression) assert res.nterm == 6 - res = merge([expr1, expr2]) - assert res.nterm == 6 + with pytest.warns(DeprecationWarning): + res: LinearExpression = merge([expr1, expr2]) # type: ignore + assert res.nterm == 6 # now concat with same length of terms expr1 = z.sel(dim_0=0).sum("dim_1") expr2 = z.sel(dim_0=1).sum("dim_1") - res = merge([expr1, expr2], dim="dim_1") + res = merge([expr1, expr2], dim="dim_1", cls=LinearExpression) assert res.nterm == 3 # now with different length of terms expr1 = z.sel(dim_0=0, dim_1=slice(0, 1)).sum("dim_1") expr2 = z.sel(dim_0=1).sum("dim_1") - res = merge([expr1, expr2], dim="dim_1") + res = merge([expr1, expr2], dim="dim_1", cls=LinearExpression) assert res.nterm == 3 assert res.sel(dim_1=0).vars[2].item() == -1 diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index d8a1f556..d855e740 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -9,7 +9,7 @@ from linopy import Model, Variable, merge from linopy.constants import FACTOR_DIM, TERM_DIM -from linopy.expressions import QuadraticExpression +from linopy.expressions import LinearExpression, QuadraticExpression from linopy.testing import assert_quadequal @@ -41,6 +41,12 @@ def test_quadratic_expression_from_variables_multiplication( assert quad_expr.data.sizes[FACTOR_DIM] == 2 +def test_adding_quadratic_expressions(x: Variable) -> None: + quad_expr = x * x + double_quad = quad_expr + quad_expr + assert isinstance(double_quad, QuadraticExpression) + + def test_quadratic_expression_from_variables_power(x: Variable) -> None: power_expr = x**2 target: QuadraticExpression = x * x # type: ignore @@ -202,16 +208,18 @@ def merge_raise_deprecation_warning(x: Variable, y: Variable) -> None: def test_merge_linear_expression_and_quadratic_expression( x: Variable, y: Variable ) -> None: - linexpr = 10 * x + y + 5 - quadexpr = x * y + linexpr: LinearExpression = 10 * x + y + 5 + quadexpr: QuadraticExpression = x * y # type: ignore + merge([linexpr.to_quadexpr(), quadexpr], cls=QuadraticExpression) with pytest.raises(ValueError): merge([linexpr, quadexpr], cls=QuadraticExpression) - with pytest.warns(DeprecationWarning): - merge(linexpr, quadexpr, cls=QuadraticExpression) # type: ignore - linexpr = linexpr.to_quadexpr() - merged_expr = merge([linexpr, quadexpr], cls=QuadraticExpression) + with pytest.warns(DeprecationWarning): + merge(quadexpr, quadexpr, cls=QuadraticExpression) # type: ignore + + quadexpr_2 = linexpr.to_quadexpr() + merged_expr = merge([quadexpr_2, quadexpr], cls=QuadraticExpression) assert isinstance(merged_expr, QuadraticExpression) assert merged_expr.nterm == 3 assert merged_expr.const.sum() == 10 @@ -221,6 +229,12 @@ def test_merge_linear_expression_and_quadratic_expression( first_term = merged_expr.data.isel({TERM_DIM: 0}) assert (first_term.vars.isel({FACTOR_DIM: 1}) == -1).all() + qdexpr = merge([x**2, y**2], cls=QuadraticExpression) + assert isinstance(qdexpr, QuadraticExpression) + + with pytest.raises(ValueError): + merge([x**2, y**2], cls=LinearExpression) + def test_quadratic_expression_loc(x: Variable) -> None: expr = x * x From 5e73653a53983207d7130ab217d3cb0979024cc6 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Tue, 6 May 2025 00:29:26 +0200 Subject: [PATCH 08/21] fixed tests --- linopy/common.py | 12 ++++++------ linopy/monkey_patch_xarray.py | 2 +- linopy/testing.py | 4 ++-- test/test_compatible_arithmetrics.py | 19 ++++++++++--------- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/linopy/common.py b/linopy/common.py index 3e2449f4..84bb71dd 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: from linopy.constraints import Constraint - from linopy.expressions import LinearExpression + from linopy.expressions import LinearExpression, QuadraticExpression from linopy.variables import Variable @@ -994,13 +994,13 @@ def check_common_keys_values(list_of_dicts: list[dict[str, Any]]) -> bool: def align( - *objects: LinearExpression | Variable | T_Alignable, + *objects: LinearExpression | QuadraticExpression | Variable | T_Alignable, join: JoinOptions = "inner", copy: bool = True, indexes: Any = None, exclude: str | Iterable[Hashable] = frozenset(), fill_value: Any = dtypes.NA, -) -> tuple[LinearExpression | Variable | T_Alignable, ...]: +) -> tuple[LinearExpression| QuadraticExpression | Variable | T_Alignable, ...]: """ Given any number of Variables, Expressions, Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -1055,13 +1055,13 @@ def align( """ - from linopy.expressions import LinearExpression + from linopy.expressions import LinearExpression, QuadraticExpression from linopy.variables import Variable finisher: list[partial[Any] | Callable[[Any], Any]] = [] das: list[Any] = [] for obj in objects: - if isinstance(obj, LinearExpression): + if isinstance(obj, (LinearExpression, QuadraticExpression)): finisher.append(partial(obj.__class__, model=obj.model)) das.append(obj.data) elif isinstance(obj, Variable): @@ -1090,7 +1090,7 @@ def align( return tuple([f(da) for f, da in zip(finisher, aligned)]) -LocT = TypeVar("LocT", "Dataset", "Variable", "LinearExpression", "Constraint") +LocT = TypeVar("LocT", "Dataset", "Variable", "LinearExpression", "QuadraticExpression", "Constraint") class LocIndexer(Generic[LocT]): diff --git a/linopy/monkey_patch_xarray.py b/linopy/monkey_patch_xarray.py index bd3bbce9..e5e30a55 100644 --- a/linopy/monkey_patch_xarray.py +++ b/linopy/monkey_patch_xarray.py @@ -26,6 +26,6 @@ def deco(func: Callable) -> Callable: def __mul__( da: DataArray, other: Any, unpatched_method: Callable ) -> DataArray | NotImplementedType: - if isinstance(other, (variables.Variable, expressions.LinearExpression)): + if isinstance(other, (variables.Variable, expressions.LinearExpression, expressions.QuadraticExpression)): return NotImplemented return unpatched_method(da, other) diff --git a/linopy/testing.py b/linopy/testing.py index 2cb75fcd..977ae8b7 100644 --- a/linopy/testing.py +++ b/linopy/testing.py @@ -22,8 +22,8 @@ def assert_linequal( return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) -def assert_quadequal(a: QuadraticExpression, b: QuadraticExpression) -> None: - """Assert that two linear expressions are equal.""" +def assert_quadequal(a: LinearExpression | QuadraticExpression, b: LinearExpression | QuadraticExpression) -> None: + """Assert that two quadratic or linear expressions are equal.""" return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) diff --git a/test/test_compatible_arithmetrics.py b/test/test_compatible_arithmetrics.py index 8cb01c31..92f0cc5d 100644 --- a/test/test_compatible_arithmetrics.py +++ b/test/test_compatible_arithmetrics.py @@ -6,8 +6,8 @@ import xarray as xr from linopy import LESS_EQUAL, Model, Variable -from linopy.testing import assert_linequal - +from linopy.testing import assert_linequal, assert_quadequal +from xarray.testing import assert_equal class SomeOtherDatatype: """ @@ -127,9 +127,9 @@ def test_arithmetric_operations_vars_and_expr(m: Model) -> None: x = m.variables["x"] x_expr = x * 1.0 - assert_linequal(x**2, x_expr**2) - assert_linequal(x**2 + x, x + x**2) - assert_linequal(x**2 * 2, x**2 * 2) + assert_quadequal(x**2, x_expr**2) + assert_quadequal(x**2 + x, x + x**2) + assert_quadequal(x**2 * 2, x**2 * 2) with pytest.raises(TypeError): _ = x**2 * x @@ -144,7 +144,8 @@ def test_arithmetric_operations_con(m: Model) -> None: assert_linequal(c.lhs - data, c.lhs - other_datatype) assert_linequal(c.lhs * data, c.lhs * other_datatype) assert_linequal(c.lhs / data, c.lhs / other_datatype) - assert_linequal(c.rhs + data, c.rhs + other_datatype) # type: ignore - assert_linequal(c.rhs - data, c.rhs - other_datatype) # type: ignore - assert_linequal(c.rhs * data, c.rhs * other_datatype) # type: ignore - assert_linequal(c.rhs / data, c.rhs / other_datatype) # type: ignore + + assert_equal(c.rhs + data, c.rhs + other_datatype) + assert_equal(c.rhs - data, c.rhs - other_datatype) + assert_equal(c.rhs * data, c.rhs * other_datatype) + assert_equal(c.rhs / data, c.rhs / other_datatype) From c50174d89522e02ac7df1aecf7a8cf0a3b35fec6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 May 2025 22:31:03 +0000 Subject: [PATCH 09/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/common.py | 11 +++++++++-- linopy/monkey_patch_xarray.py | 9 ++++++++- linopy/testing.py | 4 +++- test/test_compatible_arithmetrics.py | 3 ++- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/linopy/common.py b/linopy/common.py index 84bb71dd..81d4e41d 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -1000,7 +1000,7 @@ def align( indexes: Any = None, exclude: str | Iterable[Hashable] = frozenset(), fill_value: Any = dtypes.NA, -) -> tuple[LinearExpression| QuadraticExpression | Variable | T_Alignable, ...]: +) -> tuple[LinearExpression | QuadraticExpression | Variable | T_Alignable, ...]: """ Given any number of Variables, Expressions, Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -1090,7 +1090,14 @@ def align( return tuple([f(da) for f, da in zip(finisher, aligned)]) -LocT = TypeVar("LocT", "Dataset", "Variable", "LinearExpression", "QuadraticExpression", "Constraint") +LocT = TypeVar( + "LocT", + "Dataset", + "Variable", + "LinearExpression", + "QuadraticExpression", + "Constraint", +) class LocIndexer(Generic[LocT]): diff --git a/linopy/monkey_patch_xarray.py b/linopy/monkey_patch_xarray.py index e5e30a55..b2527898 100644 --- a/linopy/monkey_patch_xarray.py +++ b/linopy/monkey_patch_xarray.py @@ -26,6 +26,13 @@ def deco(func: Callable) -> Callable: def __mul__( da: DataArray, other: Any, unpatched_method: Callable ) -> DataArray | NotImplementedType: - if isinstance(other, (variables.Variable, expressions.LinearExpression, expressions.QuadraticExpression)): + if isinstance( + other, + ( + variables.Variable, + expressions.LinearExpression, + expressions.QuadraticExpression, + ), + ): return NotImplemented return unpatched_method(da, other) diff --git a/linopy/testing.py b/linopy/testing.py index 977ae8b7..dfa46081 100644 --- a/linopy/testing.py +++ b/linopy/testing.py @@ -22,7 +22,9 @@ def assert_linequal( return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) -def assert_quadequal(a: LinearExpression | QuadraticExpression, b: LinearExpression | QuadraticExpression) -> None: +def assert_quadequal( + a: LinearExpression | QuadraticExpression, b: LinearExpression | QuadraticExpression +) -> None: """Assert that two quadratic or linear expressions are equal.""" return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) diff --git a/test/test_compatible_arithmetrics.py b/test/test_compatible_arithmetrics.py index 92f0cc5d..dbdf86f2 100644 --- a/test/test_compatible_arithmetrics.py +++ b/test/test_compatible_arithmetrics.py @@ -4,10 +4,11 @@ import pandas as pd import pytest import xarray as xr +from xarray.testing import assert_equal from linopy import LESS_EQUAL, Model, Variable from linopy.testing import assert_linequal, assert_quadequal -from xarray.testing import assert_equal + class SomeOtherDatatype: """ From b8f61438a630978f7743424bcf38bfd97cb947a9 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Tue, 6 May 2025 22:41:36 +0200 Subject: [PATCH 10/21] fixed coverage --- linopy/common.py | 11 +++++++-- linopy/expressions.py | 12 +++------ linopy/monkey_patch_xarray.py | 9 ++++++- linopy/testing.py | 4 ++- linopy/variables.py | 2 +- test/test_compatible_arithmetrics.py | 3 ++- test/test_linear_expression.py | 37 +++++++++++++++++++++++++++- test/test_quadratic_expression.py | 15 ++++++++++- 8 files changed, 76 insertions(+), 17 deletions(-) diff --git a/linopy/common.py b/linopy/common.py index 84bb71dd..81d4e41d 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -1000,7 +1000,7 @@ def align( indexes: Any = None, exclude: str | Iterable[Hashable] = frozenset(), fill_value: Any = dtypes.NA, -) -> tuple[LinearExpression| QuadraticExpression | Variable | T_Alignable, ...]: +) -> tuple[LinearExpression | QuadraticExpression | Variable | T_Alignable, ...]: """ Given any number of Variables, Expressions, Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -1090,7 +1090,14 @@ def align( return tuple([f(da) for f, da in zip(finisher, aligned)]) -LocT = TypeVar("LocT", "Dataset", "Variable", "LinearExpression", "QuadraticExpression", "Constraint") +LocT = TypeVar( + "LocT", + "Dataset", + "Variable", + "LinearExpression", + "QuadraticExpression", + "Constraint", +) class LocIndexer(Generic[LocT]): diff --git a/linopy/expressions.py b/linopy/expressions.py index 7a25f3fd..670c4ebf 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1345,7 +1345,7 @@ def __sub__( def __rsub__(self, other: ConstantLike | Variable) -> LinearExpression: try: - return self.__add__(-other) + return self.__neg__().__add__(other) except TypeError: return NotImplemented @@ -1674,13 +1674,7 @@ def __mul__(self, other: SideLike) -> QuadraticExpression: "Higher order non-linear expressions are not yet supported." ) try: - if isinstance(other, (variables.Variable, variables.ScalarVariable)): - other = other.to_linexpr() - - if isinstance(other, (LinearExpression, ScalarLinearExpression)): - return self._multiply_by_linear_expression(other) - else: - return self._multiply_by_constant(other) + return self._multiply_by_constant(other) except TypeError: return NotImplemented @@ -1736,7 +1730,7 @@ def __rsub__(self, other: SideLike) -> QuadraticExpression: Subtract expression from others. """ try: - return self.__neg__() + other + return self.__neg__().__add__(other) except TypeError: return NotImplemented diff --git a/linopy/monkey_patch_xarray.py b/linopy/monkey_patch_xarray.py index e5e30a55..b2527898 100644 --- a/linopy/monkey_patch_xarray.py +++ b/linopy/monkey_patch_xarray.py @@ -26,6 +26,13 @@ def deco(func: Callable) -> Callable: def __mul__( da: DataArray, other: Any, unpatched_method: Callable ) -> DataArray | NotImplementedType: - if isinstance(other, (variables.Variable, expressions.LinearExpression, expressions.QuadraticExpression)): + if isinstance( + other, + ( + variables.Variable, + expressions.LinearExpression, + expressions.QuadraticExpression, + ), + ): return NotImplemented return unpatched_method(da, other) diff --git a/linopy/testing.py b/linopy/testing.py index 977ae8b7..dfa46081 100644 --- a/linopy/testing.py +++ b/linopy/testing.py @@ -22,7 +22,9 @@ def assert_linequal( return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) -def assert_quadequal(a: LinearExpression | QuadraticExpression, b: LinearExpression | QuadraticExpression) -> None: +def assert_quadequal( + a: LinearExpression | QuadraticExpression, b: LinearExpression | QuadraticExpression +) -> None: """Assert that two quadratic or linear expressions are equal.""" return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) diff --git a/linopy/variables.py b/linopy/variables.py index 563311e1..17aacbe3 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -424,7 +424,7 @@ def __pow__(self, other: int) -> QuadraticExpression: if isinstance(other, int) and other == 2: expr = self.to_linexpr() return expr._multiply_by_linear_expression(expr) - return NotImplemented + raise ValueError("Can only raise to the power of 2") @overload def __matmul__(self, other: ConstantLike) -> LinearExpression: ... diff --git a/test/test_compatible_arithmetrics.py b/test/test_compatible_arithmetrics.py index 92f0cc5d..dbdf86f2 100644 --- a/test/test_compatible_arithmetrics.py +++ b/test/test_compatible_arithmetrics.py @@ -4,10 +4,11 @@ import pandas as pd import pytest import xarray as xr +from xarray.testing import assert_equal from linopy import LESS_EQUAL, Model, Variable from linopy.testing import assert_linequal, assert_quadequal -from xarray.testing import assert_equal + class SomeOtherDatatype: """ diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index b9c710fd..8aa5aa20 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -17,7 +17,7 @@ from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge from linopy.constants import HELPER_DIMS, TERM_DIM from linopy.expressions import ScalarLinearExpression -from linopy.testing import assert_linequal +from linopy.testing import assert_linequal, assert_quadequal @pytest.fixture @@ -187,6 +187,9 @@ def test_linear_expression_with_multiplication(x: Variable) -> None: expr2 = x.mul(1) assert_linequal(expr, expr2) + expr3 = expr.mul(1) + assert_linequal(expr, expr3) + expr = x / 1 assert isinstance(expr, LinearExpression) @@ -196,6 +199,9 @@ def test_linear_expression_with_multiplication(x: Variable) -> None: expr2 = x.div(1) assert_linequal(expr, expr2) + expr3 = expr.div(1) + assert_linequal(expr, expr3) + expr = np.array([1, 2]) * x assert isinstance(expr, LinearExpression) @@ -228,6 +234,9 @@ def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> expr2 = x.add(y) assert_linequal(expr, expr2) + expr3 = (x * 1).add(y) + assert_linequal(expr, expr3) + expr3 = x + (x * x) assert isinstance(expr3, QuadraticExpression) @@ -248,11 +257,23 @@ def test_linear_expression_with_subtraction(m: Model, x: Variable, y: Variable) expr2 = x.sub(y) assert_linequal(expr, expr2) + expr3: LinearExpression = x * 1 + expr4 = expr3.sub(y) + assert_linequal(expr, expr4) + expr = -x - 8 * y assert isinstance(expr, LinearExpression) assert_linequal(expr, m.linexpr((-1, "x"), (-8, "y"))) +def test_linear_expression_rsubtraction(x: Variable, y: Variable) -> None: + expr = x * 1.0 + expr_2: LinearExpression = 10.0 - expr + assert isinstance(expr_2, LinearExpression) + expr_3: LinearExpression = (expr - 10.0) * -1 + assert_linequal(expr_2, expr_3) + + def test_linear_expression_with_constant(m: Model, x: Variable, y: Variable) -> None: expr = x + 1 assert isinstance(expr, LinearExpression) @@ -336,6 +357,9 @@ def test_linear_expression_addition(x: Variable, y: Variable, z: Variable) -> No assert (res.coords["dim_1"] == other.coords["dim_1"]).all() assert res.data.notnull().all().to_array().all() + res2 = expr.add(other) + assert_linequal(res, res2) + assert isinstance(x - expr, LinearExpression) assert isinstance(x + expr, LinearExpression) @@ -456,6 +480,17 @@ def test_linear_expression_sum_warn_unknown_kwargs(z: Variable) -> None: (1 * z).sum(unknown_kwarg="dim_0") +def test_linear_expression_power(x: Variable) -> None: + qd_expr = x**2 + assert isinstance(qd_expr, QuadraticExpression) + + qd_expr2 = x.pow(2) + assert_quadequal(qd_expr, qd_expr2) + + with pytest.raises(ValueError): + x**3 + + def test_linear_expression_multiplication( x: Variable, y: Variable, z: Variable ) -> None: diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index d855e740..f91fd047 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -45,6 +45,7 @@ def test_adding_quadratic_expressions(x: Variable) -> None: quad_expr = x * x double_quad = quad_expr + quad_expr assert isinstance(double_quad, QuadraticExpression) + assert double_quad.__add__(object()) is NotImplemented def test_quadratic_expression_from_variables_power(x: Variable) -> None: @@ -123,6 +124,12 @@ def test_matmul_expr_and_expr(x: Variable, y: Variable, z: Variable) -> None: assert_quadequal(expr, target) +def test_matmul_with_const(x: Variable) -> None: + expr = x * x + const = 2.0 + assert_quadequal(expr @ const, (x * const).sum()) + + def test_quadratic_expression_dot_and_matmul(x: Variable, y: Variable) -> None: matmul_expr: QuadraticExpression = 10 * x @ y # type: ignore dot_expr: QuadraticExpression = 10 * x.dot(y) # type: ignore @@ -163,6 +170,7 @@ def test_quadratic_expression_subtraction(x: Variable, y: Variable) -> None: assert isinstance(expr, QuadraticExpression) assert (expr.const == -5).all() assert expr.nterm == 2 + assert expr.__sub__(object()) is NotImplemented def test_quadratic_expression_rsubtraction(x: Variable, y: Variable) -> None: @@ -171,6 +179,11 @@ def test_quadratic_expression_rsubtraction(x: Variable, y: Variable) -> None: assert (expr.const == -5).all() assert expr.nterm == 2 + expr2 = 5 - x * y + assert isinstance(expr2, QuadraticExpression) + assert (expr2.const == 5).all() + assert expr2.nterm == 1 + def test_quadratic_expression_sum(x: Variable, y: Variable) -> None: base_expr = x * y + x + 5 @@ -314,7 +327,7 @@ def test_power_of_three(x: Variable) -> None: (x * 1) * (x * x) with pytest.raises(TypeError): (x * x) * (x * 1) - with pytest.raises(TypeError): + with pytest.raises(ValueError): x**3 with pytest.raises(TypeError): (x * x) * (x * x) From cf6418edd2036c0a872be00073c92f14c1ab54ae Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Tue, 6 May 2025 22:46:44 +0200 Subject: [PATCH 11/21] fixed precommit issue --- linopy/expressions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 670c4ebf..506dfc6b 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -13,7 +13,7 @@ from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field from itertools import product, zip_longest -from typing import TYPE_CHECKING, Any, Type, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeVar, overload from warnings import warn import numpy as np @@ -1151,8 +1151,8 @@ def __iter__(self) -> Iterator[Hashable]: @classmethod def _sum( - cls: Type[GenericExpression], - expr: GenericExpression | Dataset, + cls, + expr: BaseExpression | Dataset, dim: DimsLike | None = None, ) -> Dataset: data = _expr_unwrap(expr) From c4b002ceeab658dcfaae704cb7d9a6fe1d64da67 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Tue, 6 May 2025 22:59:20 +0200 Subject: [PATCH 12/21] fixed test --- linopy/variables.py | 4 +++- test/test_compatible_arithmetrics.py | 3 ++- test/test_quadratic_expression.py | 8 ++++++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/linopy/variables.py b/linopy/variables.py index 17aacbe3..b55edb6e 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -421,7 +421,9 @@ def __pow__(self, other: int) -> QuadraticExpression: """ Power of the variables with a coefficient. The only coefficient allowed is 2. """ - if isinstance(other, int) and other == 2: + if not isinstance(other, int): + return NotImplemented + if other == 2: expr = self.to_linexpr() return expr._multiply_by_linear_expression(expr) raise ValueError("Can only raise to the power of 2") diff --git a/test/test_compatible_arithmetrics.py b/test/test_compatible_arithmetrics.py index dbdf86f2..1d1618ba 100644 --- a/test/test_compatible_arithmetrics.py +++ b/test/test_compatible_arithmetrics.py @@ -105,7 +105,8 @@ def test_arithmetric_operations_variable(m: Model) -> None: assert x.__mul__(object()) is NotImplemented assert x.__truediv__(object()) is NotImplemented # type: ignore assert x.__pow__(object()) is NotImplemented # type: ignore - assert x.__pow__(3) is NotImplemented + with pytest.raises(ValueError): + x.__pow__(3) def test_arithmetric_operations_expr(m: Model) -> None: diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index f91fd047..59522b8e 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -126,8 +126,12 @@ def test_matmul_expr_and_expr(x: Variable, y: Variable, z: Variable) -> None: def test_matmul_with_const(x: Variable) -> None: expr = x * x - const = 2.0 - assert_quadequal(expr @ const, (x * const).sum()) + const = DataArray([2.0, 1.0], dims=["dim_0"]) + expr2 = expr @ const + assert isinstance(expr2, QuadraticExpression) + assert expr2.nterm == 2 + assert expr2.data.sizes[FACTOR_DIM] == 2 + def test_quadratic_expression_dot_and_matmul(x: Variable, y: Variable) -> None: From a80f90574ed205cbe7fb586397f4af11aa7758d8 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Tue, 6 May 2025 23:03:06 +0200 Subject: [PATCH 13/21] formatting --- linopy/expressions.py | 7 +++++++ test/test_quadratic_expression.py | 1 - 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 506dfc6b..0d361ace 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1655,6 +1655,13 @@ def __init__(self, data: Dataset | None, model: Model) -> None: def type(self) -> str: return "QuadraticExpression" + @property + def shape(self) -> tuple[int, ...]: + # TODO Implement this + raise NotImplementedError( + f"{self.__class__.__name__} does not support shape property" + ) + def __mul__(self, other: SideLike) -> QuadraticExpression: """ Multiply the expr by a factor. diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index 59522b8e..233f3c3c 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -133,7 +133,6 @@ def test_matmul_with_const(x: Variable) -> None: assert expr2.data.sizes[FACTOR_DIM] == 2 - def test_quadratic_expression_dot_and_matmul(x: Variable, y: Variable) -> None: matmul_expr: QuadraticExpression = 10 * x @ y # type: ignore dot_expr: QuadraticExpression = 10 * x.dot(y) # type: ignore From b51991f1c18da1407a0c19671b0874535f52a28f Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Wed, 7 May 2025 19:47:46 +0200 Subject: [PATCH 14/21] Added tests to improve code coverage --- linopy/expressions.py | 55 ++++++++++++++++++++++++------- linopy/model.py | 45 ++++++++++++++++++++----- linopy/variables.py | 2 +- test/test_linear_expression.py | 37 ++++++++++++++++++++- test/test_quadratic_expression.py | 4 +++ test/test_typing.py | 6 ---- 6 files changed, 121 insertions(+), 28 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 0d361ace..0f3708d6 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1557,7 +1557,11 @@ def from_rule( @classmethod def from_tuples( - cls, *tuples: tuple, model: Model | None = None + cls, + *tuples: tuple[ConstantLike, str | Variable | ScalarVariable] + | tuple[ConstantLike] + | ConstantLike, + model: Model | None = None, ) -> LinearExpression: """ Create a linear expression by using tuples of coefficients and @@ -1568,8 +1572,12 @@ def from_tuples( Parameters ---------- - tuples : tuples of (coefficients, variables) - Each tuple represents one term in the resulting linear expression, + tuples : A list of elements. Each element is either: + * (coefficients, variables) + * (constant,) + * constant + + Each (coefficients, variables) tuple represents one term in the resulting linear expression, which can possibly span over multiple dimensions: * coefficients : int/float/array_like @@ -1579,6 +1587,9 @@ def from_tuples( * variables : str/array_like/linopy.Variable The variable(s) going into the term. These may be referenced by name. + * constant: int/float/array_like + The constant value to add to the expression + model : The linopy.Model. If None this can be inferred from the provided variables Returns ------- @@ -1591,25 +1602,47 @@ def from_tuples( >>> m = Model() >>> x = m.add_variables(pd.Series([0, 0]), 1) >>> y = m.add_variables(4, pd.Series([8, 10])) - >>> expr = LinearExpression.from_tuples((10, x), (1, y)) + >>> expr = LinearExpression.from_tuples((10, x), (1, y), 1) + >>> expr = LinearExpression.from_tuples( + ... (10, x), (1, y), (1,) + ... ) # Alternative usage - This is the same as calling ``10*x + y`` but a bit more performant. + This is the same as calling ``10*x + y` + 1` but a bit more performant. """ - exprs = [] + exprs: list[LinearExpression] = [] for t in tuples: + if isinstance(t, SUPPORTED_CONSTANT_TYPES): + if model is None: + raise ValueError("Model must be provided when using constants.") + expr = LinearExpression(t, model) + exprs.append(expr) + continue + if not isinstance(t, tuple): + raise ValueError("Expected tuple or constant.") + if len(t) == 2: # assume first element is coefficient and second is variable c, v = t - if not isinstance(v, (variables.Variable, variables.ScalarVariable)): + if isinstance(v, ScalarVariable): + if not isinstance(c, (int, float)): + raise TypeError( + "Expected int or float as coefficient of scalar variable (first element of tuple)." + ) + expr = v.to_linexpr(c) + elif isinstance(v, variables.Variable): + if not isinstance(c, SUPPORTED_CONSTANT_TYPES): + raise TypeError( + "Expected constant as coefficient of variable (first element of tuple)." + ) + expr = v.to_linexpr(c) + else: raise TypeError("Expected variable as second element of tuple.") - expr = v.to_linexpr(c) - const = None + if model is None: model = expr.model # TODO: Ensure equality of models elif len(t) == 1: # assume that the element is a constant - c, v = None, None - (const,) = as_dataarray(t) + const = as_dataarray(t[0]) if model is None: raise ValueError("Model must be provided when using constants.") expr = LinearExpression(const, model) diff --git a/linopy/model.py b/linopy/model.py index 99fae982..d7991d93 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -12,7 +12,7 @@ from collections.abc import Callable, Mapping, Sequence from pathlib import Path from tempfile import NamedTemporaryFile, gettempdir -from typing import Any +from typing import Any, overload import numpy as np import pandas as pd @@ -878,17 +878,37 @@ def calculate_block_maps(self) -> None: blocks = replace_by_map(self.objective.vars, block_map) self.objective = self.objective.assign(blocks=blocks) + @overload def linexpr( - self, *args: tuple[ConstantLike, str | Variable | ScalarVariable] | Callable + self, *args: Sequence[Sequence | pd.Index | DataArray] | Mapping + ) -> LinearExpression: + ... + # A function and tuples of coordinates + + @overload + def linexpr( + self, *args: tuple[ConstantLike, str | Variable | ScalarVariable] | ConstantLike + ) -> LinearExpression: + ... + # A mixture of tuples of (coefficients, variables) and constants + + def linexpr( + self, + *args: tuple[ConstantLike, str | Variable | ScalarVariable] + | ConstantLike + | Callable + | Sequence[Sequence | pd.Index | DataArray] + | Mapping, ) -> LinearExpression: """ Create a linopy.LinearExpression from argument list. Parameters ---------- - args : tuples of (coefficients, variables) or tuples of - coordinates and a function - If args is a collection of coefficients-variables-tuples, the resulting + args : A mixture of tuples of (coefficients, variables) and constants + or a function and tuples of coordinates + + If args is a collection of coefficients-variables-tuples and constants, the resulting linear expression is built with the function LinearExpression.from_tuples. * coefficients : int/float/array_like The coefficient(s) in the term, if the coefficients array @@ -897,6 +917,8 @@ def linexpr( * variables : str/array_like/linopy.Variable The variable(s) going into the term. These may be referenced by name. + * constant: int/float/array_like + The constant value to add to the expression If args is a collection of coordinates with an appended function at the end, the function LinearExpression.from_rule is used to build the linear @@ -907,7 +929,7 @@ def linexpr( The first argument of the function is the underlying `linopy.Model`. The following arguments are given by the coordinates for accessing the variables. The function has to return a - `ScalarLinearExpression`. Therefore use the direct getter when + `ScalarLinearExpression`. Therefore, use the direct getter when indexing variables. * coords : coordinate-like Coordinates to be processed by `xarray.DataArray`. For each @@ -954,9 +976,14 @@ def linexpr( return LinearExpression.from_rule(self, rule, coords) # type: ignore if not isinstance(args, tuple): raise TypeError(f"Not supported type {args}.") - tuples = [ # type: ignore - (c, self.variables[v]) if isinstance(v, str) else (c, v) for (c, v) in args - ] + + tuples: list[tuple[ConstantLike, VariableLike] | ConstantLike] = [] + for arg in args: + if isinstance(arg, tuple): + c, v = arg + tuples.append((c, self.variables[v]) if isinstance(v, str) else (c, v)) + else: + tuples.append(arg) return LinearExpression.from_tuples(*tuples, model=self) @property diff --git a/linopy/variables.py b/linopy/variables.py index b55edb6e..32575094 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -1568,7 +1568,7 @@ def to_scalar_linexpr(self, coeff: int | float = 1) -> ScalarLinearExpression: raise TypeError(f"Coefficient must be a numeric value, got {type(coeff)}.") return expressions.ScalarLinearExpression((coeff,), (self.label,), self.model) - def to_linexpr(self, coeff: int = 1) -> LinearExpression: + def to_linexpr(self, coeff: int | float = 1) -> LinearExpression: return self.to_scalar_linexpr(coeff).to_linexpr() def __neg__(self) -> ScalarLinearExpression: diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 8aa5aa20..200cf725 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -158,6 +158,13 @@ def test_linexpr_with_scalars(m: Model) -> None: assert_equal(expr.coeffs, target) +def test_linexpr_with_variables_and_constants( + m: Model, x: Variable, y: Variable +) -> None: + expr = m.linexpr((10, x), (1, y), 2) + assert (expr.const == 2).all() + + def test_linexpr_with_series(m: Model, v: Variable) -> None: lhs = pd.Series(np.arange(20)), v expr = m.linexpr(lhs) @@ -214,6 +221,12 @@ def test_linear_expression_with_multiplication(x: Variable) -> None: expr = pd.Series([1, 2], index=pd.RangeIndex(2, name="dim_0")) * x assert isinstance(expr, LinearExpression) + quad = x * x + assert isinstance(quad, QuadraticExpression) + + with pytest.raises(TypeError): + quad * quad + assert expr.__mul__(object()) is NotImplemented assert expr.__rmul__(object()) is NotImplemented @@ -312,7 +325,7 @@ def test_linear_expression_with_errors(m: Model, x: Variable) -> None: x / (1 * x) with pytest.raises(TypeError): - m.linexpr((10, x.labels), (1, "y")) # type: ignore + m.linexpr((10, x.labels), (1, "y")) def test_linear_expression_from_rule(m: Model, x: Variable, y: Variable) -> None: @@ -1057,6 +1070,28 @@ def test_linear_expression_rolling_from_variable(v: Variable) -> None: assert rolled.nterm == 2 +def test_linear_expression_from_tuples(x: Variable, y: Variable) -> None: + expr = LinearExpression.from_tuples((10, x), (1, y)) + assert isinstance(expr, LinearExpression) + + expr2 = LinearExpression.from_tuples((10, x), (1,)) + assert isinstance(expr2, LinearExpression) + assert (expr2.const == 1).all() + + expr3 = LinearExpression.from_tuples((10, x), 1) + assert isinstance(expr3, LinearExpression) + assert_linequal(expr2, expr3) + + expr4 = LinearExpression.from_tuples((10, x), (1, y), 1) + assert isinstance(expr4, LinearExpression) + + with pytest.raises(ValueError): + LinearExpression.from_tuples((10, x), (1, y), x) + + with pytest.raises(ValueError): + LinearExpression.from_tuples((10, x, 3), (1, y), 1) + + def test_linear_expression_sanitize(x: Variable, y: Variable, z: Variable) -> None: expr = 10 * x + y + z assert isinstance(expr.sanitize(), LinearExpression) diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index 233f3c3c..0edcbff2 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -214,6 +214,10 @@ def test_quadratic_expression_wrong_multiplication(x: Variable, y: Variable) -> with pytest.raises(TypeError): x * x * y + quad = x * x + with pytest.raises(TypeError): + quad * quad + def merge_raise_deprecation_warning(x: Variable, y: Variable) -> None: expr: QuadraticExpression = x * y # type: ignore diff --git a/test/test_typing.py b/test/test_typing.py index e9afa55b..99a27033 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -1,15 +1,9 @@ import xarray as xr -from mypy import api import linopy def test_operations_with_data_arrays_are_typed_correctly() -> None: - # Get the path of this file - file_path = __file__ - result = api.run([file_path]) - assert result[2] == 0, "Mypy returned issues: " + result[0] - m = linopy.Model() a: xr.DataArray = xr.DataArray([1, 2, 3]) From a11371f9e5863d9b1a81d2675c4e00e8dacd7be0 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Wed, 7 May 2025 19:49:15 +0200 Subject: [PATCH 15/21] minor changes --- linopy/expressions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 0f3708d6..667f91f6 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1623,7 +1623,7 @@ def from_tuples( if len(t) == 2: # assume first element is coefficient and second is variable c, v = t - if isinstance(v, ScalarVariable): + if isinstance(v, variables.ScalarVariable): if not isinstance(c, (int, float)): raise TypeError( "Expected int or float as coefficient of scalar variable (first element of tuple)." From 4db1afe1c2f02b9121c3e5658d0e305ebe47d53e Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Wed, 7 May 2025 20:00:03 +0200 Subject: [PATCH 16/21] Deprecated using single-value tuple for LinearExpression.from_tuples --- linopy/expressions.py | 38 ++++++++++++++++++++-------------- test/test_linear_expression.py | 3 ++- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 667f91f6..0dee81dc 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1550,7 +1550,7 @@ def from_rule( raise TypeError(msg) combinations = product(*[c.values for c in coords.values()]) - exprs = [] + placeholder = ScalarLinearExpression((np.nan,), (-1,), model) exprs = [rule(model, *coord) or placeholder for coord in combinations] return cls._from_scalarexpression_list(exprs, coords, model) @@ -1558,9 +1558,7 @@ def from_rule( @classmethod def from_tuples( cls, - *tuples: tuple[ConstantLike, str | Variable | ScalarVariable] - | tuple[ConstantLike] - | ConstantLike, + *tuples: tuple[ConstantLike, str | Variable | ScalarVariable] | ConstantLike, model: Model | None = None, ) -> LinearExpression: """ @@ -1574,7 +1572,6 @@ def from_tuples( ---------- tuples : A list of elements. Each element is either: * (coefficients, variables) - * (constant,) * constant Each (coefficients, variables) tuple represents one term in the resulting linear expression, @@ -1603,20 +1600,23 @@ def from_tuples( >>> x = m.add_variables(pd.Series([0, 0]), 1) >>> y = m.add_variables(4, pd.Series([8, 10])) >>> expr = LinearExpression.from_tuples((10, x), (1, y), 1) - >>> expr = LinearExpression.from_tuples( - ... (10, x), (1, y), (1,) - ... ) # Alternative usage This is the same as calling ``10*x + y` + 1` but a bit more performant. """ - exprs: list[LinearExpression] = [] - for t in tuples: + + def process_one( + t: tuple[ConstantLike, str | Variable | ScalarVariable] + | tuple[ConstantLike] + | ConstantLike, + ) -> LinearExpression: + nonlocal model + if isinstance(t, SUPPORTED_CONSTANT_TYPES): if model is None: raise ValueError("Model must be provided when using constants.") expr = LinearExpression(t, model) - exprs.append(expr) - continue + return expr + if not isinstance(t, tuple): raise ValueError("Expected tuple or constant.") @@ -1640,16 +1640,22 @@ def from_tuples( if model is None: model = expr.model # TODO: Ensure equality of models - elif len(t) == 1: + return expr + + if len(t) == 1: + warn( + "Passing a single value tuple to LinearExpression.from_tuples is deprecated", + DeprecationWarning, + ) # assume that the element is a constant const = as_dataarray(t[0]) if model is None: raise ValueError("Model must be provided when using constants.") - expr = LinearExpression(const, model) + return LinearExpression(const, model) else: - raise ValueError("Expected tuples of length 1 or 2.") + raise ValueError("Expected tuples of length 2") - exprs.append(expr) + exprs = [process_one(t) for t in tuples] return merge(exprs, cls=cls) if len(exprs) > 1 else exprs[0] diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 200cf725..1ed4adfd 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1074,7 +1074,8 @@ def test_linear_expression_from_tuples(x: Variable, y: Variable) -> None: expr = LinearExpression.from_tuples((10, x), (1, y)) assert isinstance(expr, LinearExpression) - expr2 = LinearExpression.from_tuples((10, x), (1,)) + with pytest.warns(DeprecationWarning): + expr2 = LinearExpression.from_tuples((10, x), (1,)) assert isinstance(expr2, LinearExpression) assert (expr2.const == 1).all() From e687249d32456008203e5420e53164aae56ce754 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Wed, 7 May 2025 21:08:28 +0200 Subject: [PATCH 17/21] added tests to improve code coverage --- doc/release_notes.rst | 1 + linopy/expressions.py | 10 +++++----- linopy/variables.py | 3 +-- test/test_linear_expression.py | 29 ++++++++++++++++++++++++++--- test/test_quadratic_expression.py | 3 +++ 5 files changed, 36 insertions(+), 10 deletions(-) diff --git a/doc/release_notes.rst b/doc/release_notes.rst index ebe3a120..c5a97289 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,6 +4,7 @@ Release Notes Future Version --------------- **Minor Improvements** + * Improved variable/expression arithmetic methods so that they correctly handle types Upcoming Version diff --git a/linopy/expressions.py b/linopy/expressions.py index 5f8e40b8..d7ef52bb 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1317,7 +1317,7 @@ def __add__( def __radd__(self, other: ConstantLike) -> LinearExpression: try: - return self.__add__(other) + return self + other except TypeError: return NotImplemented @@ -1345,7 +1345,7 @@ def __sub__( def __rsub__(self, other: ConstantLike | Variable) -> LinearExpression: try: - return self.__neg__().__add__(other) + return (self * -1) + other except TypeError: return NotImplemented @@ -1389,7 +1389,7 @@ def __rmul__(self, other: ConstantLike) -> LinearExpression: Right-multiply the expr by a factor. """ try: - return self.__mul__(other) + return self * other except TypeError: return NotImplemented @@ -1725,7 +1725,7 @@ def __mul__(self, other: SideLike) -> QuadraticExpression: return NotImplemented def __rmul__(self, other: SideLike) -> QuadraticExpression: - return self.__mul__(other) + return self * other def __add__(self, other: SideLike) -> QuadraticExpression: """ @@ -1776,7 +1776,7 @@ def __rsub__(self, other: SideLike) -> QuadraticExpression: Subtract expression from others. """ try: - return self.__neg__().__add__(other) + return (self * -1) + other except TypeError: return NotImplemented diff --git a/linopy/variables.py b/linopy/variables.py index 91778639..38479cf9 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -1516,8 +1516,7 @@ class ScalarVariable: """ A scalar variable container. - In contrast to the Variable class, a ScalarVariable only contains - only one label. Use this class to create a expression or constraint + In contrast to the Variable class, a ScalarVariable only contains one label. Use this class to create a expression or constraint in a rule. """ diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 1ed4adfd..7f7603c5 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -18,6 +18,7 @@ from linopy.constants import HELPER_DIMS, TERM_DIM from linopy.expressions import ScalarLinearExpression from linopy.testing import assert_linequal, assert_quadequal +from linopy.variables import ScalarVariable @pytest.fixture @@ -227,6 +228,8 @@ def test_linear_expression_with_multiplication(x: Variable) -> None: with pytest.raises(TypeError): quad * quad + expr = x * 1 + assert isinstance(expr, LinearExpression) assert expr.__mul__(object()) is NotImplemented assert expr.__rmul__(object()) is NotImplemented @@ -285,6 +288,7 @@ def test_linear_expression_rsubtraction(x: Variable, y: Variable) -> None: assert isinstance(expr_2, LinearExpression) expr_3: LinearExpression = (expr - 10.0) * -1 assert_linequal(expr_2, expr_3) + assert expr.__rsub__(object()) is NotImplemented def test_linear_expression_with_constant(m: Model, x: Variable, y: Variable) -> None: @@ -494,14 +498,15 @@ def test_linear_expression_sum_warn_unknown_kwargs(z: Variable) -> None: def test_linear_expression_power(x: Variable) -> None: - qd_expr = x**2 + expr: LinearExpression = x * 1.0 + qd_expr = expr**2 assert isinstance(qd_expr, QuadraticExpression) - qd_expr2 = x.pow(2) + qd_expr2 = expr.pow(2) assert_quadequal(qd_expr, qd_expr2) with pytest.raises(ValueError): - x**3 + expr**3 def test_linear_expression_multiplication( @@ -1085,13 +1090,31 @@ def test_linear_expression_from_tuples(x: Variable, y: Variable) -> None: expr4 = LinearExpression.from_tuples((10, x), (1, y), 1) assert isinstance(expr4, LinearExpression) + assert (expr4.const == 1).all() + + expr5 = LinearExpression.from_tuples(1, model=x.model) + assert isinstance(expr5, LinearExpression) + +def test_linear_expression_from_tuples_bad_calls( + m: Model, x: Variable, y: Variable +) -> None: with pytest.raises(ValueError): LinearExpression.from_tuples((10, x), (1, y), x) with pytest.raises(ValueError): LinearExpression.from_tuples((10, x, 3), (1, y), 1) + sv = ScalarVariable(label=0, model=m) + with pytest.raises(TypeError): + LinearExpression.from_tuples((np.array([1, 1]), sv)) + + with pytest.raises(TypeError): + LinearExpression.from_tuples((x, x)) + + with pytest.raises(ValueError): + LinearExpression.from_tuples(10) + def test_linear_expression_sanitize(x: Variable, y: Variable, z: Variable) -> None: expr = 10 * x + y + z diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index 0edcbff2..6a41d94f 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -123,6 +123,9 @@ def test_matmul_expr_and_expr(x: Variable, y: Variable, z: Variable) -> None: assert expr.nterm == 6 assert_quadequal(expr, target) + with pytest.raises(TypeError): + (x**2) @ (y**2) + def test_matmul_with_const(x: Variable) -> None: expr = x * x From be3eeff5ea71c6583dd7e9543f9e9fca5a1154d3 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Wed, 7 May 2025 21:27:40 +0200 Subject: [PATCH 18/21] improved code coverage a tiny bit more --- linopy/model.py | 8 ++------ linopy/variables.py | 4 ++-- test/test_linear_expression.py | 3 +++ test/test_variables.py | 7 +++++++ 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index d59fd011..b71a8d93 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -884,16 +884,12 @@ def calculate_block_maps(self) -> None: @overload def linexpr( self, *args: Sequence[Sequence | pd.Index | DataArray] | Mapping - ) -> LinearExpression: - ... - # A function and tuples of coordinates + ) -> LinearExpression: ... @overload def linexpr( self, *args: tuple[ConstantLike, str | Variable | ScalarVariable] | ConstantLike - ) -> LinearExpression: - ... - # A mixture of tuples of (coefficients, variables) and constants + ) -> LinearExpression: ... def linexpr( self, diff --git a/linopy/variables.py b/linopy/variables.py index 38479cf9..695ceb85 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -490,7 +490,7 @@ def __add__( def __radd__(self, other: ConstantLike) -> LinearExpression: try: - return self.__add__(other) + return self + other except TypeError: return NotImplemented @@ -1588,7 +1588,7 @@ def __mul__(self, coeff: int | float) -> ScalarLinearExpression: return self.to_scalar_linexpr(coeff) def __rmul__(self, coeff: int | float) -> ScalarLinearExpression: - if isinstance(coeff, Variable): + if isinstance(coeff, (Variable, ScalarVariable)): return NotImplemented return self.to_scalar_linexpr(coeff) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 7f7603c5..cec12882 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -331,6 +331,9 @@ def test_linear_expression_with_errors(m: Model, x: Variable) -> None: with pytest.raises(TypeError): m.linexpr((10, x.labels), (1, "y")) + with pytest.raises(TypeError): + m.linexpr(a=2) # type: ignore + def test_linear_expression_from_rule(m: Model, x: Variable, y: Variable) -> None: def bound(m: Model, i: int) -> ScalarLinearExpression: diff --git a/test/test_variables.py b/test/test_variables.py index a55e92bb..3984b091 100644 --- a/test/test_variables.py +++ b/test/test_variables.py @@ -13,6 +13,7 @@ import linopy from linopy import Model from linopy.testing import assert_varequal +from linopy.variables import ScalarVariable @pytest.fixture @@ -115,3 +116,9 @@ def test_variables_get_name_by_label(m: Model) -> None: with pytest.raises(ValueError): m.variables.get_name_by_label("anystring") # type: ignore + + +def test_scalar_variable(m: Model) -> None: + x = ScalarVariable(label=0, model=m) + assert isinstance(x, ScalarVariable) + assert x.__rmul__(x) is NotImplemented # type: ignore From 9b6bdbed8146fbbc72de13c4a5a75fef27b2161e Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Sat, 24 May 2025 00:16:08 +0200 Subject: [PATCH 19/21] fixed mypy warnign --- linopy/expressions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index f69ac393..36fa0354 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -508,7 +508,7 @@ def _multiply_by_linear_expression( # merge on factor dimension only returns v1 * v2 + c1 * c2 ds = other.data[["coeffs", "vars"]].sel(_term=0).broadcast_like(self.data) ds = assign_multiindex_safe(ds, const=other.const) - res = merge([self, ds], dim=FACTOR_DIM, cls=QuadraticExpression) + res = merge([self, ds], dim=FACTOR_DIM, cls=QuadraticExpression) # type: ignore # deal with cross terms c1 * v2 + c2 * v1 if self.has_constant: res = res + self.const * other.reset_const() From 703ea4a538dd0bfc1d93a8dd38736a2632bc5533 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 May 2025 22:16:18 +0000 Subject: [PATCH 20/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/expressions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 36fa0354..954c8e80 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -508,7 +508,7 @@ def _multiply_by_linear_expression( # merge on factor dimension only returns v1 * v2 + c1 * c2 ds = other.data[["coeffs", "vars"]].sel(_term=0).broadcast_like(self.data) ds = assign_multiindex_safe(ds, const=other.const) - res = merge([self, ds], dim=FACTOR_DIM, cls=QuadraticExpression) # type: ignore + res = merge([self, ds], dim=FACTOR_DIM, cls=QuadraticExpression) # type: ignore # deal with cross terms c1 * v2 + c2 * v1 if self.has_constant: res = res + self.const * other.reset_const() From a89ca80c0647623e47bd14a1ac74f9554faa0063 Mon Sep 17 00:00:00 2001 From: Robbie Muir Date: Sat, 24 May 2025 13:10:02 +0200 Subject: [PATCH 21/21] minor changes --- doc/release_notes.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 55f15c7c..1ae1f5e4 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,6 +4,7 @@ Release Notes .. Upcoming Version .. ---------------- + * Improved variable/expression arithmetic methods so that they correctly handle types Version 0.5.5