Skip to content

Commit 4ba0933

Browse files
committed
Update interpolation.py
1 parent e956adb commit 4ba0933

4 files changed

Lines changed: 113 additions & 82 deletions

File tree

firedrake/assemble.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -590,28 +590,38 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
590590
if not is_adjoint and rank == 2:
591591
_, v1 = expr.arguments()
592592
operand = ufl.replace(operand, {v1: v1.reconstruct(number=0)})
593+
594+
target_mesh = V.mesh()
595+
source_mesh = extract_unique_domain(operand) or target_mesh
596+
same_mesh = source_mesh.topology is target_mesh.topology
597+
593598
# Get the interpolator
594-
interp_data = expr.interp_data.copy()
599+
interp_data = expr.interp_data
595600
default_missing_val = interp_data.pop('default_missing_val', None)
596-
if (is_adjoint and rank == 1) or rank == 0:
601+
if same_mesh and ((is_adjoint and rank == 1) or rank == 0):
602+
interp_data = interp_data.copy()
597603
interp_data["access"] = op2.INC
598-
interpolator = firedrake.Interpolator(operand, v, **interp_data)
604+
605+
dual_arg = v if same_mesh else V
606+
interp_expr = firedrake.Interpolate(operand, v=dual_arg, **interp_data)
607+
interpolator = firedrake.Interpolator(interp_expr, V, **interp_data)
608+
599609
# Assembly
600610
if rank == 0:
601611
result = interpolator._interpolate(output=tensor, default_missing_val=default_missing_val)
602612
return result.dat.data.item() if tensor is None else result
603613
elif rank == 1:
604614
# Assembling the action of the Jacobian adjoint.
605615
if is_adjoint:
606-
return interpolator._interpolate(v, output=tensor, default_missing_val=default_missing_val)
616+
return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val)
607617
# Assembling the Jacobian action.
608618
elif interpolator.nargs:
609619
return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val)
610620
# Assembling the operator
611621
elif tensor is None:
612622
return interpolator._interpolate(default_missing_val=default_missing_val)
613623
else:
614-
return firedrake.Interpolator(operand, tensor, **interp_data)._interpolate(default_missing_val=default_missing_val)
624+
return firedrake.Interpolator(interp_expr, tensor, **interp_data)._interpolate(default_missing_val=default_missing_val)
615625
elif rank == 2:
616626
res = tensor.petscmat if tensor else PETSc.Mat()
617627
# Get the interpolation matrix

firedrake/interpolation.py

