Skip to content

Commit 515b7f9

Browse files
committed
Matfree adjoint interpolation
1 parent 2b008b3 commit 515b7f9

4 files changed

Lines changed: 120 additions & 32 deletions

File tree

firedrake/assemble.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -591,17 +591,18 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
591591
_, v1 = expr.arguments()
592592
operand = ufl.replace(operand, {v1: v1.reconstruct(number=0)})
593593
# Get the interpolator
594-
interp_data = expr.interp_data
594+
interp_data = expr.interp_data.copy()
595595
default_missing_val = interp_data.pop('default_missing_val', None)
596-
interpolator = firedrake.Interpolator(operand, V, **interp_data)
596+
if (is_adjoint and rank == 1) or rank == 0:
597+
interp_data["access"] = op2.INC
598+
interpolator = firedrake.Interpolator(operand, v, **interp_data)
597599
# Assembly
598600
if rank == 0:
599-
Iu = interpolator._interpolate(default_missing_val=default_missing_val)
600-
return assemble(ufl.Action(v, Iu), tensor=tensor)
601+
return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val)
601602
elif rank == 1:
602603
# Assembling the action of the Jacobian adjoint.
603604
if is_adjoint:
604-
return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val)
605+
return interpolator._interpolate(v, output=tensor, default_missing_val=default_missing_val)
605606
# Assembling the Jacobian action.
606607
elif interpolator.nargs:
607608
return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val)

firedrake/interpolation.py

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@ def _interpolate(
552552
else:
553553
if isinstance(self.V, (firedrake.Function, firedrake.Cofunction)):
554554
V_dest = self.V.function_space()
555+
elif isinstance(self.V, firedrake.Coargument):
556+
V_dest = self.V.function_space().dual()
555557
else:
556558
V_dest = self.V
557559
if output:
@@ -677,9 +679,10 @@ class SameMeshInterpolator(Interpolator):
677679
def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE,
678680
bcs=None, matfree=True, allow_missing_dofs=False, **kwargs):
679681
if subset is None:
680-
target = V.function_space().mesh().topology if isinstance(V, firedrake.Function) else V.mesh().topology
681-
temp = extract_unique_domain(expr)
682-
source = target if temp is None else temp.topology
682+
target_mesh = as_domain(V)
683+
source_mesh = extract_unique_domain(expr)
684+
target = target_mesh.topology
685+
source = target if source_mesh is None else source_mesh.topology
683686
if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source:
684687
composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None)
685688
if result_integral_type != "cell":
@@ -732,7 +735,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
732735
# Interpolation action
733736
self.frozen_assembled_interpolator = assembled_interpolator.copy()
734737

735-
if self.nargs:
738+
if self.nargs == 2:
736739
function, = function
737740
if not hasattr(function, "dat"):
738741
raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!")
@@ -770,20 +773,37 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
770773
@PETSc.Log.EventDecorator()
771774
def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
772775
assert isinstance(expr, ufl.classes.Expr)
776+
777+
if isinstance(V, (ufl.Coargument, ufl.Cofunction)):
778+
dual_arg = V
779+
V = dual_arg.function_space().dual()
780+
elif isinstance(V, (ufl.FunctionSpace, ufl.Coefficient)):
781+
fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space()
782+
dual_arg = Coargument(fs.dual(), number=0)
783+
773784
arguments = extract_arguments(expr)
785+
if isinstance(dual_arg, ufl.Coargument):
786+
arguments.append(dual_arg)
787+
rank = len(arguments)
788+
774789
target_mesh = as_domain(V)
775-
if len(arguments) == 0:
790+
if rank <= 1:
776791
source_mesh = extract_unique_domain(expr) or target_mesh
777792
vom_onto_other_vom = (
778793
isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology)
779794
and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology)
780795
and target_mesh is not source_mesh
781796
)
797+
798+
if rank == 0:
799+
# FIXME
800+
V = firedrake.Function(firedrake.FunctionSpace(target_mesh, "Real", 0))
782801
if isinstance(V, firedrake.Function):
783802
f = V
784803
V = f.function_space()
785804
else:
786-
f = firedrake.Function(V)
805+
V_dest = arguments[-1].function_space().dual()
806+
f = firedrake.Function(V_dest)
787807
if access in {firedrake.MIN, firedrake.MAX}:
788808
finfo = numpy.finfo(f.dat.dtype)
789809
if access == firedrake.MIN:
@@ -792,11 +812,12 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
792812
val = firedrake.Constant(finfo.min)
793813
f.assign(val)
794814
tensor = f.dat
795-
elif len(arguments) == 1:
815+
elif rank == 2:
796816
if isinstance(V, firedrake.Function):
797817
raise ValueError("Cannot interpolate an expression with an argument into a Function")
798818
if len(V) > 1:
799819
raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported")
820+
800821
argfs = arguments[0].function_space()
801822
source_mesh = argfs.mesh()
802823
argfs_map = argfs.cell_node_map()
@@ -840,7 +861,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
840861
tensor = op2.Mat(sparsity)
841862
f = tensor
842863
else:
843-
raise ValueError("Cannot interpolate an expression with %d arguments" % len(arguments))
864+
raise ValueError("Cannot interpolate an expression with %d arguments" % rank)
844865

