Skip to content

Commit f1080ea

Browse files
committed
Interpolate: support fieldsplit
1 parent 76ec367 commit f1080ea

2 files changed

Lines changed: 61 additions & 60 deletions

File tree

firedrake/formmanipulation.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy
33
import collections
44

5-
from ufl import as_vector, split
5+
from ufl import as_tensor, as_vector, split
66
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
77
from ufl.algorithms.map_integrands import map_integrand_dags
88
from ufl.algorithms import expand_derivatives
@@ -14,6 +14,7 @@
1414
from firedrake.petsc import PETSc
1515
from firedrake.functionspace import MixedFunctionSpace
1616
from firedrake.cofunction import Cofunction
17+
from firedrake.ufl_expr import Coargument
1718
from firedrake.matrix import AssembledMatrix
1819

1920

@@ -175,7 +176,32 @@ def interpolate(self, o, operand):
175176
if isinstance(operand, Zero):
176177
return ZeroBaseForm(o.arguments())
177178

178-
return o._ufl_expr_reconstruct_(operand)
179+
dual_arg, _ = o.argument_slots()
180+
V = dual_arg.function_space()
181+
if len(V) == 1:
182+
return o._ufl_expr_reconstruct_(operand, dual_arg)
183+
184+
# Split the target (dual) argument
185+
if isinstance(dual_arg, Coargument):
186+
indices = self.blocks[dual_arg.number()]
187+
W = subspace(dual_arg.function_space(), indices)
188+
dual_arg = Coargument(W, dual_arg.number())
189+
else:
190+
raise NotImplementedError()
191+
192+
# Unflatten the expression into the target shapes
193+
cur = 0
194+
operands = []
195+
components = numpy.reshape(operand, (-1,))
196+
for i, Vi in enumerate(V):
197+
if i in indices:
198+
operands.extend(components[cur:cur+Vi.value_size])
199+
cur += Vi.value_size
200+
201+
operand = as_tensor(numpy.reshape(operands, W.value_shape))
202+
if isinstance(operand, Zero):
203+
return ZeroBaseForm(o.arguments())
204+
return o._ufl_expr_reconstruct_(operand, dual_arg)
179205

180206

181207
SplitForm = collections.namedtuple("SplitForm", ["indices", "form"])

firedrake/interpolation.py

Lines changed: 33 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -985,14 +985,18 @@ def callable():
985985
return callable
986986
else:
987987
loops = []
988-
expressions = split_interpolate_target(expr)
989988

990989
if access == op2.INC:
991990
loops.append(tensor.zero)
992991

993992
# Interpolate each sub expression into each function space
994-
for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions):
995-
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))
993+
tensors = list(tensor)
994+
if len(tensors) == 1:
995+
split = [((0,), expr)]
996+
else:
997+
split = firedrake.formmanipulation.split_form(expr)
998+
for (i,), sub_expr in split:
999+
loops.extend(_interpolator(V[i], tensors[i], sub_expr, subset, arguments, access, bcs=bcs))
9961000

9971001
if bcs and rank == 1:
9981002
loops.extend(partial(bc.apply, f) for bc in bcs)
@@ -1707,36 +1711,6 @@ def duplicate(self, mat=None, op=None):
17071711
return self._wrap_dummy_mat()
17081712

17091713