Lines changed: 68 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ class Interpolator(abc.ABC):
241241
"""
242242

243243
def __new__(cls, expr, V, **kwargs):
244+
if isinstance(expr, ufl.Interpolate):
245+
expr, = expr.ufl_operands
244246
target_mesh = as_domain(V)
245247
source_mesh = extract_unique_domain(expr) or target_mesh
246248
submesh_interp_implemented = \
@@ -266,6 +268,8 @@ def __init__(
266268
allow_missing_dofs=False,
267269
matfree=True
268270
):
271+
if isinstance(expr, ufl.Interpolate):
272+
expr, = expr.ufl_operands
269273
self.expr = expr
270274
self.V = V
271275
self.subset = subset
@@ -374,6 +378,8 @@ def __init__(
374378
"Can only interpolate into spaces with point evaluation nodes."
375379
)
376380

381+
if isinstance(expr, ufl.Interpolate):
382+
expr, = expr.ufl_operands
377383
super().__init__(expr, V, subset, freeze_expr, access, bcs, allow_missing_dofs, matfree)
378384

379385
self.arguments = extract_arguments(expr)
@@ -540,7 +546,7 @@ def _interpolate(
540546
V_dest = self.expr.function_space().dual()
541547
except AttributeError:
542548
if self.nargs:
543-
V_dest = self.arguments[0].function_space().dual()
549+
V_dest = self.arguments[-1].function_space().dual()
544550
else:
545551
coeffs = extract_coefficients(self.expr)
546552
if len(coeffs):
@@ -552,8 +558,6 @@ def _interpolate(
552558
else:
553559
if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)):
554560
V_dest = self.V.function_space()
555-
elif isinstance(self.V, firedrake.Coargument):
556-
V_dest = self.V.function_space().dual()
557561
else:
558562
V_dest = self.V
559563
if output:
@@ -679,10 +683,14 @@ class SameMeshInterpolator(Interpolator):
679683
def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE,
680684
bcs=None, matfree=True, allow_missing_dofs=False, **kwargs):
681685
if subset is None:
686+
if isinstance(expr, ufl.Interpolate):
687+
operand, = expr.ufl_operands
688+
else:
689+
operand = expr
682690
target_mesh = as_domain(V)
683-
source_mesh = extract_unique_domain(expr)
691+
source_mesh = extract_unique_domain(operand) or target_mesh
684692
target = target_mesh.topology
685-
source = target if source_mesh is None else source_mesh.topology
693+
source = source_mesh.topology
686694
if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source:
687695
composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None)
688696
if result_integral_type != "cell":
@@ -703,7 +711,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE,
703711
self.callable, arguments = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree)
704712
except FIAT.hdiv_trace.TraceError:
705713
raise NotImplementedError("Can't interpolate onto traces sorry")
706-
self.arguments = arguments
714+
self.arguments = expr.arguments()
707715
self.nargs = len(arguments)
708716

709717
@PETSc.Log.EventDecorator()
@@ -735,16 +743,19 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
735743
# Interpolation action
736744
self.frozen_assembled_interpolator = assembled_interpolator.copy()
737745

738-
if self.nargs == 2:
746+
if hasattr(assembled_interpolator, "handle") and len(function):
739747
function, = function
740748
if not hasattr(function, "dat"):
741749
raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!")
742750
if adjoint:
743751
mul = assembled_interpolator.handle.multHermitian
744752
V = self.arguments[0].function_space().dual()
753+
assert function.function_space() == self.arguments[1].function_space()
745754
else:
746755
mul = assembled_interpolator.handle.mult
747-
V = self.V
756+
V = self.arguments[1].function_space().dual()
757+
assert function.function_space() == self.arguments[0].function_space()
758+
748759
result = output or firedrake.Function(V)
749760
with function.dat.vec_ro as x, result.dat.vec_wo as out:
750761
if x is not out:
@@ -772,29 +783,23 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
772783

773784
@PETSc.Log.EventDecorator()
774785
def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
775-
assert isinstance(expr, ufl.classes.Expr)
786+
assert isinstance(expr, ufl.Interpolate)
787+
dual_arg, operand = expr.argument_slots()
788+
assert isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction))
789+
790+
target_mesh = as_domain(dual_arg)
791+
source_mesh = extract_unique_domain(operand) or target_mesh
792+
same_mesh = target_mesh is source_mesh
793+
794+
vom_onto_other_vom = (
795+
isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology)
796+
and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology)
797+
and target_mesh is not source_mesh
798+
)
776799

777-
if isinstance(V, (ufl.Coargument, ufl.Cofunction)):
778-
dual_arg = V
779-
V = dual_arg.function_space().dual()
780-
elif isinstance(V, (ufl.FunctionSpace, ufl.Coefficient)):
781-
fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space()
782-
dual_arg = Coargument(fs.dual(), number=0)
783-
784-
arguments = extract_arguments(expr)
785-
if isinstance(dual_arg, ufl.Coargument):
786-
arguments.append(dual_arg)
800+
arguments = expr.arguments()
787801
rank = len(arguments)
788-
789-
target_mesh = as_domain(V)
790802
if rank <= 1:
791-
source_mesh = extract_unique_domain(expr) or target_mesh
792-
vom_onto_other_vom = (
793-
isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology)
794-
and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology)
795-
and target_mesh is not source_mesh
796-
)
797-
798803
if rank == 0:
799804
R = firedrake.FunctionSpace(target_mesh, "Real", 0)
800805
f = firedrake.Function(R)
@@ -817,15 +822,8 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
817822
raise ValueError("Cannot interpolate an expression with an argument into a Function")
818823
if len(V) > 1:
819824
raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported")
820-
821825
argfs = arguments[0].function_space()
822-
source_mesh = argfs.mesh()
823826
argfs_map = argfs.cell_node_map()
824-
vom_onto_other_vom = (
825-
isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology)
826-
and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology)
827-
and target_mesh is not source_mesh
828-
)
829827
if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology) and target_mesh is not source_mesh and not vom_onto_other_vom:
830828
if not isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology):
831829
raise NotImplementedError("Can only interpolate onto a Vertex Only Mesh")
@@ -863,8 +861,10 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
863861
else:
864862
raise ValueError("Cannot interpolate an expression with %d arguments" % rank)
865863

864+
if not same_mesh:
865+
arguments = extract_arguments(operand)
866866
if vom_onto_other_vom:
867-
wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, expr, arguments, matfree)
867+
wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, arguments, matfree)
868868
# NOTE: get_dat_mpi_type ensures we get the correct MPI type for the
869869
# data, including the correct data size and dimensional information
870870
# (so for vector function spaces in 2 dimensions we might need a
@@ -874,13 +874,13 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
874874
# when it is called.
875875
assert f.dat is tensor
876876
wrapper.mpi_type, _ = get_dat_mpi_type(f.dat)
877-
assert not len(arguments)
877+
assert len(arguments) == 0
878878

879879
def callable():
880880
wrapper.forward_operation(f.dat)
881881
return f
882882
else:
883-
assert rank == 2
883+
assert len(arguments) == 1
884884
assert tensor is None
885885
# we know we will be outputting either a function or a cofunction,
886886
# both of which will use a dat as a data carrier. At present, the
@@ -904,38 +904,40 @@ def callable():
904904
# Make sure we have an expression of the right length i.e. a value for
905905
# each component in the value shape of each function space
906906
loops = []
907-
if numpy.prod(expr.ufl_shape, dtype=int) != V.value_size:
907+
if numpy.prod(operand.ufl_shape, dtype=int) != V.value_size:
908908
raise RuntimeError('Expression of length %d required, got length %d'
909-
% (V.value_size, numpy.prod(expr.ufl_shape, dtype=int)))
909+
% (V.value_size, numpy.prod(operand.ufl_shape, dtype=int)))
910910

911911
if len(V) == 1:
912-
loops.extend(_interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=bcs))
912+
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))
913913
else:
914-
if (hasattr(expr, "subfunctions") and len(expr.subfunctions) == len(V)
915-
and all(sub_expr.ufl_shape == Vsub.value_shape for Vsub, sub_expr in zip(V, expr.subfunctions))):
914+
if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(V)
915+
and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(V, operand.subfunctions))):
916916
# Use subfunctions if they match the target shapes
917-
expressions = expr.subfunctions
917+
operands = operand.subfunctions
918918
else:
919919
# Unflatten the expression into the shapes of the mixed components
920920
offset = 0
921-
expressions = []
921+
operands = []
922922
for Vsub in V:
923923
if len(Vsub.value_shape) == 0:
924-
expressions.append(expr[offset])
924+
operands.append(operand[offset])
925925
else:
926-
components = [expr[offset + j] for j in range(Vsub.value_size)]
927-
expressions.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape)))
926+
components = [operand[offset + j] for j in range(Vsub.value_size)]
927+
operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape)))
928928
offset += Vsub.value_size
929929

930930
if isinstance(dual_arg, Cofunction):
931931
duals = dual_arg.subfunctions
932932
elif isinstance(dual_arg, Coargument):
933-
duals = [Coargument(Vsub.dual(), number=dual_arg.number()) for Vsub in V]
933+
duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()]
934934
else:
935935
raise ValueError("dual_arg must be a Cofunction or Coargument")
936+
936937
# Interpolate each sub expression into each function space
937-
for Vsub, sub_tensor, sub_expr, sub_dual in zip(V, tensor, expressions, duals):
938-
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, sub_dual, subset, arguments, access, bcs=bcs))
938+
for Vsub, sub_tensor, sub_op, sub_dual in zip(V, tensor, operands, duals):
939+
sub_expr = ufl.Interpolate(sub_op, sub_dual)
940+
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))
939941

940942
if bcs and rank == 1:
941943
loops.extend(partial(bc.apply, f) for bc in bcs)
@@ -949,11 +951,18 @@ def callable(loops, f):
949951

950952

951953
@utils.known_pyop2_safe
952-
def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None):
954+
def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
953955
try:
954956
expr = ufl.as_ufl(expr)
955957
except ValueError:
956958
raise ValueError("Expecting to interpolate a UFL expression")
959+
960+
interp_expr = expr
961+
if isinstance(expr, ufl.Interpolate):
962+
dual_arg, expr = expr.argument_slots()
963+
else:
964+
dual_arg = Coargument(V.dual(), number=0)
965+
957966
try:
958967
to_element = create_element(V.ufl_element())
959968
except KeyError:
@@ -1029,7 +1038,7 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None
10291038
# FIXME: for the runtime unknown point set (for cross-mesh
10301039
# interpolation) we have to pass the finat element we construct
10311040
# here. Ideally we would only pass the UFL element through.
1032-
kernel = compile_expression(cell_set.comm, expr, dual_arg, to_element, V.ufl_element(),
1041+
kernel = compile_expression(cell_set.comm, interp_expr, to_element, V.ufl_element(),
10331042
domain=source_mesh, parameters=parameters)
10341043
ast = kernel.ast
10351044
oriented = kernel.oriented
@@ -1042,7 +1051,7 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None
10421051

10431052
parloop_args = [kernel, cell_set]
10441053

1045-
interp_expr = ufl.Interpolate(expr, dual_arg)
1054+
interp_expr = ufl.Interpolate(expr, v=dual_arg)
10461055
coefficients = tsfc_interface.extract_numbered_coefficients(interp_expr, coefficient_numbers)
10471056
if needs_external_coords:
10481057
coefficients = [source_mesh.coordinates] + coefficients
@@ -1061,12 +1070,13 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None
10611070
if isinstance(tensor, op2.Global):
10621071
parloop_args.append(tensor(access))
10631072
elif isinstance(tensor, op2.Dat):
1064-
V_dest = arguments[0].function_space() if isinstance(dual_arg, ufl.Cofunction) else V
1073+
V_dest = arguments[-1].function_space() if isinstance(dual_arg, ufl.Cofunction) else V
10651074
parloop_args.append(tensor(access, V_dest.cell_node_map()))
10661075
else:
10671076
assert access == op2.WRITE # Other access descriptors not done for Matrices.
10681077
rows_map = V.cell_node_map()
10691078
Vcol = arguments[0].function_space()
1079+
assert tensor.handle.getSize() == (V.dim(), Vcol.dim())
10701080
if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology):
10711081
columns_map = Vcol.cell_node_map()
10721082
if target_mesh is not source_mesh:
@@ -1165,9 +1175,10 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None
11651175
f"firedrake-tsfc-expression-kernel-cache-uid{os.getuid()}")
11661176

11671177

1168-
def _compile_expression_key(comm, expr, dual_arg, to_element, ufl_element, domain, parameters) -> tuple[Hashable, ...]:
1178+
def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters) -> tuple[Hashable, ...]:
11691179
"""Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`."""
1170-
return (hash_expr(expr), type(dual_arg), hash(ufl_element), utils.tuplify(parameters))
1180+
dual_arg, operand = expr.argument_slots()
1181+
return (hash_expr(operand), type(dual_arg), hash(ufl_element), utils.tuplify(parameters))
11711182

11721183

11731184
@memory_and_disk_cache(

tests/firedrake/regression/test_interpolate.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,6 @@ def test_adjoint_Pk(degree):
341341

342342
v_adj_form = assemble(interpolate(TestFunction(Pk), v * dx))
343343

344-
assert v_adj.function_space() == v_adj_form.function_space()
345344
assert np.allclose(v_adj_form.dat.data, v_adj.dat.data)
346345

347346

@@ -354,7 +353,6 @@ def test_adjoint_quads():
354353
u_P1 = assemble(conj(TestFunction(P1)) * dx)
355354
v_adj = assemble(interpolate(TestFunction(P1), assemble(v * dx)))
356355

357-
assert v_adj.function_space() == u_P1.function_space()
358356
assert np.allclose(u_P1.dat.data, v_adj.dat.data)
359357

360358

@@ -367,7 +365,6 @@ def test_adjoint_dg():
367365
u_cg = assemble(conj(TestFunction(cg1)) * dx)
368366
v_adj = assemble(interpolate(TestFunction(cg1), assemble(v * dx)))
369367

370-
assert v_adj.function_space() == u_cg.function_space()
371368
assert np.allclose(u_cg.dat.data, v_adj.dat.data)
372369

373370

0 commit comments

Comments
 (0)