845866
if vom_onto_other_vom:
846867
wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, expr, arguments, matfree)
@@ -859,7 +880,7 @@ def callable():
859880
wrapper.forward_operation(f.dat)
860881
return f
861882
else:
862-
assert len(arguments) == 1
883+
assert rank == 2
863884
assert tensor is None
864885
# we know we will be outputting either a function or a cofunction,
865886
# both of which will use a dat as a data carrier. At present, the
@@ -888,7 +909,7 @@ def callable():
888909
% (V.value_size, numpy.prod(expr.ufl_shape, dtype=int)))
889910

890911
if len(V) == 1:
891-
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))
912+
loops.extend(_interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=bcs))
892913
else:
893914
if (hasattr(expr, "subfunctions") and len(expr.subfunctions) == len(V)
894915
and all(sub_expr.ufl_shape == Vsub.value_shape for Vsub, sub_expr in zip(V, expr.subfunctions))):
@@ -905,11 +926,18 @@ def callable():
905926
components = [expr[offset + j] for j in range(Vsub.value_size)]
906927
expressions.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape)))
907928
offset += Vsub.value_size
929+
930+
if isinstance(dual_arg, Cofunction):
931+
duals = dual_arg.subfunctions
932+
elif isinstance(dual_arg, Coargument):
933+
duals = [Coargument(Vsub.dual(), number=dual_arg.number()) for Vsub in V]
934+
else:
935+
raise ValueError("dual_arg must be a Cofunction or Coargument")
908936
# Interpolate each sub expression into each function space
909-
for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions):
910-
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))
937+
for Vsub, sub_tensor, sub_expr, sub_dual in zip(V, tensor, expressions, duals):
938+
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, sub_dual, subset, arguments, access, bcs=bcs))
911939

912-
if bcs and len(arguments) == 0:
940+
if bcs and rank == 1:
913941
loops.extend(partial(bc.apply, f) for bc in bcs)
914942

915943
def callable(loops, f):
@@ -921,7 +949,7 @@ def callable(loops, f):
921949

922950

923951
@utils.known_pyop2_safe
924-
def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
952+
def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None):
925953
try:
926954
expr = ufl.as_ufl(expr)
927955
except ValueError:
@@ -977,13 +1005,31 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
9771005
parameters = {}
9781006
parameters['scalar_type'] = utils.ScalarType
9791007

