@@ -552,6 +552,8 @@ def _interpolate(
552552 else :
553553 if isinstance (self .V , (firedrake .Function , firedrake .Cofunction )):
554554 V_dest = self .V .function_space ()
555+ elif isinstance (self .V , firedrake .Coargument ):
556+ V_dest = self .V .function_space ().dual ()
555557 else :
556558 V_dest = self .V
557559 if output :
@@ -677,9 +679,10 @@ class SameMeshInterpolator(Interpolator):
677679 def __init__ (self , expr , V , subset = None , freeze_expr = False , access = op2 .WRITE ,
678680 bcs = None , matfree = True , allow_missing_dofs = False , ** kwargs ):
679681 if subset is None :
680- target = V .function_space ().mesh ().topology if isinstance (V , firedrake .Function ) else V .mesh ().topology
681- temp = extract_unique_domain (expr )
682- source = target if temp is None else temp .topology
682+ target_mesh = as_domain (V )
683+ source_mesh = extract_unique_domain (expr )
684+ target = target_mesh .topology
685+ source = target if source_mesh is None else source_mesh .topology
683686 if all (isinstance (m , firedrake .mesh .MeshTopology ) for m in [target , source ]) and target is not source :
684687 composed_map , result_integral_type = source .trans_mesh_entity_map (target , "cell" , "everywhere" , None )
685688 if result_integral_type != "cell" :
@@ -732,7 +735,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
732735 # Interpolation action
733736 self .frozen_assembled_interpolator = assembled_interpolator .copy ()
734737
735- if self .nargs :
738+ if self .nargs == 2 :
736739 function , = function
737740 if not hasattr (function , "dat" ):
738741 raise ValueError ("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!" )
@@ -770,20 +773,37 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
770773@PETSc .Log .EventDecorator ()
771774def make_interpolator (expr , V , subset , access , bcs = None , matfree = True ):
772775 assert isinstance (expr , ufl .classes .Expr )
776+
777+ if isinstance (V , (ufl .Coargument , ufl .Cofunction )):
778+ dual_arg = V
779+ V = dual_arg .function_space ().dual ()
780+ elif isinstance (V , (ufl .FunctionSpace , ufl .Coefficient )):
781+ fs = V if isinstance (V , ufl .FunctionSpace ) else V .function_space ()
782+ dual_arg = Coargument (fs .dual (), number = 0 )
783+
773784 arguments = extract_arguments (expr )
785+ if isinstance (dual_arg , ufl .Coargument ):
786+ arguments .append (dual_arg )
787+ rank = len (arguments )
788+
774789 target_mesh = as_domain (V )
775- if len ( arguments ) == 0 :
790+ if rank <= 1 :
776791 source_mesh = extract_unique_domain (expr ) or target_mesh
777792 vom_onto_other_vom = (
778793 isinstance (target_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology )
779794 and isinstance (source_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology )
780795 and target_mesh is not source_mesh
781796 )
797+
798+ if rank == 0 :
799+ # FIXME
800+ V = firedrake .Function (firedrake .FunctionSpace (target_mesh , "Real" , 0 ))
782801 if isinstance (V , firedrake .Function ):
783802 f = V
784803 V = f .function_space ()
785804 else :
786- f = firedrake .Function (V )
805+ V_dest = arguments [- 1 ].function_space ().dual ()
806+ f = firedrake .Function (V_dest )
787807 if access in {firedrake .MIN , firedrake .MAX }:
788808 finfo = numpy .finfo (f .dat .dtype )
789809 if access == firedrake .MIN :
@@ -792,11 +812,12 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
792812 val = firedrake .Constant (finfo .min )
793813 f .assign (val )
794814 tensor = f .dat
795- elif len ( arguments ) == 1 :
815+ elif rank == 2 :
796816 if isinstance (V , firedrake .Function ):
797817 raise ValueError ("Cannot interpolate an expression with an argument into a Function" )
798818 if len (V ) > 1 :
799819 raise NotImplementedError ("Interpolation of mixed expressions with arguments is not supported" )
820+
800821 argfs = arguments [0 ].function_space ()
801822 source_mesh = argfs .mesh ()
802823 argfs_map = argfs .cell_node_map ()
@@ -840,7 +861,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
840861 tensor = op2 .Mat (sparsity )
841862 f = tensor
842863 else :
843- raise ValueError ("Cannot interpolate an expression with %d arguments" % len ( arguments ) )
864+ raise ValueError ("Cannot interpolate an expression with %d arguments" % rank )
844865
845866 if vom_onto_other_vom :
846867 wrapper = VomOntoVomWrapper (V , source_mesh , target_mesh , expr , arguments , matfree )
@@ -859,7 +880,7 @@ def callable():
859880 wrapper .forward_operation (f .dat )
860881 return f
861882 else :
862- assert len ( arguments ) == 1
883+ assert rank == 2
863884 assert tensor is None
864885 # we know we will be outputting either a function or a cofunction,
865886 # both of which will use a dat as a data carrier. At present, the
@@ -888,7 +909,7 @@ def callable():
888909 % (V .value_size , numpy .prod (expr .ufl_shape , dtype = int )))
889910
890911 if len (V ) == 1 :
891- loops .extend (_interpolator (V , tensor , expr , subset , arguments , access , bcs = bcs ))
912+ loops .extend (_interpolator (V , tensor , expr , dual_arg , subset , arguments , access , bcs = bcs ))
892913 else :
893914 if (hasattr (expr , "subfunctions" ) and len (expr .subfunctions ) == len (V )
894915 and all (sub_expr .ufl_shape == Vsub .value_shape for Vsub , sub_expr in zip (V , expr .subfunctions ))):
@@ -905,11 +926,18 @@ def callable():
905926 components = [expr [offset + j ] for j in range (Vsub .value_size )]
906927 expressions .append (ufl .as_tensor (numpy .reshape (components , Vsub .value_shape )))
907928 offset += Vsub .value_size
929+
930+ if isinstance (dual_arg , Cofunction ):
931+ duals = dual_arg .subfunctions
932+ elif isinstance (dual_arg , Coargument ):
933+ duals = [Coargument (Vsub .dual (), number = dual_arg .number ()) for Vsub in V ]
934+ else :
935+ raise ValueError ("dual_arg must be a Cofunction or Coargument" )
908936 # Interpolate each sub expression into each function space
909- for Vsub , sub_tensor , sub_expr in zip (V , tensor , expressions ):
910- loops .extend (_interpolator (Vsub , sub_tensor , sub_expr , subset , arguments , access , bcs = bcs ))
937+ for Vsub , sub_tensor , sub_expr , sub_dual in zip (V , tensor , expressions , duals ):
938+ loops .extend (_interpolator (Vsub , sub_tensor , sub_expr , sub_dual , subset , arguments , access , bcs = bcs ))
911939
912- if bcs and len ( arguments ) == 0 :
940+ if bcs and rank == 1 :
913941 loops .extend (partial (bc .apply , f ) for bc in bcs )
914942
915943 def callable (loops , f ):
@@ -921,7 +949,7 @@ def callable(loops, f):
921949
922950
923951@utils .known_pyop2_safe
924- def _interpolator (V , tensor , expr , subset , arguments , access , bcs = None ):
952+ def _interpolator (V , tensor , expr , dual_arg , subset , arguments , access , bcs = None ):
925953 try :
926954 expr = ufl .as_ufl (expr )
927955 except ValueError :
@@ -977,13 +1005,31 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
9771005 parameters = {}
9781006 parameters ['scalar_type' ] = utils .ScalarType
9791007
1008+ needs_weight = isinstance (dual_arg , ufl .Cofunction ) and not to_element .is_dg ()
1009+ if needs_weight :
1010+ W = dual_arg .function_space ()
1011+ shapes = (W .finat_element .space_dimension (), W .block_size )
1012+ domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes
1013+ instructions = """
1014+ for i, j
1015+ w[i,j] = w[i,j] + 1
1016+ end
1017+ """
1018+ weight = firedrake .Function (W )
1019+ firedrake .par_loop ((domain , instructions ), ufl .dx , {"w" : (weight , op2 .INC )})
1020+
1021+ tmp = firedrake .Function (W )
1022+ with weight .dat .vec as w , dual_arg .dat .vec as x , tmp .dat .vec as y :
1023+ y .pointwiseDivide (x , w )
1024+ dual_arg = tmp
1025+
9801026 # We need to pass both the ufl element and the finat element
9811027 # because the finat elements might not have the right mapping
9821028 # (e.g. L2 Piola, or tensor element with symmetries)
9831029 # FIXME: for the runtime unknown point set (for cross-mesh
9841030 # interpolation) we have to pass the finat element we construct
9851031 # here. Ideally we would only pass the UFL element through.
986- kernel = compile_expression (cell_set .comm , expr , to_element , V .ufl_element (),
1032+ kernel = compile_expression (cell_set .comm , expr , dual_arg , to_element , V .ufl_element (),
9871033 domain = source_mesh , parameters = parameters )
9881034 ast = kernel .ast
9891035 oriented = kernel .oriented
@@ -996,7 +1042,8 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
9961042
9971043 parloop_args = [kernel , cell_set ]
9981044
999- coefficients = tsfc_interface .extract_numbered_coefficients (expr , coefficient_numbers )
1045+ interp_expr = ufl .Interpolate (expr , dual_arg )
1046+ coefficients = tsfc_interface .extract_numbered_coefficients (interp_expr , coefficient_numbers )
10001047 if needs_external_coords :
10011048 coefficients = [source_mesh .coordinates ] + coefficients
10021049
@@ -1014,7 +1061,8 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10141061 if isinstance (tensor , op2 .Global ):
10151062 parloop_args .append (tensor (access ))
10161063 elif isinstance (tensor , op2 .Dat ):
1017- parloop_args .append (tensor (access , V .cell_node_map ()))
1064+ V_dest = arguments [0 ].function_space () if isinstance (dual_arg , ufl .Cofunction ) else V
1065+ parloop_args .append (tensor (access , V_dest .cell_node_map ()))
10181066 else :
10191067 assert access == op2 .WRITE # Other access descriptors not done for Matrices.
10201068 rows_map = V .cell_node_map ()
@@ -1117,9 +1165,9 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
11171165 f"firedrake-tsfc-expression-kernel-cache-uid{ os .getuid ()} " )
11181166
11191167
1120- def _compile_expression_key (comm , expr , to_element , ufl_element , domain , parameters ) -> tuple [Hashable , ...]:
1168+ def _compile_expression_key (comm , expr , dual_arg , to_element , ufl_element , domain , parameters ) -> tuple [Hashable , ...]:
11211169 """Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`."""
1122- return (hash_expr (expr ), hash (ufl_element ), utils .tuplify (parameters ))
1170+ return (hash_expr (expr ), type ( dual_arg ), hash (ufl_element ), utils .tuplify (parameters ))
11231171
11241172
11251173@memory_and_disk_cache (
0 commit comments