Skip to content

Commit 1cde704

Browse files
committed
Add typing generics Function
1 parent dff1b7b commit 1cde704

1 file changed

Lines changed: 14 additions & 14 deletions

File tree

python/dolfinx/fem/function.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def dtype(self) -> npt.DTypeLike:
311311
return np.dtype(self._cpp_object.dtype)
312312

313313

314-
class Function(ufl.Coefficient):
314+
class Function(ufl.Coefficient, Generic[_S]):
315315
"""A finite element function.
316316
317317
A finite element function is represented by a function space
@@ -330,7 +330,7 @@ class Function(ufl.Coefficient):
330330
def __init__(
331331
self,
332332
V: FunctionSpace,
333-
x: la.Vector | None = None,
333+
x: la.Vector[_S] | None = None,
334334
name: str | None = None,
335335
dtype: npt.DTypeLike | None = None,
336336
):
@@ -395,7 +395,7 @@ def function_space(self) -> FunctionSpace:
395395
"""FunctionSpace that the Function is defined on."""
396396
return self._V
397397

398-
def eval(self, x: npt.ArrayLike, cells: npt.ArrayLike, u=None) -> np.ndarray:
398+
def eval(self, x: npt.ArrayLike, cells: npt.NDArray[np.int32], u=None) -> npt.NDArray[_S]:
399399
"""Evaluate Function at points x.
400400
401401
Points where x has shape (num_points, 3), and cells has shape
@@ -431,7 +431,7 @@ def eval(self, x: npt.ArrayLike, cells: npt.ArrayLike, u=None) -> np.ndarray:
431431
return u
432432

433433
def interpolate_nonmatching(
434-
self, u0: Function, cells: npt.NDArray[np.int32], interpolation_data: PointOwnershipData
434+
self, u0: Function[_S], cells: npt.NDArray[np.int32], interpolation_data: PointOwnershipData
435435
) -> None:
436436
"""Interpolate a Function on a non-matching mesh.
437437
@@ -447,9 +447,9 @@ def interpolate_nonmatching(
447447

448448
def interpolate(
449449
self,
450-
u0: Callable | Expression | Function,
451-
cells0: np.ndarray | None = None,
452-
cells1: np.ndarray | None = None,
450+
u0: Callable | Expression[_S] | Function[_S],
451+
cells0: npt.NDArray[np.int32] | None = None,
452+
cells1: npt.NDArray[np.int32] | None = None,
453453
) -> None:
454454
"""Interpolate an expression.
455455
@@ -495,7 +495,7 @@ def _(e0: Expression):
495495
)
496496
self._cpp_object.interpolate_f(np.asarray(u0(x), dtype=self.dtype), cells0)
497497

498-
def copy(self) -> Function:
498+
def copy(self) -> Function[_S]:
499499
"""Create a copy of the Function.
500500
501501
The function space is shared and the degree-of-freedom vector is
@@ -511,12 +511,12 @@ def copy(self) -> Function:
511511
)
512512

513513
@property
514-
def x(self) -> la.Vector:
514+
def x(self) -> la.Vector[_S]:
515515
"""Vector holding the degrees-of-freedom."""
516516
return self._x
517517

518518
@property
519-
def dtype(self) -> np.dtype:
519+
def dtype(self) -> npt.DTypeLike:
520520
"""Function value dtype."""
521521
return np.dtype(self._cpp_object.x.array.dtype)
522522

@@ -529,11 +529,11 @@ def name(self) -> str:
529529
def name(self, name):
530530
self._cpp_object.name = name
531531

532-
def __str__(self):
532+
def __str__(self) -> str:
533533
"""Pretty print representation."""
534534
return self.name
535535

536-
def sub(self, i: int) -> Function:
536+
def sub(self, i: int) -> Function[_S]:
537537
"""Return a sub-function (a view into the ``Function``).
538538
539539
Sub-functions are indexed ``i = 0, ..., N-1``, where ``N`` is
@@ -552,7 +552,7 @@ def sub(self, i: int) -> Function:
552552
"""
553553
return Function(self._V.sub(i), self.x, name=f"{self!s}_{i}")
554554

555-
def split(self) -> tuple[Function, ...]:
555+
def split(self) -> tuple[Function[_S], ...]:
556556
"""Extract (any) sub-functions.
557557
558558
A sub-function can be extracted from a discrete function that is
@@ -567,7 +567,7 @@ def split(self) -> tuple[Function, ...]:
567567
raise RuntimeError("No subfunctions to extract")
568568
return tuple(self.sub(i) for i in range(num_sub_spaces))
569569

570-
def collapse(self) -> Function:
570+
def collapse(self) -> Function[_S]:
571571
"""Create a collapsed version of this Function."""
572572
u_collapsed = self._cpp_object.collapse() # type: ignore
573573
V_collapsed = FunctionSpace(

0 commit comments

Comments
 (0)