1008+
needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg()
1009+
if needs_weight:
1010+
W = dual_arg.function_space()
1011+
shapes = (W.finat_element.space_dimension(), W.block_size)
1012+
domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes
1013+
instructions = """
1014+
for i, j
1015+
w[i,j] = w[i,j] + 1
1016+
end
1017+
"""
1018+
weight = firedrake.Function(W)
1019+
firedrake.par_loop((domain, instructions), ufl.dx, {"w": (weight, op2.INC)})
1020+
1021+
tmp = firedrake.Function(W)
1022+
with weight.dat.vec as w, dual_arg.dat.vec as x, tmp.dat.vec as y:
1023+
y.pointwiseDivide(x, w)
1024+
dual_arg = tmp
1025+
9801026
# We need to pass both the ufl element and the finat element
9811027
# because the finat elements might not have the right mapping
9821028
# (e.g. L2 Piola, or tensor element with symmetries)
9831029
# FIXME: for the runtime unknown point set (for cross-mesh
9841030
# interpolation) we have to pass the finat element we construct
9851031
# here. Ideally we would only pass the UFL element through.
986-
kernel = compile_expression(cell_set.comm, expr, to_element, V.ufl_element(),
1032+
kernel = compile_expression(cell_set.comm, expr, dual_arg, to_element, V.ufl_element(),
9871033
domain=source_mesh, parameters=parameters)
9881034
ast = kernel.ast
9891035
oriented = kernel.oriented
@@ -996,7 +1042,8 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
9961042

9971043
parloop_args = [kernel, cell_set]
9981044

999-
coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers)
1045+
interp_expr = ufl.Interpolate(expr, dual_arg)
1046+
coefficients = tsfc_interface.extract_numbered_coefficients(interp_expr, coefficient_numbers)
10001047
if needs_external_coords:
10011048
coefficients = [source_mesh.coordinates] + coefficients
10021049

@@ -1014,7 +1061,8 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10141061
if isinstance(tensor, op2.Global):
10151062
parloop_args.append(tensor(access))
10161063
elif isinstance(tensor, op2.Dat):
1017-
parloop_args.append(tensor(access, V.cell_node_map()))
1064+
V_dest = arguments[0].function_space() if isinstance(dual_arg, ufl.Cofunction) else V
1065+
parloop_args.append(tensor(access, V_dest.cell_node_map()))
10181066
else:
10191067
assert access == op2.WRITE # Other access descriptors not done for Matrices.
10201068
rows_map = V.cell_node_map()
@@ -1117,9 +1165,9 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
11171165
f"firedrake-tsfc-expression-kernel-cache-uid{os.getuid()}")
11181166

11191167

1120-
def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters) -> tuple[Hashable, ...]:
1168+
def _compile_expression_key(comm, expr, dual_arg, to_element, ufl_element, domain, parameters) -> tuple[Hashable, ...]:
11211169
"""Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`."""
1122-
return (hash_expr(expr), hash(ufl_element), utils.tuplify(parameters))
1170+
return (hash_expr(expr), type(dual_arg), hash(ufl_element), utils.tuplify(parameters))
11231171

11241172

