Skip to content

Commit 4ce9f8e

Browse files
committed
Fixup
1 parent 9826f0d commit 4ce9f8e

3 files changed

Lines changed: 32 additions & 40 deletions

File tree

firedrake/assemble.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,10 +602,13 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
602602
expr = reconstruct_interp(operand, v=V)
603603

604604
# Get the interpolator
605-
interp_data = expr.interp_data.copy()
605+
interp_data = expr.interp_data
606606
default_missing_val = interp_data.pop('default_missing_val', None)
607607
if same_mesh and ((is_adjoint and rank == 1) or rank == 0):
608608
interp_data["access"] = op2.INC
609+
610+
if rank == 1 and ((same_mesh and tensor) or isinstance(tensor, firedrake.Function)):
611+
V = tensor
609612
interpolator = firedrake.Interpolator(expr, V, **interp_data)
610613

611614
# Assembly

firedrake/interpolation.py

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,20 @@ def __init__(self, expr, v,
9393
and reduce operations.
9494
"""
9595
# Check function space
96+
expr = ufl.as_ufl(expr)
9697
if isinstance(v, functionspaceimpl.WithGeometry):
97-
expr_args = extract_arguments(ufl.as_ufl(expr))
98+
expr_args = extract_arguments(expr)
9899
is_adjoint = len(expr_args) and expr_args[0].number() == 0
99100
v = Argument(v.dual(), 1 if is_adjoint else 0)
101+
102+
V = v.arguments()[0].function_space()
103+
if len(expr.ufl_shape) != len(V.value_shape):
104+
raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d'
105+
% (len(expr.ufl_shape), len(V.value_shape)))
106+
107+
if expr.ufl_shape != V.value_shape:
108+
raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r'
109+
% (expr.ufl_shape, V.value_shape))
100110
super().__init__(expr, v)
101111

102112
# -- Interpolate data (e.g. `subset` or `access`) -- #
@@ -173,7 +183,7 @@ def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False
173183
raise TypeError(f"Expected a one-form, provided form had {rank} arguments")
174184
elif isinstance(V, functionspaceimpl.WithGeometry):
175185
dual_arg = Coargument(V.dual(), 0)
176-
expr_args = extract_arguments(expr)
186+
expr_args = extract_arguments(ufl.as_ufl(expr))
177187
if expr_args and expr_args[0].number() == 0:
178188
# In this case we are doing adjoint interpolation
179189
# When V is a FunctionSpace and expr contains Argument(0),
@@ -483,7 +493,7 @@ def __init__(
483493
if len(shape) == 0:
484494
fs_type = firedrake.FunctionSpace
485495
elif len(shape) == 1:
486-
fs_type = firedrake.VectorFunctionSpace
496+
fs_type = partial(firedrake.VectorFunctionSpace, dim=shape[0])
487497
else:
488498
fs_type = partial(firedrake.TensorFunctionSpace, shape=shape)
489499
P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0)
@@ -710,7 +720,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE,
710720
super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr,
711721
access=access, bcs=bcs, matfree=matfree, allow_missing_dofs=allow_missing_dofs)
712722
try:
713-
self.callable, arguments = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree)
723+
self.callable = make_interpolator(expr, V, subset, access, bcs=bcs, matfree=matfree)
714724
except FIAT.hdiv_trace.TraceError:
715725
raise NotImplementedError("Can't interpolate onto traces sorry")
716726
self.arguments = expr.arguments()
@@ -726,6 +736,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
726736
if transpose is not None:
727737
warnings.warn("'transpose' argument is deprecated, use 'adjoint' instead", FutureWarning)
728738
adjoint = transpose or adjoint
739+
729740
try:
730741
assembled_interpolator = self.frozen_assembled_interpolator
731742
copy_required = True
@@ -740,7 +751,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
740751
# Interpolation action
741752
self.frozen_assembled_interpolator = assembled_interpolator.copy()
742753

743-
if hasattr(assembled_interpolator, "handle") and len(function):
754+
if len(self.arguments) == 2 and len(function):
744755
function, = function
745756
if not hasattr(function, "dat"):
746757
raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!")
@@ -783,15 +794,11 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
783794
@PETSc.Log.EventDecorator()
784795
def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
785796
if not isinstance(expr, ufl.Interpolate):
786-
fs = V if isinstance(V, ufl.FunctionSpace) else V.function_space()
787-
expr = Interpolate(expr, fs)
797+
raise ValueError(f"Expecting to interpolate a ufl.Interpolate, got {type(expr).__name__}.")
788798
dual_arg, operand = expr.argument_slots()
789-
assert isinstance(dual_arg, (ufl.Coargument, ufl.Cofunction))
790799

791800
target_mesh = as_domain(dual_arg)
792801
source_mesh = extract_unique_domain(operand) or target_mesh
793-
same_mesh = target_mesh is source_mesh
794-
795802
vom_onto_other_vom = (
796803
isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology)
797804
and isinstance(source_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology)
@@ -803,7 +810,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
803810
if rank <= 1:
804811
if rank == 0:
805812
R = firedrake.FunctionSpace(target_mesh, "Real", 0)
806-
f = firedrake.Function(R)
813+
f = firedrake.Function(R, dtype=utils.ScalarType)
807814
elif isinstance(V, firedrake.Function):
808815
f = V
809816
V = f.function_space()
@@ -862,10 +869,8 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
862869
else:
863870
raise ValueError("Cannot interpolate an expression with %d arguments" % rank)
864871

865-
if not same_mesh:
866-
arguments = extract_arguments(operand)
867872
if vom_onto_other_vom:
868-
wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, arguments, matfree)
873+
wrapper = VomOntoVomWrapper(V, source_mesh, target_mesh, operand, matfree)
869874
# NOTE: get_dat_mpi_type ensures we get the correct MPI type for the
870875
# data, including the correct data size and dimensional information
871876
# (so for vector function spaces in 2 dimensions we might need a
@@ -875,13 +880,13 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
875880
# when it is called.
876881
assert f.dat is tensor
877882
wrapper.mpi_type, _ = get_dat_mpi_type(f.dat)
878-
assert len(arguments) == 0
883+
assert len(arguments) == 1
879884

880885
def callable():
881886
wrapper.forward_operation(f.dat)
882887
return f
883888
else:
884-
assert len(arguments) == 1
889+
assert len(arguments) == 2
885890
assert tensor is None
886891
# we know we will be outputting either a function or a cofunction,
887892
# both of which will use a dat as a data carrier. At present, the
@@ -900,15 +905,9 @@ def callable():
900905
def callable():
901906
return wrapper
902907

903-
return callable, arguments
908+
return callable
904909
else:
905-
# Make sure we have an expression of the right length i.e. a value for
906-
# each component in the value shape of each function space
907910
loops = []
908-
if numpy.prod(operand.ufl_shape, dtype=int) != V.value_size:
909-
raise RuntimeError('Expression of length %d required, got length %d'
910-
% (V.value_size, numpy.prod(operand.ufl_shape, dtype=int)))
911-
912911
if len(V) == 1:
913912
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))
914913
else:
@@ -933,8 +932,7 @@ def callable():
933932
elif isinstance(dual_arg, Coargument):
934933
duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()]
935934
else:
936-
raise ValueError("dual_arg must be a Cofunction or Coargument")
937-
935+
duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))]
938936
# Interpolate each sub expression into each function space
939937
for Vsub, sub_tensor, sub_op, sub_dual in zip(V, tensor, operands, duals):
940938
sub_expr = expr._ufl_expr_reconstruct_(sub_op, sub_dual)
@@ -948,7 +946,7 @@ def callable(loops, f):
948946
l()
949947
return f
950948

951-
return partial(callable, loops, f), arguments
949+
return partial(callable, loops, f)
952950

953951

954952
@utils.known_pyop2_safe
@@ -966,14 +964,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
966964
if access is op2.READ:
967965
raise ValueError("Can't have READ access for output function")
968966

969-
if len(operand.ufl_shape) != len(V.value_shape):
970-
raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d'
971-
% (len(operand.ufl_shape), len(V.value_shape)))
972-
973-
if operand.ufl_shape != V.value_shape:
974-
raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r'
975-
% (operand.ufl_shape, V.value_shape))
976-
977967
# NOTE: The par_loop is always over the target mesh cells.
978968
target_mesh = as_domain(V)
979969
source_mesh = extract_unique_domain(operand) or target_mesh
@@ -1401,17 +1391,15 @@ class VomOntoVomWrapper(object):
14011391
expr : `ufl.Expr`
14021392
The expression to interpolate. If ``arguments`` is not empty, those
14031393
arguments must be present within it.
1404-
arguments : list of `ufl.Argument`
1405-
The arguments in the expression. These are not extracted from expr here
1406-
since, where we use this, we already have them.
14071394
matfree : bool
14081395
If ``False``, the matrix representating the permutation of the points is
14091396
constructed and used to perform the interpolation. If ``True``, then the
14101397
interpolation is performed using the broadcast and reduce operations on the
14111398
PETSc Star Forest.
14121399
"""
14131400

1414-
def __init__(self, V, source_vom, target_vom, expr, arguments, matfree):
1401+
def __init__(self, V, source_vom, target_vom, expr, matfree):
1402+
arguments = extract_arguments(expr)
14151403
reduce = False
14161404
if source_vom.input_ordering is target_vom:
14171405
reduce = True

tsfc/driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,8 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
223223
operand = ufl_utils.preprocess_expression(operand, complex_mode=complex_mode)
224224

225225
if isinstance(expression, ufl.Interpolate):
226-
expression = expression._ufl_expr_reconstruct_(operand)
226+
v, _ = expression.argument_slots()
227+
expression = ufl.Interpolate(operand, v)
227228
else:
228229
expression = operand
229230

0 commit comments

Comments
 (0)