Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import ufl
from ufl import as_ufl, as_tensor
from ufl.algorithms import extract_coefficients
from finat.ufl import VectorElement
import finat

Expand All @@ -18,6 +19,7 @@

import firedrake
import firedrake.matrix as matrix
import firedrake.utils as utils
from firedrake import ufl_expr
from firedrake import slate
from firedrake import solving
Expand Down Expand Up @@ -224,6 +226,9 @@ def set(self, r, val):
def integrals(self):
raise NotImplementedError("integrals() method has to be overwritten")

def coefficients(self):
raise NotImplementedError("coefficients() method has to be overwritten")

@PETSc.Log.EventDecorator()
def as_subspace(self, field, V, use_split):
fs = self._function_space
Expand Down Expand Up @@ -315,8 +320,6 @@ def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=Fal
V = V.sub(index)
if g is None:
g = self._original_arg
if isinstance(g, firedrake.Function) and g.function_space() != V:
g = firedrake.Function(V).interpolate(g)
if sub_domain is None:
sub_domain = self.sub_domain
if field is not None:
Expand All @@ -342,11 +345,11 @@ def function_arg(self, g):
del self._function_arg_update
except AttributeError:
pass
self._coefficients = ()
V = self.function_space()
if isinstance(g, firedrake.Function) and g.ufl_element().family() != "Real":
if g.function_space() != V:
raise RuntimeError("%r is defined on incompatible FunctionSpace!" % g)
if isinstance(g, firedrake.Function) and g.function_space() == V:
self._function_arg = g
self._coefficients = (g,)
elif isinstance(g, ufl.classes.Zero):
if g.ufl_shape and g.ufl_shape != V.value_shape:
raise ValueError(f"Provided boundary value {g} does not match shape of space")
Expand All @@ -355,17 +358,18 @@ def function_arg(self, g):
elif isinstance(g, ufl.classes.Expr):
if g.ufl_shape != V.value_shape:
raise RuntimeError(f"Provided boundary value {g} does not match shape of space")
self._function_arg = firedrake.Function(V)
try:
self._function_arg = firedrake.Function(V)
interpolator = firedrake.get_interpolator(firedrake.interpolate(g, V))
# Call this here to check if the element supports interpolation
# TODO: It's probably better to have a more explicit way of checking this
interpolator._get_callable()
self._function_arg_update = partial(interpolator.assemble, tensor=self._function_arg)
except (NotImplementedError, AttributeError):
self._function_arg_update()
except NotImplementedError:
# Element doesn't implement interpolation
self._function_arg = firedrake.Function(V).project(g)
self._function_arg_update = firedrake.Projector(g, self._function_arg).project
self._function_arg_update()
self._coefficients = tuple(extract_coefficients(g))
else:
try:
g = as_ufl(g)
Expand Down Expand Up @@ -460,6 +464,9 @@ def apply(self, r, u=None):
def integrals(self):
return []

def coefficients(self):
return self._coefficients

def extract_form(self, form_type):
# DirichletBC is directly used in assembly.
return self
Expand Down Expand Up @@ -567,6 +574,10 @@ def reconstruct(self, V, subu, u, field, is_linear):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
return self

def coefficients(self):
splits = (self._F, self._J, self._Jp)
return utils.unique(itertools.chain.from_iterable(f.coefficients() for f in splits))


class EquationBCSplit(BCBase):
r'''Class for a BC tree that stores/manipulates either `F`, `J`, or `Jp`.
Expand Down Expand Up @@ -613,6 +624,10 @@ def sorted_equation_bcs(self):
def integrals(self):
return self.f.integrals()

def coefficients(self):
subs = (self.f, *self.bcs)
return utils.unique(itertools.chain.from_iterable(f.coefficients() for f in subs))

def add(self, bc):
if not isinstance(bc, (DirichletBC, EquationBCSplit)):
raise TypeError("EquationBCSplit.add expects an instance of DirichletBC or EquationBCSplit.")
Expand Down
2 changes: 1 addition & 1 deletion firedrake/mg/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def coarsen_formsum(form, self, coefficient_mapping=None):
@coarsen.register(firedrake.DirichletBC)
def coarsen_bc(bc, self, coefficient_mapping=None):
V = self(bc.function_space(), self, coefficient_mapping=coefficient_mapping)
val = self(bc.function_arg, self, coefficient_mapping=coefficient_mapping)
val = self(bc._original_arg, self, coefficient_mapping=coefficient_mapping)
subdomain = bc.sub_domain

return type(bc)(V, val, subdomain)
Expand Down
5 changes: 3 additions & 2 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,12 @@ def solve(self, bounds=None):
# Make sure appcontext is attached to every DM from every coefficient and DirichletBC before we solve.
problem = self._problem
forms = (problem.F, problem.J, problem.Jp)
coefficients = utils.unique(chain.from_iterable(form.coefficients() for form in forms if form is not None))
objs = (*forms, *problem.bcs)
coefficients = utils.unique(chain.from_iterable(f.coefficients() for f in objs if f is not None))
solution_dm = self.snes.getDM()
# Grab the unique DMs for this problem
problem_dms = []
for c in chain(coefficients, problem.dirichlet_bcs()):
for c in coefficients:
dm = c.function_space().dm
if dm == solution_dm:
# Make sure the solution dm is visited last
Expand Down
26 changes: 19 additions & 7 deletions tests/firedrake/regression/test_bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,29 @@ def test_zero_bcs_wrong_fs(V, f2):
bc.zero(f2)


def test_init_bcs_wrong_fs(V, f2):
"Initialise a DirichletBC with a Function on an incompatible FunctionSpace."
with pytest.raises(RuntimeError):
DirichletBC(V, f2, 1)
def test_init_bcs_mismatching_fs(V, f2):
"Initialise a DirichletBC with a Function on a different FunctionSpace."
if V.value_shape == f2.ufl_shape:
bc = DirichletBC(V, f2, 1)
g = bc.function_arg
assert g.function_space() == V
assert errornorm(f2, g) < 1E-12
else:
with pytest.raises(RuntimeError):
DirichletBC(V, f2, 1)


def test_set_bcs_wrong_fs(V, f2):
"Set a DirichletBC to a Function on an incompatible FunctionSpace."
def test_set_bcs_mismatching_fs(V, f2):
"Set a DirichletBC to a Function on a different FunctionSpace."
bc = DirichletBC(V, 32, 1)
with pytest.raises(RuntimeError):
if V.value_shape == f2.ufl_shape:
bc.set_value(f2)
g = bc.function_arg
assert g.function_space() == V
assert errornorm(f2, g) < 1E-12
else:
with pytest.raises(RuntimeError):
bc.set_value(f2)


def test_homogeneous_bcs(a, u, V):
Expand Down
Loading