@@ -1006,8 +1006,13 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10061006 parameters = {}
10071007 parameters ['scalar_type' ] = utils .ScalarType
10081008
1009+ callables = ()
1010+ if access == op2 .INC :
1011+ callables += (tensor .zero ,)
1012+
10091013 needs_weight = isinstance (dual_arg , ufl .Cofunction ) and not to_element .is_dg ()
10101014 if needs_weight :
1015+ # Compute the reciprocal of the DOF multiplicity
10111016 W = dual_arg .function_space ()
10121017 shapes = (W .finat_element .space_dimension (), W .block_size )
10131018 domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes
@@ -1018,14 +1023,14 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10181023 """
10191024 weight = firedrake .Function (W )
10201025 firedrake .par_loop ((domain , instructions ), ufl .dx , {"w" : (weight , op2 .INC )})
1026+ with weight .dat .vec as w :
1027+ w .reciprocal ()
10211028
1022- # Create a copy and apply the weight
1023- # TODO include this in the callables
1024- v = firedrake .Function (dual_arg )
1025- with v .dat .vec as x , weight .dat .vec as w :
1026- x .pointwiseDivide (x , w )
1027-
1029+ # Create a buffer for the weighted Cofunction and a callable to apply the weight
1030+ v = firedrake .Function (W )
10281031 expr = expr ._ufl_expr_reconstruct_ (operand , v = v )
1032+ with weight .dat .vec_ro as w , dual_arg .dat .vec_ro as x , v .dat .vec_wo as y :
1033+ callables += (partial (y .pointwiseMult , x , w ),)
10291034
10301035 # We need to pass both the ufl element and the finat element
10311036 # because the finat elements might not have the right mapping
@@ -1043,6 +1048,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10431048 name = kernel .name
10441049 kernel = op2 .Kernel (ast , name , requires_zeroed_output_arguments = True ,
10451050 flop_count = kernel .flop_count , events = (kernel .event ,))
1051+
10461052 parloop_args = [kernel , cell_set ]
10471053
10481054 coefficients = tsfc_interface .extract_numbered_coefficients (expr , coefficient_numbers )
@@ -1158,7 +1164,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
11581164 if isinstance (tensor , op2 .Mat ):
11591165 return parloop_compute_callable , tensor .assemble
11601166 else :
1161- return copyin + (parloop_compute_callable , ) + copyout
1167+ return copyin + callables + (parloop_compute_callable , ) + copyout
11621168
11631169
11641170try :
0 commit comments