Skip to content

Commit 120a2a3

Browse files
committed
Reusable Interpolator
1 parent 44a98ed commit 120a2a3

4 files changed

Lines changed: 22 additions & 13 deletions

File tree

firedrake/assemble.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
538538
result = expr.assemble(assembly_opts=opts)
539539
return tensor.assign(result) if tensor else result
540540
elif isinstance(expr, ufl.Interpolate):
541+
if not isinstance(expr, firedrake.Interpolate):
542+
expr = firedrake.Interpolate(*reversed(expr.dual_args()))
541543
orig_expr = expr
542544
# Replace assembled children
543545
_, operand = expr.argument_slots()

firedrake/interpolation.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,8 +1006,13 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10061006
parameters = {}
10071007
parameters['scalar_type'] = utils.ScalarType
10081008

1009+
callables = ()
1010+
if access == op2.INC:
1011+
callables += (tensor.zero,)
1012+
10091013
needs_weight = isinstance(dual_arg, ufl.Cofunction) and not to_element.is_dg()
10101014
if needs_weight:
1015+
# Compute the reciprocal of the DOF multiplicity
10111016
W = dual_arg.function_space()
10121017
shapes = (W.finat_element.space_dimension(), W.block_size)
10131018
domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes
@@ -1018,14 +1023,14 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10181023
"""
10191024
weight = firedrake.Function(W)
10201025
firedrake.par_loop((domain, instructions), ufl.dx, {"w": (weight, op2.INC)})
1026+
with weight.dat.vec as w:
1027+
w.reciprocal()
10211028

1022-
# Create a copy and apply the weight
1023-
# TODO include this in the callables
1024-
v = firedrake.Function(dual_arg)
1025-
with v.dat.vec as x, weight.dat.vec as w:
1026-
x.pointwiseDivide(x, w)
1027-
1029+
# Create a buffer for the weighted Cofunction and a callable to apply the weight
1030+
v = firedrake.Function(W)
10281031
expr = expr._ufl_expr_reconstruct_(operand, v=v)
1032+
with weight.dat.vec_ro as w, dual_arg.dat.vec_ro as x, v.dat.vec_wo as y:
1033+
callables += (partial(y.pointwiseMult, x, w),)
10291034

10301035
# We need to pass both the ufl element and the finat element
10311036
# because the finat elements might not have the right mapping
@@ -1043,6 +1048,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10431048
name = kernel.name
10441049
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True,
10451050
flop_count=kernel.flop_count, events=(kernel.event,))
1051+
10461052
parloop_args = [kernel, cell_set]
10471053

10481054
coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers)
@@ -1158,7 +1164,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
11581164
if isinstance(tensor, op2.Mat):
11591165
return parloop_compute_callable, tensor.assemble
11601166
else:
1161-
return copyin + (parloop_compute_callable, ) + copyout
1167+
return copyin + callables + (parloop_compute_callable, ) + copyout
11621168

11631169

11641170
try:

firedrake/variational_solver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from firedrake import dmhooks, slate, solving, solving_utils, ufl_expr, utils
88
from firedrake.petsc import PETSc, DEFAULT_KSP_PARAMETERS, DEFAULT_SNES_PARAMETERS
99
from firedrake.function import Function
10-
from firedrake.interpolation import Interpolate
10+
from firedrake.interpolation import interpolate
1111
from firedrake.matrix import MatrixBase
1212
from firedrake.ufl_expr import TrialFunction, TestFunction
1313
from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space
@@ -98,7 +98,7 @@ def __init__(self, F, u, bcs=None, J=None,
9898
F_arg, = F.arguments()
9999
self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict})
100100
else:
101-
self.F = Interpolate(v_res, replace(F, {self.u: self.u_restrict}))
101+
self.F = interpolate(v_res, replace(F, {self.u: self.u_restrict}))
102102

103103
v_arg, u_arg = self.J.arguments()
104104
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})

tests/firedrake/regression/test_interpolate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def test_trace():
329329

330330
@pytest.mark.parametrize("rank", (0, 1))
331331
@pytest.mark.parametrize("mat_type", ("matfree", "aij"))
332-
@pytest.mark.parametrize("degree", range(1, 4))
332+
@pytest.mark.parametrize("degree", (1, 3))
333333
@pytest.mark.parametrize("cell", ["triangle", "quadrilateral"])
334334
@pytest.mark.parametrize("shape", ("scalar", "vector", "tensor"))
335335
def test_adjoint_Pk(rank, mat_type, degree, cell, shape):
@@ -350,14 +350,15 @@ def test_adjoint_Pk(rank, mat_type, degree, cell, shape):
350350
operand = TestFunction(Pk)
351351

352352
if mat_type == "matfree":
353-
result = assemble(interpolate(operand, v))
353+
interp = interpolate(operand, v)
354354
else:
355355
adj_interp = assemble(interpolate(operand, TrialFunction(Pkp1.dual())))
356356
if rank == 0:
357-
result = assemble(action(v, adj_interp))
357+
interp = action(v, adj_interp)
358358
else:
359-
result = assemble(action(adj_interp, v))
359+
interp = action(adj_interp, v)
360360

361+
result = assemble(interp)
361362
expect = assemble(inner(expr, operand) * dx)
362363
if rank == 0:
363364
assert np.allclose(result, expect)

0 commit comments

Comments
 (0)