Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
result = expr.assemble(assembly_opts=opts)
return tensor.assign(result) if tensor else result
elif isinstance(expr, ufl.Interpolate):
if not isinstance(expr, firedrake.Interpolate):
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
expr = firedrake.Interpolate(*reversed(expr.dual_args()))
Comment thread
pbrubeck marked this conversation as resolved.
Outdated
orig_expr = expr
# Replace assembled children
_, operand = expr.argument_slots()
v, *assembled_operand = args
Expand Down Expand Up @@ -588,28 +591,47 @@ def base_form_assembly_visitor(self, expr, tensor, *args):

# Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument.
if not is_adjoint and rank == 2:
_, v1 = expr.arguments()
operand = ufl.replace(operand, {v1: v1.reconstruct(number=0)})
v0, v1 = expr.arguments()
expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()),
v1: v1.reconstruct(number=v0.number())})
Comment thread
pbrubeck marked this conversation as resolved.
Comment thread
pbrubeck marked this conversation as resolved.
v, operand = expr.argument_slots()

# Matrix-free adjoint interpolation is only implemented by SameMeshInterpolator
# so we need assemble the interpolator matrix if the meshes are different
target_mesh = V.mesh()
source_mesh = extract_unique_domain(operand) or target_mesh
if is_adjoint and rank < 2 and source_mesh is not target_mesh:
expr = reconstruct_interp(operand, v=V)
matfree = (rank == len(expr.arguments())) and (rank < 2)

# Get the interpolator
interp_data = expr.interp_data
interp_data = expr.interp_data.copy()
default_missing_val = interp_data.pop('default_missing_val', None)
interpolator = firedrake.Interpolator(operand, V, **interp_data)
if matfree and ((is_adjoint and rank == 1) or rank == 0):
Comment thread
pbrubeck marked this conversation as resolved.
Comment thread
pbrubeck marked this conversation as resolved.
# Adjoint interpolation of a Cofunction or the action of a
# Cofunction on an interpolated Function require INC access
# on the output tensor
interp_data["access"] = op2.INC

if rank == 1 and matfree and isinstance(tensor, firedrake.Function):
V = tensor
interpolator = firedrake.Interpolator(expr, V, **interp_data)

# Assembly
if rank == 0:
if matfree:
# Assembling the operator
return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val)
elif rank == 0:
# Assembling the double action.
Iu = interpolator._interpolate(default_missing_val=default_missing_val)
return assemble(ufl.Action(v, Iu), tensor=tensor)
elif rank == 1:
# Assembling the action of the Jacobian adjoint.
if is_adjoint:
return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val)
# Assembling the Jacobian action.
elif interpolator.nargs:
return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val)
# Assembling the operator
elif tensor is None:
return interpolator._interpolate(default_missing_val=default_missing_val)
else:
return firedrake.Interpolator(operand, tensor, **interp_data)._interpolate(default_missing_val=default_missing_val)
return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val)
elif rank == 2:
res = tensor.petscmat if tensor else PETSc.Mat()
# Get the interpolation matrix
Expand All @@ -618,14 +640,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
if is_adjoint:
# Out-of-place Hermitian transpose
petsc_mat.hermitianTranspose(out=res)
else:
elif res:
# Copy the interpolation matrix into the output tensor
petsc_mat.copy(result=res)
Comment thread
pbrubeck marked this conversation as resolved.
else:
res = petsc_mat
if tensor is None:
tensor = self.assembled_matrix(expr, res)
tensor = self.assembled_matrix(orig_expr, res)
return tensor
else:
# The case rank == 0 is handled via the DAG restructuring
raise ValueError("Incompatible number of arguments.")
elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)):
return tensor.assign(expr)
Expand Down
Loading
Loading