@@ -241,6 +241,8 @@ class Interpolator(abc.ABC):
241241 """
242242
243243 def __new__ (cls , expr , V , ** kwargs ):
244+ if isinstance (expr , ufl .Interpolate ):
245+ expr , = expr .ufl_operands
244246 target_mesh = as_domain (V )
245247 source_mesh = extract_unique_domain (expr ) or target_mesh
246248 submesh_interp_implemented = \
@@ -266,6 +268,8 @@ def __init__(
266268 allow_missing_dofs = False ,
267269 matfree = True
268270 ):
271+ if isinstance (expr , ufl .Interpolate ):
272+ expr , = expr .ufl_operands
269273 self .expr = expr
270274 self .V = V
271275 self .subset = subset
@@ -374,6 +378,8 @@ def __init__(
374378 "Can only interpolate into spaces with point evaluation nodes."
375379 )
376380
381+ if isinstance (expr , ufl .Interpolate ):
382+ expr , = expr .ufl_operands
377383 super ().__init__ (expr , V , subset , freeze_expr , access , bcs , allow_missing_dofs , matfree )
378384
379385 self .arguments = extract_arguments (expr )
@@ -540,7 +546,7 @@ def _interpolate(
540546 V_dest = self .expr .function_space ().dual ()
541547 except AttributeError :
542548 if self .nargs :
543- V_dest = self .arguments [0 ].function_space ().dual ()
549+ V_dest = self .arguments [- 1 ].function_space ().dual ()
544550 else :
545551 coeffs = extract_coefficients (self .expr )
546552 if len (coeffs ):
@@ -552,8 +558,6 @@ def _interpolate(
552558 else :
553559 if isinstance (self .V , (firedrake .Function , firedrake .Cofunction )):
554560 V_dest = self .V .function_space ()
555- elif isinstance (self .V , firedrake .Coargument ):
556- V_dest = self .V .function_space ().dual ()
557561 else :
558562 V_dest = self .V
559563 if output :
@@ -679,10 +683,14 @@ class SameMeshInterpolator(Interpolator):
679683 def __init__ (self , expr , V , subset = None , freeze_expr = False , access = op2 .WRITE ,
680684 bcs = None , matfree = True , allow_missing_dofs = False , ** kwargs ):
681685 if subset is None :
686+ if isinstance (expr , ufl .Interpolate ):
687+ operand , = expr .ufl_operands
688+ else :
689+ operand = expr
682690 target_mesh = as_domain (V )
683- source_mesh = extract_unique_domain (expr )
691+ source_mesh = extract_unique_domain (operand ) or target_mesh
684692 target = target_mesh .topology
685- source = target if source_mesh is None else source_mesh .topology
693+ source = source_mesh .topology
686694 if all (isinstance (m , firedrake .mesh .MeshTopology ) for m in [target , source ]) and target is not source :
687695 composed_map , result_integral_type = source .trans_mesh_entity_map (target , "cell" , "everywhere" , None )
688696 if result_integral_type != "cell" :
@@ -703,7 +711,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE,
703711 self .callable , arguments = make_interpolator (expr , V , subset , access , bcs = bcs , matfree = matfree )
704712 except FIAT .hdiv_trace .TraceError :
705713 raise NotImplementedError ("Can't interpolate onto traces sorry" )
706- self .arguments = arguments
714+ self .arguments = expr . arguments ()
707715 self .nargs = len (arguments )
708716
709717 @PETSc .Log .EventDecorator ()
@@ -735,16 +743,19 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
735743 # Interpolation action
736744 self .frozen_assembled_interpolator = assembled_interpolator .copy ()
737745
738- if self . nargs == 2 :
746+ if hasattr ( assembled_interpolator , "handle" ) and len ( function ) :
739747 function , = function
740748 if not hasattr (function , "dat" ):
741749 raise ValueError ("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!" )
742750 if adjoint :
743751 mul = assembled_interpolator .handle .multHermitian
744752 V = self .arguments [0 ].function_space ().dual ()
753+ assert function .function_space () == self .arguments [1 ].function_space ()
745754 else :
746755 mul = assembled_interpolator .handle .mult
747- V = self .V
756+ V = self .arguments [1 ].function_space ().dual ()
757+ assert function .function_space () == self .arguments [0 ].function_space ()
758+
748759 result = output or firedrake .Function (V )
749760 with function .dat .vec_ro as x , result .dat .vec_wo as out :
750761 if x is not out :
@@ -772,29 +783,23 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
772783
773784@PETSc .Log .EventDecorator ()
774785def make_interpolator (expr , V , subset , access , bcs = None , matfree = True ):
775- assert isinstance (expr , ufl .classes .Expr )
786+ assert isinstance (expr , ufl .Interpolate )
787+ dual_arg , operand = expr .argument_slots ()
788+ assert isinstance (dual_arg , (ufl .Coargument , ufl .Cofunction ))
789+
790+ target_mesh = as_domain (dual_arg )
791+ source_mesh = extract_unique_domain (operand ) or target_mesh
792+ same_mesh = target_mesh is source_mesh
793+
794+ vom_onto_other_vom = (
795+ isinstance (target_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology )
796+ and isinstance (source_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology )
797+ and target_mesh is not source_mesh
798+ )
776799
777- if isinstance (V , (ufl .Coargument , ufl .Cofunction )):
778- dual_arg = V
779- V = dual_arg .function_space ().dual ()
780- elif isinstance (V , (ufl .FunctionSpace , ufl .Coefficient )):
781- fs = V if isinstance (V , ufl .FunctionSpace ) else V .function_space ()
782- dual_arg = Coargument (fs .dual (), number = 0 )
783-
784- arguments = extract_arguments (expr )
785- if isinstance (dual_arg , ufl .Coargument ):
786- arguments .append (dual_arg )
800+ arguments = expr .arguments ()
787801 rank = len (arguments )
788-
789- target_mesh = as_domain (V )
790802 if rank <= 1 :
791- source_mesh = extract_unique_domain (expr ) or target_mesh
792- vom_onto_other_vom = (
793- isinstance (target_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology )
794- and isinstance (source_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology )
795- and target_mesh is not source_mesh
796- )
797-
798803 if rank == 0 :
799804 R = firedrake .FunctionSpace (target_mesh , "Real" , 0 )
800805 f = firedrake .Function (R )
@@ -817,15 +822,8 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
817822 raise ValueError ("Cannot interpolate an expression with an argument into a Function" )
818823 if len (V ) > 1 :
819824 raise NotImplementedError ("Interpolation of mixed expressions with arguments is not supported" )
820-
821825 argfs = arguments [0 ].function_space ()
822- source_mesh = argfs .mesh ()
823826 argfs_map = argfs .cell_node_map ()
824- vom_onto_other_vom = (
825- isinstance (target_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology )
826- and isinstance (source_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology )
827- and target_mesh is not source_mesh
828- )
829827 if isinstance (target_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology ) and target_mesh is not source_mesh and not vom_onto_other_vom :
830828 if not isinstance (target_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology ):
831829 raise NotImplementedError ("Can only interpolate onto a Vertex Only Mesh" )
@@ -863,8 +861,10 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
863861 else :
864862 raise ValueError ("Cannot interpolate an expression with %d arguments" % rank )
865863
864+ if not same_mesh :
865+ arguments = extract_arguments (operand )
866866 if vom_onto_other_vom :
867- wrapper = VomOntoVomWrapper (V , source_mesh , target_mesh , expr , arguments , matfree )
867+ wrapper = VomOntoVomWrapper (V , source_mesh , target_mesh , operand , arguments , matfree )
868868 # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the
869869 # data, including the correct data size and dimensional information
870870 # (so for vector function spaces in 2 dimensions we might need a
@@ -874,13 +874,13 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
874874 # when it is called.
875875 assert f .dat is tensor
876876 wrapper .mpi_type , _ = get_dat_mpi_type (f .dat )
877- assert not len (arguments )
877+ assert len (arguments ) == 0
878878
879879 def callable ():
880880 wrapper .forward_operation (f .dat )
881881 return f
882882 else :
883- assert rank == 2
883+ assert len ( arguments ) == 1
884884 assert tensor is None
885885 # we know we will be outputting either a function or a cofunction,
886886 # both of which will use a dat as a data carrier. At present, the
@@ -904,38 +904,40 @@ def callable():
904904 # Make sure we have an expression of the right length i.e. a value for
905905 # each component in the value shape of each function space
906906 loops = []
907- if numpy .prod (expr .ufl_shape , dtype = int ) != V .value_size :
907+ if numpy .prod (operand .ufl_shape , dtype = int ) != V .value_size :
908908 raise RuntimeError ('Expression of length %d required, got length %d'
909- % (V .value_size , numpy .prod (expr .ufl_shape , dtype = int )))
909+ % (V .value_size , numpy .prod (operand .ufl_shape , dtype = int )))
910910
911911 if len (V ) == 1 :
912- loops .extend (_interpolator (V , tensor , expr , dual_arg , subset , arguments , access , bcs = bcs ))
912+ loops .extend (_interpolator (V , tensor , expr , subset , arguments , access , bcs = bcs ))
913913 else :
914- if (hasattr (expr , "subfunctions" ) and len (expr .subfunctions ) == len (V )
915- and all (sub_expr .ufl_shape == Vsub .value_shape for Vsub , sub_expr in zip (V , expr .subfunctions ))):
914+ if (hasattr (operand , "subfunctions" ) and len (operand .subfunctions ) == len (V )
915+ and all (sub_op .ufl_shape == Vsub .value_shape for Vsub , sub_op in zip (V , operand .subfunctions ))):
916916 # Use subfunctions if they match the target shapes
917- expressions = expr .subfunctions
917+ operands = operand .subfunctions
918918 else :
919919 # Unflatten the expression into the shapes of the mixed components
920920 offset = 0
921- expressions = []
921+ operands = []
922922 for Vsub in V :
923923 if len (Vsub .value_shape ) == 0 :
924- expressions .append (expr [offset ])
924+ operands .append (operand [offset ])
925925 else :
926- components = [expr [offset + j ] for j in range (Vsub .value_size )]
927- expressions .append (ufl .as_tensor (numpy .reshape (components , Vsub .value_shape )))
926+ components = [operand [offset + j ] for j in range (Vsub .value_size )]
927+ operands .append (ufl .as_tensor (numpy .reshape (components , Vsub .value_shape )))
928928 offset += Vsub .value_size
929929
930930 if isinstance (dual_arg , Cofunction ):
931931 duals = dual_arg .subfunctions
932932 elif isinstance (dual_arg , Coargument ):
933- duals = [Coargument (Vsub . dual () , number = dual_arg .number ()) for Vsub in V ]
933+ duals = [Coargument (Vsub , number = dual_arg .number ()) for Vsub in dual_arg . function_space () ]
934934 else :
935935 raise ValueError ("dual_arg must be a Cofunction or Coargument" )
936+
936937 # Interpolate each sub expression into each function space
937- for Vsub , sub_tensor , sub_expr , sub_dual in zip (V , tensor , expressions , duals ):
938- loops .extend (_interpolator (Vsub , sub_tensor , sub_expr , sub_dual , subset , arguments , access , bcs = bcs ))
938+ for Vsub , sub_tensor , sub_op , sub_dual in zip (V , tensor , operands , duals ):
939+ sub_expr = ufl .Interpolate (sub_op , sub_dual )
940+ loops .extend (_interpolator (Vsub , sub_tensor , sub_expr , subset , arguments , access , bcs = bcs ))
939941
940942 if bcs and rank == 1 :
941943 loops .extend (partial (bc .apply , f ) for bc in bcs )
@@ -949,11 +951,18 @@ def callable(loops, f):
949951
950952
951953@utils .known_pyop2_safe
952- def _interpolator (V , tensor , expr , dual_arg , subset , arguments , access , bcs = None ):
954+ def _interpolator (V , tensor , expr , subset , arguments , access , bcs = None ):
953955 try :
954956 expr = ufl .as_ufl (expr )
955957 except ValueError :
956958 raise ValueError ("Expecting to interpolate a UFL expression" )
959+
960+ interp_expr = expr
961+ if isinstance (expr , ufl .Interpolate ):
962+ dual_arg , expr = expr .argument_slots ()
963+ else :
964+ dual_arg = Coargument (V .dual (), number = 0 )
965+
957966 try :
958967 to_element = create_element (V .ufl_element ())
959968 except KeyError :
@@ -1029,7 +1038,7 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None
10291038 # FIXME: for the runtime unknown point set (for cross-mesh
10301039 # interpolation) we have to pass the finat element we construct
10311040 # here. Ideally we would only pass the UFL element through.
1032- kernel = compile_expression (cell_set .comm , expr , dual_arg , to_element , V .ufl_element (),
1041+ kernel = compile_expression (cell_set .comm , interp_expr , to_element , V .ufl_element (),
10331042 domain = source_mesh , parameters = parameters )
10341043 ast = kernel .ast
10351044 oriented = kernel .oriented
@@ -1042,7 +1051,7 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None
10421051
10431052 parloop_args = [kernel , cell_set ]
10441053
1045- interp_expr = ufl .Interpolate (expr , dual_arg )
1054+ interp_expr = ufl .Interpolate (expr , v = dual_arg )
10461055 coefficients = tsfc_interface .extract_numbered_coefficients (interp_expr , coefficient_numbers )
10471056 if needs_external_coords :
10481057 coefficients = [source_mesh .coordinates ] + coefficients
@@ -1061,12 +1070,13 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None
10611070 if isinstance (tensor , op2 .Global ):
10621071 parloop_args .append (tensor (access ))
10631072 elif isinstance (tensor , op2 .Dat ):
1064- V_dest = arguments [0 ].function_space () if isinstance (dual_arg , ufl .Cofunction ) else V
1073+ V_dest = arguments [- 1 ].function_space () if isinstance (dual_arg , ufl .Cofunction ) else V
10651074 parloop_args .append (tensor (access , V_dest .cell_node_map ()))
10661075 else :
10671076 assert access == op2 .WRITE # Other access descriptors not done for Matrices.
10681077 rows_map = V .cell_node_map ()
10691078 Vcol = arguments [0 ].function_space ()
1079+ assert tensor .handle .getSize () == (V .dim (), Vcol .dim ())
10701080 if isinstance (target_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology ):
10711081 columns_map = Vcol .cell_node_map ()
10721082 if target_mesh is not source_mesh :
@@ -1165,9 +1175,10 @@ def _interpolator(V, tensor, expr, dual_arg, subset, arguments, access, bcs=None
11651175 f"firedrake-tsfc-expression-kernel-cache-uid{ os .getuid ()} " )
11661176
11671177
1168- def _compile_expression_key (comm , expr , dual_arg , to_element , ufl_element , domain , parameters ) -> tuple [Hashable , ...]:
1178+ def _compile_expression_key (comm , expr , to_element , ufl_element , domain , parameters ) -> tuple [Hashable , ...]:
11691179 """Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`."""
1170- return (hash_expr (expr ), type (dual_arg ), hash (ufl_element ), utils .tuplify (parameters ))
1180+ dual_arg , operand = expr .argument_slots ()
1181+ return (hash_expr (operand ), type (dual_arg ), hash (ufl_element ), utils .tuplify (parameters ))
11711182
11721183
11731184@memory_and_disk_cache (
0 commit comments