1710-
def split_interpolate_target(expr: ufl.Interpolate):
1711-
"""Split an Interpolate into the components (subfunctions) of the target space."""
1712-
dual_arg, operand = expr.argument_slots()
1713-
V = dual_arg.function_space().dual()
1714-
if len(V) == 1:
1715-
return (expr,)
1716-
# Split the target (dual) argument
1717-
if isinstance(dual_arg, Cofunction):
1718-
duals = dual_arg.subfunctions
1719-
elif isinstance(dual_arg, ufl.Coargument):
1720-
duals = [Coargument(Vsub, dual_arg.number()) for Vsub in dual_arg.function_space()]
1721-
else:
1722-
duals = [vi for _, vi in sorted(firedrake.formmanipulation.split_form(dual_arg))]
1723-
# Split the operand into the target shapes
1724-
if (isinstance(operand, firedrake.Function) and len(operand.subfunctions) == len(V)
1725-
and all(fsub.ufl_shape == Vsub.value_shape for Vsub, fsub in zip(V, operand.subfunctions))):
1726-
# Use subfunctions if they match the target shapes
1727-
operands = operand.subfunctions
1728-
else:
1729-
# Unflatten the expression into the target shapes
1730-
cur = 0
1731-
operands = []
1732-
components = numpy.reshape(operand, (-1,))
1733-
for Vi in V:
1734-
operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape)))
1735-
cur += Vi.value_size
1736-
expressions = tuple(map(expr._ufl_expr_reconstruct_, operands, duals))
1737-
return expressions
1738-
1739-
17401714
class MixedInterpolator(Interpolator):
17411715
"""A reusable interpolation object between MixedFunctionSpaces.
17421716
@@ -1754,39 +1728,40 @@ class MixedInterpolator(Interpolator):
17541728
For details see :class:`firedrake.interpolation.Interpolator`.
17551729
"""
17561730
def __init__(self, expr, V, bcs=None, **kwargs):
1731+
bcs = bcs or ()
17571732
super(MixedInterpolator, self).__init__(expr, V, bcs=bcs, **kwargs)
1733+
17581734
expr = self.ufl_interpolate
1759-
bcs = bcs or ()
17601735
self.arguments = expr.arguments()
1761-
1762-
# Split the target (dual) argument
1763-
dual_split = split_interpolate_target(expr)
1764-
self.sub_interpolators = {}
1765-
for i, form in enumerate(dual_split):
1766-
# Split the source (primal) argument
1767-
for j, sub_interp in firedrake.formmanipulation.split_form(form):
1768-
j = max(j) if j else 0
1769-
# Ensure block sparsity
1770-
if not isinstance(sub_interp, ufl.ZeroBaseForm):
1771-
vi, operand = sub_interp.argument_slots()
1772-
Vtarget = vi.function_space().dual()
1773-
adjoint = vi.number() == 1 if isinstance(vi, Coargument) else True
1774-
1775-
args = sub_interp.arguments()
1776-
Vsource = args[0 if adjoint else 1].function_space()
1777-
sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}]
1778-
1779-
indices = (j, i) if adjoint else (i, j)
1780-
Isub = Interpolator(sub_interp, Vtarget, bcs=sub_bcs, **kwargs)
1781-
self.sub_interpolators[indices] = Isub
1782-
1736+
rank = len(self.arguments)
1737+
if rank < 2:
1738+
dual_arg, operand = expr.argument_slots()
1739+
# Split the dual argument
1740+
dual_split = dict(firedrake.formmanipulation.split_form(dual_arg))
1741+
# Create the Jacobian to split into blocks
1742+
expr = expr._ufl_expr_reconstruct_(operand, firedrake.TrialFunction(dual_arg.function_space()))
1743+
1744+
Isub = {}
1745+
for indices, form in firedrake.formmanipulation.split_form(expr):
1746+
if not isinstance(form, ufl.ZeroBaseForm):
1747+
args = form.arguments()
1748+
vi, operand = form.argument_slots()
1749+
Vtarget = vi.function_space().dual()
1750+
Vsource = args[1-vi.number()].function_space()
1751+
sub_bcs = [bc for bc in self.bcs if bc.function_space() in {Vsource, Vtarget}]
1752+
if rank == 1:
1753+
# Take the action of each sub-cofunction against each block
1754+
form = action(form, dual_split[indices[1:]])
1755+
Isub[indices] = Interpolator(form, Vtarget, bcs=sub_bcs, **kwargs)
1756+
1757+
self.sub_interpolators = Isub
17831758
self.callable = self._get_callable
17841759

17851760
def _get_callable(self):
17861761
"""Assemble the operator."""
1762+
Isub = self.sub_interpolators
17871763
shape = tuple(len(a.function_space()) for a in self.arguments)
1788-
Isubs = self.sub_interpolators
1789-
blocks = numpy.reshape([Isubs[ij].callable().handle if ij in Isubs else PETSc.Mat()
1764+
blocks = numpy.reshape([Isub[ij].callable().handle if ij in Isub else PETSc.Mat()
17901765
for ij in numpy.ndindex(shape)], shape)
17911766
petscmat = PETSc.Mat().createNest(blocks)
17921767
tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat)

0 commit comments

Comments
 (0)