Skip to content

BUG: adjoint solve fails with Real space when using MUMPS #5140

@tlroy

Description

@tlroy

Describe the bug
Adjoint differentiation fails when the forward model contains a solve on a Real function space (FunctionSpace(mesh, "R"/"Real", 0)). During ReducedFunctional.derivative(), the adjoint solve attempts an LU/MUMPS factorization on a PETSc matrix of type python (matfree-style), which PETSc/MUMPS does not support, causing the adjoint to crash.

Steps to Reproduce
MFE:

from firedrake import *
from firedrake.adjoint import *
from firedrake.petsc import DEFAULT_DIRECT_SOLVER
from petsc4py import PETSc
def assemble_adj_like(rho, R):
    K_trial = TrialFunction(R)
    w = TestFunction(R)
    F = K_trial * w * dx
    L = rho * w * dx
    K = Function(R, name="K_real")
    solve(F == L, K)
    return K
def main():
    PETSc.Sys.Print(f"DEFAULT_DIRECT_SOLVER = {DEFAULT_DIRECT_SOLVER}")
    continue_annotation()
    mesh = UnitSquareMesh(8, 8)
    V = FunctionSpace(mesh, "CG", 1)
    R = FunctionSpace(mesh, "R", 0)
    m = Function(V, name="m").assign(1.0)
    m_ctrl = Control(m)
    K_real = assemble_adj_like(m, R)
    J = assemble(K_real * K_real * dx)
    Jhat = ReducedFunctional(J, m_ctrl)
    # triggers adjoint solve
    Jhat.derivative()
if __name__ == "__main__":
    main()

Expected behavior
Jhat.derivative() should succeed and return a derivative w.r.t. the control. No attempt should be made to apply LU/MUMPS to an incompatible MatType=python matrix during the adjoint solve.

Error message

Traceback (most recent call last):
  File "/shared/tests/mfe.py", line 27, in <module>
    main()
  File "/shared/tests/mfe.py", line 25, in main
    Jhat.derivative()
  File "/usr/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pyadjoint-ad/pyadjoint/reduced_functional.py", line 264, in derivative
    derivatives = compute_derivative(self.functional,
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pyadjoint-ad/pyadjoint/drivers.py", line 42, in compute_derivative
    tape.evaluate_adj(markings=True)
  File "/opt/pyadjoint-ad/pyadjoint/tape.py", line 345, in evaluate_adj
    self._blocks[i].evaluate_adj(markings=markings)
  File "/usr/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pyadjoint-ad/pyadjoint/block.py", line 137, in evaluate_adj
    prepared = self.prepare_evaluate_adj(inputs, adj_inputs, relevant_dependencies)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/firedrake/firedrake/adjoint_utils/blocks/solving.py", line 203, in prepare_evaluate_adj
    adj_sol, adj_sol_bdy = self._assemble_and_solve_adj_eq(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/firedrake/firedrake/adjoint_utils/blocks/solving.py", line 228, in _assemble_and_solve_adj_eq
    firedrake.solve(
  File "petsc4py/PETSc/Log.pyx", line 250, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "petsc4py/PETSc/Log.pyx", line 251, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "/opt/firedrake/firedrake/adjoint_utils/solving.py", line 57, in wrapper
    output = solve(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/firedrake/firedrake/solving.py", line 145, in solve
    return _la_solve(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/firedrake/firedrake/solving.py", line 251, in _la_solve
    solver.solve(x, b)
  File "petsc4py/PETSc/Log.pyx", line 250, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "petsc4py/PETSc/Log.pyx", line 251, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "/opt/firedrake/firedrake/linear_solver.py", line 86, in solve
    super().solve()
  File "petsc4py/PETSc/Log.pyx", line 250, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "petsc4py/PETSc/Log.pyx", line 251, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "/opt/firedrake/firedrake/adjoint_utils/variational_solver.py", line 108, in wrapper
    out = solve(self, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/firedrake/firedrake/variational_solver.py", line 398, in solve
    self.snes.solve(None, work)
  File "petsc4py/PETSc/SNES.pyx", line 1740, in petsc4py.PETSc.SNES.solve
petsc4py.PETSc.Error: error code 92
[0] SNESSolve() at /opt/petsc/src/snes/interface/snes.c:4889
[0] SNESSolve_KSPONLY() at /opt/petsc/src/snes/impls/ksponly/ksponly.c:49
[0] KSPSolve() at /opt/petsc/src/ksp/ksp/interface/itfunc.c:1094
[0] KSPSolve_Private() at /opt/petsc/src/ksp/ksp/interface/itfunc.c:841
[0] KSPSetUp() at /opt/petsc/src/ksp/ksp/interface/itfunc.c:428
[0] PCSetUp() at /opt/petsc/src/ksp/pc/interface/precon.c:1124
[0] PCSetUp_LU() at /opt/petsc/src/ksp/pc/impls/factor/lu/lu.c:79
[0] PCFactorSetUpMatSolverType() at /opt/petsc/src/ksp/pc/impls/factor/factor.c:108
[0] PCFactorSetUpMatSolverType_Factor() at /opt/petsc/src/ksp/pc/impls/factor/factimpl.c:9
[0] MatGetFactor() at /opt/petsc/src/mat/interface/matrix.c:4966
[0] See https://petsc.org/release/overview/linear_solve_table/ for possible LU and Cholesky solvers
[0] MatSolverType mumps does not support matrix type python

Environment:
-Linux
-Fails on both docker image and clusters install

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions