diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 5652943c6f..1ec4327557 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -102,7 +102,9 @@ def init_petsc(): NonNestedHierarchy, SemiCoarsenedExtrudedHierarchy, SubmeshHierarchy, prolong, restrict, inject, TransferManager, OpenCascadeMeshHierarchy, AdaptiveMeshHierarchy, - AdaptiveTransferManager + AdaptiveTransferManager, + CoarsePatchTransferManager, + FinePatchTransferManager, ) from firedrake.norms import errornorm, norm # noqa: F401 from firedrake.nullspace import VectorSpaceBasis, MixedVectorSpaceBasis # noqa: F401 diff --git a/firedrake/mg/__init__.py b/firedrake/mg/__init__.py index c73e5c7849..09cd8db55c 100644 --- a/firedrake/mg/__init__.py +++ b/firedrake/mg/__init__.py @@ -10,3 +10,6 @@ from firedrake.mg.opencascade_mh import OpenCascadeMeshHierarchy # noqa F401 from firedrake.mg.adaptive_hierarchy import AdaptiveMeshHierarchy # noqa F401 from firedrake.mg.adaptive_transfer_manager import AdaptiveTransferManager # noqa: F401 +from firedrake.mg.robust_transfer_manager import ( # noqa: F401 + CoarsePatchTransferManager, FinePatchTransferManager, +) diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index 749466f05e..ff9983a8d4 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -345,7 +345,7 @@ def restrict_kernel(Vf, Vc): def inject_kernel(Vf, Vc): - if Vc.finat_element.is_dg(): + if Vc.finat_element.is_dg() and Vf.ufl_element() == Vc.ufl_element(): hierarchy, level = utils.get_level(Vc.mesh()) if Vf.extruded: assert Vc.extruded diff --git a/firedrake/mg/mesh.py b/firedrake/mg/mesh.py index fecbae149e..1d71cf7300 100644 --- a/firedrake/mg/mesh.py +++ b/firedrake/mg/mesh.py @@ -11,6 +11,7 @@ from functools import cached_property from firedrake import utils +from firedrake.petsc import PETSc from firedrake.cython import mgimpl as impl from .utils import set_level @@ -29,6 +30,8 @@ class HierarchyBase(object): :arg refinements_per_level: number of mesh refinements each multigrid level should "see". :arg nested: Is this mesh hierarchy nested? + :arg coarse_facet_label: Optional subdomain ID to label the coarse facets on + each level of the hierarchy. .. note:: @@ -37,7 +40,7 @@ class HierarchyBase(object): :func:`ExtrudedMeshHierarchy`, or :func:`NonNestedHierarchy`. """ def __init__(self, meshes, coarse_to_fine_cells, fine_to_coarse_cells, - refinements_per_level=1, nested=False): + refinements_per_level=1, nested=False, coarse_facet_label=None): petsctools.cite("Mitchell2016") self._meshes = tuple(meshes) self.meshes = tuple(meshes[::refinements_per_level]) @@ -45,6 +48,7 @@ def __init__(self, meshes, coarse_to_fine_cells, fine_to_coarse_cells, self.fine_to_coarse_cells = fine_to_coarse_cells self.refinements_per_level = refinements_per_level self.nested = nested + self._coarse_facet_label = coarse_facet_label for level, m in enumerate(meshes): set_level(m, self, Fraction(level, refinements_per_level)) for level, m in enumerate(self): @@ -79,7 +83,8 @@ def MeshHierarchy(mesh, refinement_levels, netgen_flags=False, reorder=None, distribution_parameters=None, callbacks=None, - mesh_builder=firedrake.Mesh): + mesh_builder=firedrake.Mesh, + coarse_facet_label=None): """Build a hierarchy of meshes by uniformly refining a coarse mesh. Parameters @@ -109,6 +114,10 @@ def MeshHierarchy(mesh, refinement_levels, callback receives the refined DM (and the current level). mesh_builder Function to turn a DM into a ``Mesh``. Used by pyadjoint. + coarse_facet_label : int | None + Optional subdomain ID to label the coarse facets on each + level of the hierarchy. + Returns ------- A :py:class:`HierarchyBase` object representing the @@ -139,11 +148,29 @@ def MeshHierarchy(mesh, refinement_levels, else: before = after = lambda dm, i: None for i in range(refinement_levels*refinements_per_level): + if coarse_facet_label is not None: + # Create a temporary label on all the facets of the coarse dm + # to label every coarse facet on the fine dm + fstart, fend = cdm.getHeightStratum(1) + iset = PETSc.IS().createStride(fend-fstart, first=fstart, comm=cdm.comm) + cdm.createLabel("temp_label") + label = cdm.getLabel("temp_label") + label.setStratumIS(1, iset) + if i % refinements_per_level == 0: before(cdm, i) rdm = cdm.refine() if i % refinements_per_level == 0: after(rdm, i) + + if coarse_facet_label is not None: + # Move coarse_facet_label into FACE_SETS_LABEL + iset = rdm.getLabel("temp_label").getStratumIS(1) + label = rdm.getLabel(dmcommon.FACE_SETS_LABEL) + label.setStratumIS(coarse_facet_label, iset) + rdm.removeLabel("temp_label") + cdm.removeLabel("temp_label") + dms.append(rdm) cdm = rdm # Fix up coords if refining embedded circle or sphere @@ -192,7 +219,8 @@ def MeshHierarchy(mesh, refinement_levels, fine_to_coarse_cells = dict((Fraction(i, refinements_per_level), f2c) for i, f2c in enumerate(fine_to_coarse_cells)) return HierarchyBase(meshes, coarse_to_fine_cells, fine_to_coarse_cells, - refinements_per_level, nested=True) + refinements_per_level, nested=True, + coarse_facet_label=coarse_facet_label) def ExtrudedMeshHierarchy(base_hierarchy, height, base_layer=-1, refinement_ratio=2, layers=None, diff --git a/firedrake/mg/robust_transfer_manager.py b/firedrake/mg/robust_transfer_manager.py new file mode 100644 index 0000000000..3bd52f5464 --- /dev/null +++ b/firedrake/mg/robust_transfer_manager.py @@ -0,0 +1,342 @@ +from functools import partial +from ufl import H1 +from finat.ufl import FiniteElement, NodalEnrichedElement, TensorElement + +from firedrake import dmhooks +from firedrake.assemble import assemble, get_assembler +from firedrake.bcs import DirichletBC, restricted_function_space +from firedrake.function import Function +from firedrake.functionspace import MixedFunctionSpace +from firedrake.interpolation import interpolate, get_interpolator +from firedrake.slate import Inverse, Tensor +from firedrake.ufl_expr import action, TestFunction, TrialFunction +from firedrake.utils import complex_mode +from firedrake.variational_solver import LinearVariationalProblem, LinearVariationalSolver +from .embedded import TransferManager +from .utils import get_level + + +__all__ = ("CoarsePatchTransferManager", "FinePatchTransferManager", "RobustTransferManager") + + +DEFAULT_PATCH_PARAMETERS = { + "ksp_type": "preonly", + "pc_type": "bjacobi", + "sub_pc_type": "lu", + "sub_pc_factor_mat_solver_type": "petsc", + "sub_pc_factor_shift_type": "nonzero", +} + + +class RobustTransferManager(TransferManager): + """An object for managing transfers between levels in a multigrid hierarchy + via standard interpolation into subdomain boundaries followed by an extension + into the interior of the subdomains by solving the homogeneous PDE. + + The subdomain solver options are under the prefix ``mg_transfer_``. + """ + + class TransferCallable: + """Internal class to apply a sequence on linear operations + by transfering the input and output into local buffers + referenced in the list of callables. + """ + def __init__(self, x_buffer, y_buffer, callables): + self.x_buffer = x_buffer + self.y_buffer = y_buffer + self.callables = callables + + def __call__(self, x, y): + self.x_buffer.assign(x) + for c in self.callables: + c() + return y.assign(self.y_buffer) + + def form(self, V): + """Get the preconditioning Form in the _SNESContext of a FunctionSpace.""" + form = None + ctx = dmhooks.get_appctx(V.dm) + if ctx is not None: + form = ctx._problem.Jp or ctx._problem.J + # Only return form if the trial space is V + if form.arguments()[1].function_space() != V: + form = None + return form + + def options_prefix(self, V): + """Get the options prefix in the _SNESContext of a FunctionSpace.""" + ctx = dmhooks.get_appctx(V.dm) + prefix = ctx.options_prefix if ctx else "" + prefix += "mg_transfer_" + return prefix + + def auxiliary_target_space(self, V): + """Construct an auxiliary target FunctionSpace.""" + raise NotImplementedError("Must be implemented by subclass.") + + def build_auxiliary_target_space(self, V): + """Dispatch auxiliary_target_space on the subspaces of a MixedFunctionSpace.""" + subspaces = tuple(map(self.auxiliary_target_space, V)) + return MixedFunctionSpace(subspaces) if len(subspaces) > 1 else subspaces[0] + + def build_patch_solver(self, form, V): + """Build a solver to extend the solution from the residual in the + auxiliary space into the entire space V.""" + raise NotImplementedError("Must be implemented by subclass.") + + def get_patch_solver(self, form, V): + """Cache the patch solver.""" + cache = form._cache + key = (type(self).__name__, "patch_solver") + try: + return cache[key] + except KeyError: + return cache.setdefault(key, self.build_patch_solver(form, V)) + + def build_transfer_callables(self, form, Vc, Vf): + """Construct prolongation and restriction TransferCallables.""" + uc = Function(Vc) + uf = Function(Vf) + P = self.prolong_callable(form, uc, uf) + rc = Function(Vc.dual(), val=uc.dat) + rf = Function(Vf.dual(), val=uf.dat) + R = self.restrict_callable(form, rf, rc) + return P, R + + def get_transfer_callables(self, Vc, Vf): + """Cache the prolongation and restriction TransferCallables.""" + form = self.form(Vf) + cache = form._cache + key = (type(self).__name__, "transfer_callables") + try: + return cache[key] + except KeyError: + return cache.setdefault(key, self.build_transfer_callables(form, Vc, Vf)) + + def prolong_callable(self, form, uc, uf): + """Return a TransferCallable that interpolates uc into uf such that + uc = uf on patch boundaries and form(v, uf) = 0 for all v on the patch + subspaces.""" + V = uf.function_space() + V_aux = self.build_auxiliary_target_space(V) + u_aux = Function(V_aux) + + solver, r_patch, u_patch = self.get_patch_solver(form, V) + if solver is None: + # patch problem is empty + callables = ( + partial(TransferManager.prolong, self, uc, u_aux), + partial(u_aux.dat.copy, uf.dat), + ) + else: + if len(set(f.ufl_element() for f in (uf, u_aux, u_patch))) == 1: + copy_update = partial(uf.assign, u_aux - u_patch) + else: + wtest = TestFunction(V.dual()) + Iv = get_interpolator(interpolate(u_aux - u_patch, wtest)) + copy_update = partial(Iv.assemble, tensor=uf) + + v_patch, = r_patch.arguments() + residual = get_assembler(form(v_patch, u_aux)) + callables = ( + partial(TransferManager.prolong, self, uc, u_aux), + partial(residual.assemble, tensor=r_patch), + solver, + copy_update, + ) + return self.TransferCallable(uc, uf, callables) + + def restrict_callable(self, form, rf, rc): + """Return a TransferCallable with the adjoint of prolong.""" + V = rf.function_space().dual() + V_aux = self.build_auxiliary_target_space(V) + r_aux = Function(V_aux.dual()) + Au = Function(V_aux.dual()) + + solver, r_patch, u_patch = self.get_patch_solver(form, V) + if solver is None: + # patch problem is empty + callables = ( + partial(rf.dat.copy, r_aux.dat), + partial(TransferManager.restrict, self, r_aux, rc), + ) + else: + + def copy_callable(source, dest): + if source.ufl_element() == dest.ufl_element(): + return partial(dest.assign, source) + else: + R = get_interpolator(interpolate(dest.arguments()[0], source)) + return partial(R.assemble, tensor=dest) + + v_aux, = r_aux.arguments() + residual = get_assembler(form(u_patch, v_aux)) + callables = ( + copy_callable(rf, r_aux), + copy_callable(rf, r_patch), + solver, + partial(residual.assemble, tensor=Au), + partial(r_aux.assign, r_aux - Au), + partial(TransferManager.restrict, self, r_aux, rc), + ) + return self.TransferCallable(rf, rc, callables) + + def prolong(self, uc, uf): + Vc = uc.function_space() + Vf = uf.function_space() + form = self.form(Vf) + if form is not None: + P, R = self.get_transfer_callables(Vc, Vf) + return P(uc, uf) + else: + return super().prolong(uc, uf) + + def restrict(self, rf, rc): + Vc = rc.function_space().dual() + Vf = rf.function_space().dual() + form = self.form(Vf) + if form is not None: + P, R = self.get_transfer_callables(Vc, Vf) + return R(rf, rc) + else: + return super().restrict(rf, rc) + + +class CoarsePatchTransferManager(RobustTransferManager): + """An object for managing transfers between levels in a multigrid hierarchy + via standard interpolation into coarse cell boundaries followed by an extension + into the interior of the coarse cell patches by solving the homogeneous PDE. + + This class will raise an error when the coarse facets are not labeled across + the MeshHierarchy. + """ + + def auxiliary_target_space(self, V): + """Construct a facet space for inter-grid interpolation.""" + element = V.ufl_element() + if (element.family() in {"Lagrange", "Discontinuous Lagrange"} + and V.finat_element.complex.is_macrocell()): + # If the element is a split variant, reconstruct the unsplit one + element = element.reconstruct(variant=None) + + # FIXME + restrict = False + if not V.finat_element.is_dg() and restrict: + entity_dofs = V.finat_element.entity_dofs() + sd = max(entity_dofs) + if len(entity_dofs[sd][0]) > 0: + element = element["facet"] + + if element == V.ufl_element(): + return V + return V.collapse().reconstruct(element=element) + + def build_patch_solver(self, form, V): + """Solve form(test, u_patch) = r_patch on coarse cell patches.""" + V_patch = self.get_patch_function_space(V) + u_patch = Function(V_patch) + r_patch = Function(V_patch.dual()) + test = TestFunction(V_patch) + trial = TrialFunction(V_patch) + + if len(V_patch) == 1: + bcs = DirichletBC(V_patch, 0, V_patch.boundary_set) + else: + bcs = [DirichletBC(V_patch.sub(i), 0, V_.boundary_set) + for i, V_ in enumerate(V_patch) if len(V_.boundary_set) > 0] + + a = assemble(form(test, trial), bcs=bcs) + problem = LinearVariationalProblem(a, r_patch, u_patch) + solver = LinearVariationalSolver(problem, + solver_parameters=DEFAULT_PATCH_PARAMETERS, + options_prefix=self.options_prefix(V)) + return (solver.solve, r_patch, u_patch) + + def get_patch_function_space(self, V): + """Construct a space with boundary conditions on the coarse facets.""" + boundary_sets = [] + for V_ in V: + if V_.finat_element.is_dg(): + boundary_sets.append(()) + else: + mesh = V_.mesh() + mh, _ = get_level(mesh) + label = mh._coarse_facet_label + if label not in mesh.interior_facets.unique_markers: + raise ValueError("Expecting a hierarchy with a coarse facet label.") + boundary_sets.append((label,)) + return restricted_function_space(V, boundary_sets) + + +class FinePatchTransferManager(RobustTransferManager): + """An object for managing transfers between levels in a multigrid hierarchy + via standard interpolation into fine cell boundaries followed by an extension + into the interior of the fine cells by solving the homogeneous PDE. + """ + + def auxiliary_target_space(self, V): + """Construct a facet space for inter-grid interpolation.""" + if V.finat_element.is_dg(): + return V + + element = V.ufl_element() + quad_scheme = element._quad_scheme + if V.finat_element.complex.is_macrocell(): + # Macroelements require a composite quadrature scheme + if element.sobolev_space == H1 and V.finat_element.degree < 4: + quad_scheme = "powell-sabin,KMV(2)" + else: + quad_scheme = "powell-sabin" + + tdim = V.mesh().topological_dimension + if V.finat_element.has_pointwise_dual_basis and V.finat_element.degree == tdim: + # Facet moment degrees of freedom for CG elements + CG = FiniteElement("CG", degree=tdim, variant="chebyshev") + CR = FiniteElement("CR", degree=1, variant="integral", quad_scheme=quad_scheme) + element = NodalEnrichedElement(CG["ridge"], CR) + if V.value_shape != (): + element = TensorElement(element, shape=V.value_shape) + else: + # Take the facet element with the new quadrature scheme + if quad_scheme != element._quad_scheme: + element = element.reconstruct(quad_scheme=quad_scheme) + entity_dofs = V.finat_element.entity_dofs() + sd = max(entity_dofs) + if len(entity_dofs[sd][0]) > 0: + element = element["facet"] + + if element == V.ufl_element(): + return V + return V.collapse().reconstruct(element=element) + + def build_patch_solver(self, form, V): + """Solve form(test, u_patch) = r_patch on fine cell patches""" + tdim = V.mesh().topological_dimension + if any(len(V_.finat_element.entity_dofs()[tdim][0]) == 0 for V_ in V): + # The element has no interior DOFs + return (None, None, None) + + # Reconstruct the space on the interior with standard quadrature + element = V.ufl_element() + if element._quad_scheme is not None: + element = element.reconstruct(quad_scheme=None) + V_patch = V.reconstruct(element=element["interior"]) + u_patch = Function(V_patch) + r_patch = Function(V_patch.dual()) + test = TestFunction(V_patch) + trial = TrialFunction(V_patch) + a = form(test, trial) + + use_slate_for_inverse = not complex_mode + if use_slate_for_inverse: + ainv = assemble(Inverse(Tensor(a))) + assembler = get_assembler(action(ainv, r_patch)) + solve = partial(assembler.assemble, tensor=u_patch) + else: + a = assemble(a) + problem = LinearVariationalProblem(a, r_patch, u_patch) + solver = LinearVariationalSolver(problem, + solver_parameters=DEFAULT_PATCH_PARAMETERS, + options_prefix=self.options_prefix(V)) + solve = solver.solve + return (solve, r_patch, u_patch) diff --git a/firedrake/mg/ufl_utils.py b/firedrake/mg/ufl_utils.py index b40bfbd7ac..ba67a666f7 100644 --- a/firedrake/mg/ufl_utils.py +++ b/firedrake/mg/ufl_utils.py @@ -116,6 +116,13 @@ def coarsen_form(form, self, coefficient_mapping=None): return form +@coarsen.register(ufl.Interpolate) +def coarsen_interpolate(interp, self, coefficient_mapping=None): + dual_arg, operand = interp.argument_slots() + return interp._ufl_expr_reconstruct_(self(operand, self, coefficient_mapping=coefficient_mapping), + self(dual_arg, self, coefficient_mapping=coefficient_mapping)) + + @coarsen.register(ufl.FormSum) def coarsen_formsum(form, self, coefficient_mapping=None): return type(form)(*[(self(ci, self, coefficient_mapping=coefficient_mapping), diff --git a/tests/firedrake/multigrid/test_robust_transfer.py b/tests/firedrake/multigrid/test_robust_transfer.py new file mode 100644 index 0000000000..7b4637bc4b --- /dev/null +++ b/tests/firedrake/multigrid/test_robust_transfer.py @@ -0,0 +1,80 @@ +import pytest +from firedrake import * + + +@pytest.fixture +def hierarchy(): + distribution_parameters = {"overlap_type": (DistributedMeshOverlapType.VERTEX, 1)} + nx = 4 + refine = 3 + base = UnitSquareMesh(nx, nx, distribution_parameters=distribution_parameters) + mh = MeshHierarchy(base, refine, coarse_facet_label=1000) + return mh + + +@pytest.fixture +def mesh(hierarchy): + return hierarchy[-1] + + +@pytest.fixture +def V(mesh): + degree = mesh.topological_dimension + V = VectorFunctionSpace(mesh, "CG", degree, variant="alfeld") + return V + + +@pytest.fixture +def solver(V): + uh = Function(V) + u = TrialFunction(V) + v = TestFunction(V) + x = SpatialCoordinate(V.mesh()) + uexact = x * sum(x) + + mu = Constant(1) + lam = Constant(1E4) + eps = lambda u: sym(grad(u)) + a = inner(2*mu*eps(u), eps(v))*dx + inner(lam*div(u), div(v))*dx + L = a(v, uexact) + bcs = DirichletBC(V, uexact, "on_boundary") + + solver_parameters = { + "mat_type": "aij", + "snes_type": "ksponly", + "ksp_type": "cg", + "ksp_rtol": 1e-8, + "ksp_monitor": None, + "pc_type": "mg", + "mg_levels": { + "ksp_type": "chebyshev", + "ksp_max_it": 2, + "pc_type": "python", + "pc_python_type": "firedrake.ASMStarPC", + "pc_star_sub_sub_pc_type": "cholesky", + "pc_star_sub_sub_pc_factor_mat_solver_type": "petsc", + "pc_star_mat_ordering_type": "nd", + "pc_star_use_coloring": True, + }, + "mg_coarse": { + "mat_type": "aij", + "pc_type": "cholesky", + "pc_factor_mat_solver_type": "mumps", + } + } + + problem = LinearVariationalProblem(a, L, uh, bcs=bcs) + solver = LinearVariationalSolver(problem, + solver_parameters=solver_parameters) + return solver + + +@pytest.mark.parallel([1, 3]) +@pytest.mark.parametrize("create_transfer", [CoarsePatchTransferManager, FinePatchTransferManager]) +def test_robust_transfer(solver, create_transfer): + tm = create_transfer() + u = solver._problem.u + u.zero() + solver.set_transfer_manager(tm) + solver.solve() + assert solver.snes.ksp.getIterationNumber() < 15