Skip to content

Commit 44a98ed

Browse files
committed
Suggestions from review
1 parent 9329d1a commit 44a98ed

2 files changed

Lines changed: 10 additions & 4 deletions

File tree

firedrake/assemble.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
594594
v1: v1.reconstruct(number=v0.number())})
595595
v, operand = expr.argument_slots()
596596

597-
# Assemble the interpolator matrix if the meshes are different
597+
# Matrix-free adjoint interpolation is only implemented by SameMeshInterpolator
598+
# so we need assemble the interpolator matrix if the meshes are different
598599
target_mesh = V.mesh()
599600
source_mesh = extract_unique_domain(operand) or target_mesh
600601
if is_adjoint and rank < 2 and source_mesh is not target_mesh:
@@ -605,6 +606,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
605606
interp_data = expr.interp_data.copy()
606607
default_missing_val = interp_data.pop('default_missing_val', None)
607608
if matfree and ((is_adjoint and rank == 1) or rank == 0):
609+
# Adjoint interpolation of a Cofunction or the action of a
610+
# Cofunction on an interpolated Function require INC access
611+
# on the output tensor
608612
interp_data["access"] = op2.INC
609613

610614
if rank == 1 and matfree and isinstance(tensor, firedrake.Function):

firedrake/interpolation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,9 @@ def __init__(
389389
)
390390

391391
if isinstance(expr, ufl.Interpolate):
392-
expr, = expr.ufl_operands
392+
dual_arg, expr = expr.argument_slots()
393+
if not isinstance(dual_arg, Coargument):
394+
raise NotImplementedError(f"{type(self).__name__} does not support matrix-free adjoint interpolation.")
393395
super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree)
394396

395397
self.arguments = extract_arguments(expr)
@@ -749,7 +751,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
749751
# Interpolation action
750752
self.frozen_assembled_interpolator = assembled_interpolator.copy()
751753

752-
if len(self.arguments) == 2 and len(function):
754+
if len(self.arguments) == 2 and len(function) > 0:
753755
function, = function
754756
if not hasattr(function, "dat"):
755757
raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!")
@@ -865,7 +867,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
865867
tensor = op2.Mat(sparsity)
866868
f = tensor
867869
else:
868-
raise ValueError("Cannot interpolate an expression with %d arguments" % rank)
870+
raise ValueError(f"Cannot interpolate an expression with {rank} arguments")
869871

870872
if vom_onto_other_vom:
871873
wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, matfree)

0 commit comments

Comments
 (0)