Skip to content

Commit 5113864

Browse files
committed
Reverse indices for dual_arg
1 parent a33cfb3 commit 5113864

3 files changed

Lines changed: 5 additions & 6 deletions

File tree

firedrake/assemble.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,14 +596,13 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
596596
same_mesh = source_mesh.topology is target_mesh.topology
597597

598598
# Get the interpolator
599-
interp_data = expr.interp_data
599+
interp_data = expr.interp_data.copy()
600600
default_missing_val = interp_data.pop('default_missing_val', None)
601601
if same_mesh and ((is_adjoint and rank == 1) or rank == 0):
602-
interp_data = interp_data.copy()
603602
interp_data["access"] = op2.INC
604603

605604
dual_arg = v if same_mesh else V
606-
interp_expr = firedrake.Interpolate(operand, v=dual_arg, **interp_data)
605+
interp_expr = reconstruct_interp(operand, v=dual_arg)
607606
interpolator = firedrake.Interpolator(interp_expr, V, **interp_data)
608607

609608
# Assembly

firedrake/interpolation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ def callable():
938938

939939
# Interpolate each sub expression into each function space
940940
for Vsub, sub_tensor, sub_op, sub_dual in zip(V, tensor, operands, duals):
941-
sub_expr = ufl.Interpolate(sub_op, sub_dual)
941+
sub_expr = expr._ufl_expr_reconstruct_(sub_op, sub_dual)
942942
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))
943943

944944
if bcs and rank == 1:
@@ -1050,7 +1050,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10501050
name = kernel.name
10511051
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True,
10521052
flop_count=kernel.flop_count, events=(kernel.event,))
1053-
10541053
parloop_args = [kernel, cell_set]
10551054

10561055
expr = ufl.Interpolate(operand, v=dual_arg)
@@ -1069,6 +1068,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10691068
else:
10701069
copyin = ()
10711070
copyout = ()
1071+
10721072
if isinstance(tensor, op2.Global):
10731073
parloop_args.append(tensor(access))
10741074
elif isinstance(tensor, op2.Dat):

tsfc/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
300300

301301
# Compute the adjoint by contracting against the dual argument
302302
if dual_arg in coefficients:
303-
beta = basis_indices
303+
beta = basis_indices[::-1]
304304
shape = tuple(i.extent for i in beta)
305305
gem_dual = gem.Variable(f"w_{coefficients.index(dual_arg)}", shape)
306306
evaluation = gem.IndexSum(evaluation * gem_dual[beta], beta)

0 commit comments

Comments
 (0)