@@ -93,10 +93,20 @@ def __init__(self, expr, v,
9393 and reduce operations.
9494 """
9595 # Check function space
96+ expr = ufl .as_ufl (expr )
9697 if isinstance (v , functionspaceimpl .WithGeometry ):
97- expr_args = extract_arguments (ufl . as_ufl ( expr ) )
98+ expr_args = extract_arguments (expr )
9899 is_adjoint = len (expr_args ) and expr_args [0 ].number () == 0
99100 v = Argument (v .dual (), 1 if is_adjoint else 0 )
101+
102+ V = v .arguments ()[0 ].function_space ()
103+ if len (expr .ufl_shape ) != len (V .value_shape ):
104+ raise RuntimeError ('Rank mismatch: Expression rank %d, FunctionSpace rank %d'
105+ % (len (expr .ufl_shape ), len (V .value_shape )))
106+
107+ if expr .ufl_shape != V .value_shape :
108+ raise RuntimeError ('Shape mismatch: Expression shape %r, FunctionSpace shape %r'
109+ % (expr .ufl_shape , V .value_shape ))
100110 super ().__init__ (expr , v )
101111
102112 # -- Interpolate data (e.g. `subset` or `access`) -- #
@@ -173,7 +183,7 @@ def interpolate(expr, V, subset=None, access=op2.WRITE, allow_missing_dofs=False
173183 raise TypeError (f"Expected a one-form, provided form had { rank } arguments" )
174184 elif isinstance (V , functionspaceimpl .WithGeometry ):
175185 dual_arg = Coargument (V .dual (), 0 )
176- expr_args = extract_arguments (expr )
186+ expr_args = extract_arguments (ufl . as_ufl ( expr ) )
177187 if expr_args and expr_args [0 ].number () == 0 :
178188 # In this case we are doing adjoint interpolation
179189 # When V is a FunctionSpace and expr contains Argument(0),
@@ -483,7 +493,7 @@ def __init__(
483493 if len (shape ) == 0 :
484494 fs_type = firedrake .FunctionSpace
485495 elif len (shape ) == 1 :
486- fs_type = firedrake .VectorFunctionSpace
496+ fs_type = partial ( firedrake .VectorFunctionSpace , dim = shape [ 0 ])
487497 else :
488498 fs_type = partial (firedrake .TensorFunctionSpace , shape = shape )
489499 P0DG_vom = fs_type (self .vom_dest_node_coords_in_src_mesh , "DG" , 0 )
@@ -710,7 +720,7 @@ def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE,
710720 super ().__init__ (expr , V , subset = subset , freeze_expr = freeze_expr ,
711721 access = access , bcs = bcs , matfree = matfree , allow_missing_dofs = allow_missing_dofs )
712722 try :
713- self .callable , arguments = make_interpolator (expr , V , subset , access , bcs = bcs , matfree = matfree )
723+ self .callable = make_interpolator (expr , V , subset , access , bcs = bcs , matfree = matfree )
714724 except FIAT .hdiv_trace .TraceError :
715725 raise NotImplementedError ("Can't interpolate onto traces sorry" )
716726 self .arguments = expr .arguments ()
@@ -726,6 +736,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
726736 if transpose is not None :
727737 warnings .warn ("'transpose' argument is deprecated, use 'adjoint' instead" , FutureWarning )
728738 adjoint = transpose or adjoint
739+
729740 try :
730741 assembled_interpolator = self .frozen_assembled_interpolator
731742 copy_required = True
@@ -740,7 +751,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
740751 # Interpolation action
741752 self .frozen_assembled_interpolator = assembled_interpolator .copy ()
742753
743- if hasattr ( assembled_interpolator , "handle" ) and len (function ):
754+ if len ( self . arguments ) == 2 and len (function ):
744755 function , = function
745756 if not hasattr (function , "dat" ):
746757 raise ValueError ("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!" )
@@ -783,15 +794,11 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
783794@PETSc .Log .EventDecorator ()
784795def make_interpolator (expr , V , subset , access , bcs = None , matfree = True ):
785796 if not isinstance (expr , ufl .Interpolate ):
786- fs = V if isinstance (V , ufl .FunctionSpace ) else V .function_space ()
787- expr = Interpolate (expr , fs )
797+ raise ValueError (f"Expecting to interpolate a ufl.Interpolate, got { type (expr ).__name__ } ." )
788798 dual_arg , operand = expr .argument_slots ()
789- assert isinstance (dual_arg , (ufl .Coargument , ufl .Cofunction ))
790799
791800 target_mesh = as_domain (dual_arg )
792801 source_mesh = extract_unique_domain (operand ) or target_mesh
793- same_mesh = target_mesh is source_mesh
794-
795802 vom_onto_other_vom = (
796803 isinstance (target_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology )
797804 and isinstance (source_mesh .topology , firedrake .mesh .VertexOnlyMeshTopology )
@@ -803,7 +810,7 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
803810 if rank <= 1 :
804811 if rank == 0 :
805812 R = firedrake .FunctionSpace (target_mesh , "Real" , 0 )
806- f = firedrake .Function (R )
813+ f = firedrake .Function (R , dtype = utils . ScalarType )
807814 elif isinstance (V , firedrake .Function ):
808815 f = V
809816 V = f .function_space ()
@@ -862,10 +869,8 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
862869 else :
863870 raise ValueError ("Cannot interpolate an expression with %d arguments" % rank )
864871
865- if not same_mesh :
866- arguments = extract_arguments (operand )
867872 if vom_onto_other_vom :
868- wrapper = VomOntoVomWrapper (V , source_mesh , target_mesh , operand , arguments , matfree )
873+ wrapper = VomOntoVomWrapper (V , source_mesh , target_mesh , operand , matfree )
869874 # NOTE: get_dat_mpi_type ensures we get the correct MPI type for the
870875 # data, including the correct data size and dimensional information
871876 # (so for vector function spaces in 2 dimensions we might need a
@@ -875,13 +880,13 @@ def make_interpolator(expr, V, subset, access, bcs=None, matfree=True):
875880 # when it is called.
876881 assert f .dat is tensor
877882 wrapper .mpi_type , _ = get_dat_mpi_type (f .dat )
878- assert len (arguments ) == 0
883+ assert len (arguments ) == 1
879884
880885 def callable ():
881886 wrapper .forward_operation (f .dat )
882887 return f
883888 else :
884- assert len (arguments ) == 1
889+ assert len (arguments ) == 2
885890 assert tensor is None
886891 # we know we will be outputting either a function or a cofunction,
887892 # both of which will use a dat as a data carrier. At present, the
@@ -900,15 +905,9 @@ def callable():
900905 def callable ():
901906 return wrapper
902907
903- return callable , arguments
908+ return callable
904909 else :
905- # Make sure we have an expression of the right length i.e. a value for
906- # each component in the value shape of each function space
907910 loops = []
908- if numpy .prod (operand .ufl_shape , dtype = int ) != V .value_size :
909- raise RuntimeError ('Expression of length %d required, got length %d'
910- % (V .value_size , numpy .prod (operand .ufl_shape , dtype = int )))
911-
912911 if len (V ) == 1 :
913912 loops .extend (_interpolator (V , tensor , expr , subset , arguments , access , bcs = bcs ))
914913 else :
@@ -933,8 +932,7 @@ def callable():
933932 elif isinstance (dual_arg , Coargument ):
934933 duals = [Coargument (Vsub , number = dual_arg .number ()) for Vsub in dual_arg .function_space ()]
935934 else :
936- raise ValueError ("dual_arg must be a Cofunction or Coargument" )
937-
935+ duals = [v for _ , v in sorted (firedrake .formmanipulation .split_form (dual_arg ))]
938936 # Interpolate each sub expression into each function space
939937 for Vsub , sub_tensor , sub_op , sub_dual in zip (V , tensor , operands , duals ):
940938 sub_expr = expr ._ufl_expr_reconstruct_ (sub_op , sub_dual )
@@ -948,7 +946,7 @@ def callable(loops, f):
948946 l ()
949947 return f
950948
951- return partial (callable , loops , f ), arguments
949+ return partial (callable , loops , f )
952950
953951
954952@utils .known_pyop2_safe
@@ -966,14 +964,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
966964 if access is op2 .READ :
967965 raise ValueError ("Can't have READ access for output function" )
968966
969- if len (operand .ufl_shape ) != len (V .value_shape ):
970- raise RuntimeError ('Rank mismatch: Expression rank %d, FunctionSpace rank %d'
971- % (len (operand .ufl_shape ), len (V .value_shape )))
972-
973- if operand .ufl_shape != V .value_shape :
974- raise RuntimeError ('Shape mismatch: Expression shape %r, FunctionSpace shape %r'
975- % (operand .ufl_shape , V .value_shape ))
976-
977967 # NOTE: The par_loop is always over the target mesh cells.
978968 target_mesh = as_domain (V )
979969 source_mesh = extract_unique_domain (operand ) or target_mesh
@@ -1401,17 +1391,15 @@ class VomOntoVomWrapper(object):
14011391 expr : `ufl.Expr`
14021392 The expression to interpolate. If ``arguments`` is not empty, those
14031393 arguments must be present within it.
1404- arguments : list of `ufl.Argument`
1405- The arguments in the expression. These are not extracted from expr here
1406- since, where we use this, we already have them.
14071394 matfree : bool
14081395 If ``False``, the matrix representating the permutation of the points is
14091396 constructed and used to perform the interpolation. If ``True``, then the
14101397 interpolation is performed using the broadcast and reduce operations on the
14111398 PETSc Star Forest.
14121399 """
14131400
1414- def __init__ (self , V , source_vom , target_vom , expr , arguments , matfree ):
1401+ def __init__ (self , V , source_vom , target_vom , expr , matfree ):
1402+ arguments = extract_arguments (expr )
14151403 reduce = False
14161404 if source_vom .input_ordering is target_vom :
14171405 reduce = True
0 commit comments