diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 2203f2b3d5..ed51e1a48d 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -8,6 +8,7 @@ import ufl from ufl import as_ufl, as_tensor +from ufl.algorithms import extract_coefficients from finat.ufl import VectorElement import finat @@ -18,6 +19,7 @@ import firedrake import firedrake.matrix as matrix +import firedrake.utils as utils from firedrake import ufl_expr from firedrake import slate from firedrake import solving @@ -224,6 +226,9 @@ def set(self, r, val): def integrals(self): raise NotImplementedError("integrals() method has to be overwritten") + def coefficients(self): + raise NotImplementedError("coefficients() method has to be overwritten") + @PETSc.Log.EventDecorator() def as_subspace(self, field, V, use_split): fs = self._function_space @@ -315,8 +320,6 @@ def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=Fal V = V.sub(index) if g is None: g = self._original_arg - if isinstance(g, firedrake.Function) and g.function_space() != V: - g = firedrake.Function(V).interpolate(g) if sub_domain is None: sub_domain = self.sub_domain if field is not None: @@ -342,11 +345,11 @@ def function_arg(self, g): del self._function_arg_update except AttributeError: pass + self._coefficients = () V = self.function_space() - if isinstance(g, firedrake.Function) and g.ufl_element().family() != "Real": - if g.function_space() != V: - raise RuntimeError("%r is defined on incompatible FunctionSpace!" % g) + if isinstance(g, firedrake.Function) and g.function_space() == V: self._function_arg = g + self._coefficients = (g,) elif isinstance(g, ufl.classes.Zero): if g.ufl_shape and g.ufl_shape != V.value_shape: raise ValueError(f"Provided boundary value {g} does not match shape of space") @@ -355,17 +358,18 @@ def function_arg(self, g): elif isinstance(g, ufl.classes.Expr): if g.ufl_shape != V.value_shape: raise RuntimeError(f"Provided boundary value {g} does not match shape of space") + self._function_arg = firedrake.Function(V) try: - self._function_arg = firedrake.Function(V) interpolator = firedrake.get_interpolator(firedrake.interpolate(g, V)) # Call this here to check if the element supports interpolation # TODO: It's probably better to have a more explicit way of checking this - interpolator._get_callable() self._function_arg_update = partial(interpolator.assemble, tensor=self._function_arg) - except (NotImplementedError, AttributeError): + self._function_arg_update() + except NotImplementedError: # Element doesn't implement interpolation - self._function_arg = firedrake.Function(V).project(g) self._function_arg_update = firedrake.Projector(g, self._function_arg).project + self._function_arg_update() + self._coefficients = tuple(extract_coefficients(g)) else: try: g = as_ufl(g) @@ -460,6 +464,9 @@ def apply(self, r, u=None): def integrals(self): return [] + def coefficients(self): + return self._coefficients + def extract_form(self, form_type): # DirichletBC is directly used in assembly. return self @@ -567,6 +574,10 @@ def reconstruct(self, V, subu, u, field, is_linear): def _as_nonlinear_variational_problem_arg(self, is_linear=False): return self + def coefficients(self): + splits = (self._F, self._J, self._Jp) + return utils.unique(itertools.chain.from_iterable(f.coefficients() for f in splits)) + class EquationBCSplit(BCBase): r'''Class for a BC tree that stores/manipulates either `F`, `J`, or `Jp`. @@ -613,6 +624,10 @@ def sorted_equation_bcs(self): def integrals(self): return self.f.integrals() + def coefficients(self): + subs = (self.f, *self.bcs) + return utils.unique(itertools.chain.from_iterable(f.coefficients() for f in subs)) + def add(self, bc): if not isinstance(bc, (DirichletBC, EquationBCSplit)): raise TypeError("EquationBCSplit.add expects an instance of DirichletBC or EquationBCSplit.") diff --git a/firedrake/mg/ufl_utils.py b/firedrake/mg/ufl_utils.py index b40bfbd7ac..c215026a39 100644 --- a/firedrake/mg/ufl_utils.py +++ b/firedrake/mg/ufl_utils.py @@ -126,7 +126,7 @@ def coarsen_formsum(form, self, coefficient_mapping=None): @coarsen.register(firedrake.DirichletBC) def coarsen_bc(bc, self, coefficient_mapping=None): V = self(bc.function_space(), self, coefficient_mapping=coefficient_mapping) - val = self(bc.function_arg, self, coefficient_mapping=coefficient_mapping) + val = self(bc._original_arg, self, coefficient_mapping=coefficient_mapping) subdomain = bc.sub_domain return type(bc)(V, val, subdomain) diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 4031bf3c5c..a6b5bb3b31 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -359,11 +359,12 @@ def solve(self, bounds=None): # Make sure appcontext is attached to every DM from every coefficient and DirichletBC before we solve. problem = self._problem forms = (problem.F, problem.J, problem.Jp) - coefficients = utils.unique(chain.from_iterable(form.coefficients() for form in forms if form is not None)) + objs = (*forms, *problem.bcs) + coefficients = utils.unique(chain.from_iterable(f.coefficients() for f in objs if f is not None)) solution_dm = self.snes.getDM() # Grab the unique DMs for this problem problem_dms = [] - for c in chain(coefficients, problem.dirichlet_bcs()): + for c in coefficients: dm = c.function_space().dm if dm == solution_dm: # Make sure the solution dm is visited last diff --git a/tests/firedrake/regression/test_bcs.py b/tests/firedrake/regression/test_bcs.py index 9e43ba805b..c89ca2fb5b 100644 --- a/tests/firedrake/regression/test_bcs.py +++ b/tests/firedrake/regression/test_bcs.py @@ -83,17 +83,29 @@ def test_zero_bcs_wrong_fs(V, f2): bc.zero(f2) -def test_init_bcs_wrong_fs(V, f2): - "Initialise a DirichletBC with a Function on an incompatible FunctionSpace." - with pytest.raises(RuntimeError): - DirichletBC(V, f2, 1) +def test_init_bcs_mismatching_fs(V, f2): + "Initialise a DirichletBC with a Function on a different FunctionSpace." + if V.value_shape == f2.ufl_shape: + bc = DirichletBC(V, f2, 1) + g = bc.function_arg + assert g.function_space() == V + assert errornorm(f2, g) < 1E-12 + else: + with pytest.raises(RuntimeError): + DirichletBC(V, f2, 1) -def test_set_bcs_wrong_fs(V, f2): - "Set a DirichletBC to a Function on an incompatible FunctionSpace." +def test_set_bcs_mismatching_fs(V, f2): + "Set a DirichletBC to a Function on a different FunctionSpace." bc = DirichletBC(V, 32, 1) - with pytest.raises(RuntimeError): + if V.value_shape == f2.ufl_shape: bc.set_value(f2) + g = bc.function_arg + assert g.function_space() == V + assert errornorm(f2, g) < 1E-12 + else: + with pytest.raises(RuntimeError): + bc.set_value(f2) def test_homogeneous_bcs(a, u, V):