@@ -985,14 +985,18 @@ def callable():
985985 return callable
986986 else :
987987 loops = []
988- expressions = split_interpolate_target (expr )
989988
990989 if access == op2 .INC :
991990 loops .append (tensor .zero )
992991
993992 # Interpolate each sub expression into each function space
994- for Vsub , sub_tensor , sub_expr in zip (V , tensor , expressions ):
995- loops .extend (_interpolator (Vsub , sub_tensor , sub_expr , subset , arguments , access , bcs = bcs ))
993+ tensors = list (tensor )
994+ if len (tensors ) == 1 :
995+ split = [((0 ,), expr )]
996+ else :
997+ split = firedrake .formmanipulation .split_form (expr )
998+ for (i ,), sub_expr in split :
999+ loops .extend (_interpolator (V [i ], tensors [i ], sub_expr , subset , arguments , access , bcs = bcs ))
9961000
9971001 if bcs and rank == 1 :
9981002 loops .extend (partial (bc .apply , f ) for bc in bcs )
@@ -1707,36 +1711,6 @@ def duplicate(self, mat=None, op=None):
17071711 return self ._wrap_dummy_mat ()
17081712
17091713
1710- def split_interpolate_target (expr : ufl .Interpolate ):
1711- """Split an Interpolate into the components (subfunctions) of the target space."""
1712- dual_arg , operand = expr .argument_slots ()
1713- V = dual_arg .function_space ().dual ()
1714- if len (V ) == 1 :
1715- return (expr ,)
1716- # Split the target (dual) argument
1717- if isinstance (dual_arg , Cofunction ):
1718- duals = dual_arg .subfunctions
1719- elif isinstance (dual_arg , ufl .Coargument ):
1720- duals = [Coargument (Vsub , dual_arg .number ()) for Vsub in dual_arg .function_space ()]
1721- else :
1722- duals = [vi for _ , vi in sorted (firedrake .formmanipulation .split_form (dual_arg ))]
1723- # Split the operand into the target shapes
1724- if (isinstance (operand , firedrake .Function ) and len (operand .subfunctions ) == len (V )
1725- and all (fsub .ufl_shape == Vsub .value_shape for Vsub , fsub in zip (V , operand .subfunctions ))):
1726- # Use subfunctions if they match the target shapes
1727- operands = operand .subfunctions
1728- else :
1729- # Unflatten the expression into the target shapes
1730- cur = 0
1731- operands = []
1732- components = numpy .reshape (operand , (- 1 ,))
1733- for Vi in V :
1734- operands .append (ufl .as_tensor (components [cur :cur + Vi .value_size ].reshape (Vi .value_shape )))
1735- cur += Vi .value_size
1736- expressions = tuple (map (expr ._ufl_expr_reconstruct_ , operands , duals ))
1737- return expressions
1738-
1739-
17401714class MixedInterpolator (Interpolator ):
17411715 """A reusable interpolation object between MixedFunctionSpaces.
17421716
@@ -1754,39 +1728,40 @@ class MixedInterpolator(Interpolator):
17541728 For details see :class:`firedrake.interpolation.Interpolator`.
17551729 """
17561730 def __init__ (self , expr , V , bcs = None , ** kwargs ):
1731+ bcs = bcs or ()
17571732 super (MixedInterpolator , self ).__init__ (expr , V , bcs = bcs , ** kwargs )
1733+
17581734 expr = self .ufl_interpolate
1759- bcs = bcs or ()
17601735 self .arguments = expr .arguments ()
1761-
1762- # Split the target (dual) argument
1763- dual_split = split_interpolate_target ( expr )
1764- self . sub_interpolators = {}
1765- for i , form in enumerate ( dual_split ):
1766- # Split the source (primal) argument
1767- for j , sub_interp in firedrake .formmanipulation . split_form ( form ):
1768- j = max ( j ) if j else 0
1769- # Ensure block sparsity
1770- if not isinstance ( sub_interp , ufl . ZeroBaseForm ):
1771- vi , operand = sub_interp . argument_slots ()
1772- Vtarget = vi . function_space (). dual ()
1773- adjoint = vi . number () == 1 if isinstance ( vi , Coargument ) else True
1774-
1775- args = sub_interp . arguments ()
1776- Vsource = args [ 0 if adjoint else 1 ] .function_space ()
1777- sub_bcs = [ bc for bc in bcs if bc . function_space () in { Vsource , Vtarget }]
1778-
1779- indices = ( j , i ) if adjoint else ( i , j )
1780- Isub = Interpolator (sub_interp , Vtarget , bcs = sub_bcs , ** kwargs )
1781- self . sub_interpolators [ indices ] = Isub
1782-
1736+ rank = len ( self . arguments )
1737+ if rank < 2 :
1738+ dual_arg , operand = expr . argument_slots ( )
1739+ # Split the dual argument
1740+ dual_split = dict ( firedrake . formmanipulation . split_form ( dual_arg ))
1741+ # Create the Jacobian to split into blocks
1742+ expr = expr . _ufl_expr_reconstruct_ ( operand , firedrake .TrialFunction ( dual_arg . function_space ()))
1743+
1744+ Isub = {}
1745+ for indices , form in firedrake . formmanipulation . split_form ( expr ):
1746+ if not isinstance ( form , ufl . ZeroBaseForm ):
1747+ args = form . arguments ()
1748+ vi , operand = form . argument_slots ()
1749+ Vtarget = vi . function_space (). dual ()
1750+ Vsource = args [ 1 - vi . number ()]. function_space ()
1751+ sub_bcs = [ bc for bc in self . bcs if bc .function_space () in { Vsource , Vtarget }]
1752+ if rank == 1 :
1753+ # Take the action of each sub-cofunction against each block
1754+ form = action ( form , dual_split [ indices [ 1 :]] )
1755+ Isub [ indices ] = Interpolator (form , Vtarget , bcs = sub_bcs , ** kwargs )
1756+
1757+ self . sub_interpolators = Isub
17831758 self .callable = self ._get_callable
17841759
17851760 def _get_callable (self ):
17861761 """Assemble the operator."""
1762+ Isub = self .sub_interpolators
17871763 shape = tuple (len (a .function_space ()) for a in self .arguments )
1788- Isubs = self .sub_interpolators
1789- blocks = numpy .reshape ([Isubs [ij ].callable ().handle if ij in Isubs else PETSc .Mat ()
1764+ blocks = numpy .reshape ([Isub [ij ].callable ().handle if ij in Isub else PETSc .Mat ()
17901765 for ij in numpy .ndindex (shape )], shape )
17911766 petscmat = PETSc .Mat ().createNest (blocks )
17921767 tensor = firedrake .AssembledMatrix (self .arguments , self .bcs , petscmat )
0 commit comments