11251173
@memory_and_disk_cache(

tests/firedrake/regression/test_interpolate.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def test_function_cofunction(degree):
379379
f = assemble(interpolate(sin(2*pi*x[0])*sin(2*pi*x[1]), Pk))
380380

381381
fhat = assemble(f*v1*dx)
382-
norm_i = assemble(interpolate(f, fhat))
382+
norm_i = assemble(interpolate(f, fhat)).dat.data.item()
383383
norm = assemble(f*f*dx)
384384

385385
assert np.allclose(norm_i, norm)
@@ -505,3 +505,22 @@ def test_interpolate_logical_not():
505505
a = assemble(interpolate(conditional(Not(x < .2), 1, 0), V))
506506
b = assemble(interpolate(conditional(x >= .2, 1, 0), V))
507507
assert np.allclose(a.dat.data, b.dat.data)
508+
509+
510+
@pytest.mark.parametrize("family", ["CG", "DG"])
511+
def test_interpolate_adjoint_matfree(family):
512+
mesh = UnitSquareMesh(2, 2)
513+
514+
V1 = FunctionSpace(mesh, family, 1)
515+
V2 = FunctionSpace(mesh, family, 2)
516+
517+
v1 = TestFunction(V1)
518+
c2 = Cofunction(V2.dual()).assign(1)
519+
520+
result = assemble(interpolate(v1, c2))
521+
assert result.function_space() == V1.dual()
522+
523+
I = assemble(interpolate(v1, V2))
524+
expected = assemble(action(adjoint(I), c2))
525+
526+
assert np.allclose(result.dat.data_ro, expected.dat.data_ro)

tsfc/driver.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,16 @@ def preprocess_parameters(parameters):
182182
return parameters
183183

184184

185-
def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
185+
def compile_expression_dual_evaluation(expression, dual_arg,
186+
to_element, ufl_element, *,
186187
domain=None, interface=None,
187188
parameters=None):
188189
"""Compile a UFL expression to be evaluated against a compile-time known reference element's dual basis.
189190
190191
Useful for interpolating UFL expressions into e.g. N1curl spaces.
191192
192-
:arg expression: UFL expression
193+
:arg expression: UFL expression to interpolate
194+
:arg dual_arg: A Cofunction or Coargument to act on the interpolated expression
193195
:arg to_element: A FInAT element for the target space
194196
:arg ufl_element: The UFL element of the target space.
195197
:arg domain: optional UFL domain the expression is defined on (required when expression contains no domain).
@@ -210,7 +212,11 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
210212
if isinstance(to_element, (PhysicallyMappedElement, DirectlyDefinedElement)):
211213
raise NotImplementedError("Don't know how to interpolate onto zany spaces, sorry")
212214

213-
orig_expression = expression
215+
if not isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction)):
216+
raise ValueError(f"Expecting a Coargument or Cofunction, not {type(dual_arg).__name__}")
217+
218+
interp_expression = ufl.Interpolate(expression, dual_arg)
219+
orig_expression = interp_expression
214220

215221
# Map into reference space
216222
expression = apply_mapping(expression, ufl_element, domain)
@@ -235,8 +241,9 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
235241
assert domain is not None
236242

237243
# Collect required coefficients and determine numbering
238-
coefficients = extract_coefficients(expression)
244+
coefficients = extract_coefficients(interp_expression)
239245
orig_coefficients = extract_coefficients(orig_expression)
246+
240247
coefficient_numbers = tuple(map(orig_coefficients.index, coefficients))
241248
builder.set_coefficient_numbers(coefficient_numbers)
242249

@@ -252,7 +259,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
252259
needs_external_coords = True
253260
builder.set_coefficients(coefficients)
254261

255-
constants = extract_firedrake_constants(expression)
262+
constants = extract_firedrake_constants(interp_expression)
256263
builder.set_constants(constants)
257264

258265
# Split mixed coefficients
@@ -281,11 +288,24 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
281288
# indices needed for compilation of the expression
282289
evaluation, basis_indices = to_element.dual_evaluation(fn)
283290

291+
# Compute the adjoint by contracting against the dual argument
292+
if dual_arg in coefficients:
293+
beta = basis_indices
294+
shape = tuple(i.extent for i in beta)
295+
gem_dual = gem.Variable(f"w_{coefficients.index(dual_arg)}", shape)
296+
evaluation = gem.IndexSum(evaluation * gem_dual[beta], beta)
297+
basis_indices = ()
298+
284299
# Build kernel body
285300
return_indices = basis_indices + tuple(chain(*argument_multiindices))
286301
return_shape = tuple(i.extent for i in return_indices)
287-
return_var = gem.Variable('A', return_shape)
288-
return_expr = gem.Indexed(return_var, return_indices)
302+
if return_shape:
303+
return_var = gem.Variable('A', return_shape)
304+
return_expr = gem.Indexed(return_var, return_indices)
305+
else:
306+
# 0-forms are expected to write into an indexed array
307+
return_var = gem.Variable('A', (1,))
308+
return_expr = gem.Indexed(return_var, (0,))
289309

290310
# TODO: one should apply some GEM optimisations as in assembly,
291311
# but we don't for now.

0 commit comments

Comments
 (0)