@@ -122,7 +122,7 @@ def _make_core(self):
122122 # Make the initial guess callback
123123 if self .field_data .initial_guess .exprs :
124124 self ._make_initial_guess ()
125- # Make the callback used to constrain boundary nodes
125+ # Make the callback to constrain boundary nodes
126126 if self .field_data .constrain_bc :
127127 self ._make_constrain_bc ()
128128 self ._make_user_struct_efunc ()
@@ -650,14 +650,9 @@ def _make_constrain_bc(self):
650650 Constructs the `CountBCs` and `SetPointBCs` efuncs. Works for both
651651 single- and multi-field.
652652 """
653- constrain_bc = self .field_data .constrain_bc
653+ constrain_bc_dict = self .field_data .constrain_bc
654654 sobjs = self .solver_objs
655655
656- # Normalize to dict {target: ConstrainBC}
657- if isinstance (constrain_bc , dict ):
658- constrain_bc_dict = constrain_bc
659- else :
660- constrain_bc_dict = {self .field_data .target : constrain_bc }
661656 targets = list (constrain_bc_dict .keys ())
662657
663658 all_increment_exprs = [
@@ -741,136 +736,14 @@ def _create_count_bc_body(self, body, pairs):
741736 return Uxreplace (subs ).visit (body )
742737
743738 def _create_set_point_bc_body (self , body , constrain_bc_dict ):
744- """Single-field SetPointBCs body. `constrain_bc_dict` has one entry."""
745- (target , constrain_bc ), = constrain_bc_dict .items ()
746- tname = target .name
747- linsolve_expr = self .inject_solve .expr .rhs
748- objs = self .objs
749- sobjs = self .solver_objs
750-
751- dmda = sobjs ['callbackdm' ]
752- ctx = objs ['dummyctx' ]
753-
754- dm_get_local_info = petsc_call (
755- 'DMDAGetLocalInfo' , [dmda , Byref (linsolve_expr .localinfo )]
756- )
757-
758- body = self .time_dependence .uxreplace_time (body )
759-
760- fields = get_user_struct_fields (body )
761- self ._struct_params .extend (fields )
762-
763- dm_get_app_context = petsc_call (
764- 'DMGetApplicationContext' , [dmda , Byref (ctx ._C_symbol )]
765- )
766- petsc_obj_comm = Call ('PetscObjectComm' , arguments = [PetscObjectCast (dmda )])
767- is_create_general = petsc_call (
768- 'ISCreateGeneral' ,
769- [petsc_obj_comm , sobjs [f'numBC_{ tname } ' ], sobjs [f'bcPointsArr_{ tname } ' ],
770- 'PETSC_OWN_POINTER' , Byref (sobjs ['bcPointsIS' ])]
771- )
772- malloc_bc_points_arr = petsc_call (
773- 'PetscMalloc1' ,
774- [sobjs [f'numBC_{ tname } ' ], Byref (sobjs [f'bcPointsArr_{ tname } ' ]._C_symbol )]
775- )
776- malloc_bc_points = petsc_call (
777- 'PetscMalloc1' , [1 , Byref (sobjs ['bcPoints' ]._C_symbol )]
778- )
779- dummy_expr = DummyExpr (sobjs ['bcPoints' ].indexed [0 ], sobjs ['bcPointsIS' ])
780- set_point_bc = petsc_call (
781- 'DMDASetPointBC' , [dmda , 1 , sobjs ['bcPoints' ], Null ]
782- )
783- body = body ._rebuild (
784- body = (
785- (malloc_bc_points_arr ,)
786- + body .body
787- + (is_create_general , malloc_bc_points , dummy_expr , set_point_bc ,)
788- )
789- )
790-
791- derefs = dereference_funcs (ctx , fields )
792- standalones = [
793- Definition (ctx ),
794- dm_get_app_context ,
795- Definition (sobjs [f'k_iter_{ tname } ' ])
796- ]
797- body = self ._make_callable_body (
798- body , standalones = standalones , stacks = (dm_get_local_info ,) + derefs
799- )
800-
801- subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for
802- i in fields if not isinstance (i .function , AbstractFunction )}
803- subs [constrain_bc .counter ._C_symbol ] = \
804- sobjs [f'bcPointsArr_{ tname } ' ].indexed [sobjs [f'k_iter_{ tname } ' ]]
805-
806- return Uxreplace (subs ).visit (body )
807-
808- def _make_user_struct_efunc (self ):
809- """
810- This is the struct initialised inside the main kernel and
811- attached to the DM via DMSetApplicationContext.
812- """
813- mainctx = self .solver_objs ['userctx' ] = MainUserStruct (
814- name = self .sregistry .make_name (prefix = 'ctx' ),
815- pname = self .sregistry .make_name (prefix = 'UserCtx' ),
816- fields = self .filtered_struct_params ,
817- liveness = 'lazy' ,
818- modifier = None
819- )
820- body = [
821- DummyExpr (FieldFromPointer (i ._C_symbol , mainctx ), i ._C_symbol )
822- for i in mainctx .callback_fields
823- ]
824- struct_callback_body = self ._make_callable_body (body )
825- cb = Callable (
826- self .sregistry .make_name (prefix = 'PopulateUserContext' ),
827- struct_callback_body , self .objs ['err' ],
828- parameters = [mainctx ]
829- )
830- self ._efuncs [cb .name ] = cb
831- self ._user_struct_efunc = cb
832-
833- def _uxreplace_efuncs (self ):
834- sobjs = self .solver_objs
835- callback_user_struct = CallbackUserStruct (
836- name = sobjs ['userctx' ].name ,
837- pname = sobjs ['userctx' ].pname ,
838- fields = self .filtered_struct_params ,
839- liveness = 'lazy' ,
840- modifier = ' *' ,
841- parent = sobjs ['userctx' ]
842- )
843- mapper = {}
844- visitor = Uxreplace ({self .objs ['dummyctx' ]: callback_user_struct })
845- for k , v in self ._efuncs .items ():
846- mapper .update ({k : visitor .visit (v )})
847- return mapper
848-
849-
850- class CoupledCallbackBuilder (BaseCallbackBuilder ):
851- def __init__ (self , ** kwargs ):
852- self ._submatrices_callback = None
853- self ._destroy_submat_callback = None
854- super ().__init__ (** kwargs )
855-
856- @property
857- def submatrices_callback (self ):
858- return self ._submatrices_callback
859-
860- def _create_set_point_bc_body (self , body , _constrain_bc_dict ):
861- return self ._create_set_point_bc_body_coupled (body )
862-
863- def _create_set_point_bc_body_coupled (self , body ):
864739 """
865- # TODO : ADD DOCS - MAKE IT CLEARER
866- Combined SetPointBCs body for all target fields.
740+ Generic SetPointBCs body, handles single- and multi-field.
867741 """
742+ targets = list (constrain_bc_dict .keys ())
743+ nfields = len (targets )
868744 linsolve_expr = self .inject_solve .expr .rhs
869745 objs = self .objs
870746 sobjs = self .solver_objs
871- constrain_bc = self .field_data .constrain_bc
872- targets = self .field_data .targets
873- nfields = len (targets )
874747 dmda = sobjs ['callbackdm' ]
875748 ctx = objs ['dummyctx' ]
876749
@@ -946,11 +819,63 @@ def _create_set_point_bc_body_coupled(self, body):
946819 i in fields if not isinstance (i .function , AbstractFunction )}
947820 for t in targets :
948821 tname = t .name
949- subs [constrain_bc [t ].counter ._C_symbol ] = \
822+ subs [constrain_bc_dict [t ].counter ._C_symbol ] = \
950823 sobjs [f'bcPointsArr_{ tname } ' ].indexed [sobjs [f'k_iter_{ tname } ' ]]
951824
952825 return Uxreplace (subs ).visit (body )
953826
827+ def _make_user_struct_efunc (self ):
828+ """
829+ This is the struct initialised inside the main kernel and
830+ attached to the DM via DMSetApplicationContext.
831+ """
832+ mainctx = self .solver_objs ['userctx' ] = MainUserStruct (
833+ name = self .sregistry .make_name (prefix = 'ctx' ),
834+ pname = self .sregistry .make_name (prefix = 'UserCtx' ),
835+ fields = self .filtered_struct_params ,
836+ liveness = 'lazy' ,
837+ modifier = None
838+ )
839+ body = [
840+ DummyExpr (FieldFromPointer (i ._C_symbol , mainctx ), i ._C_symbol )
841+ for i in mainctx .callback_fields
842+ ]
843+ struct_callback_body = self ._make_callable_body (body )
844+ cb = Callable (
845+ self .sregistry .make_name (prefix = 'PopulateUserContext' ),
846+ struct_callback_body , self .objs ['err' ],
847+ parameters = [mainctx ]
848+ )
849+ self ._efuncs [cb .name ] = cb
850+ self ._user_struct_efunc = cb
851+
852+ def _uxreplace_efuncs (self ):
853+ sobjs = self .solver_objs
854+ callback_user_struct = CallbackUserStruct (
855+ name = sobjs ['userctx' ].name ,
856+ pname = sobjs ['userctx' ].pname ,
857+ fields = self .filtered_struct_params ,
858+ liveness = 'lazy' ,
859+ modifier = ' *' ,
860+ parent = sobjs ['userctx' ]
861+ )
862+ mapper = {}
863+ visitor = Uxreplace ({self .objs ['dummyctx' ]: callback_user_struct })
864+ for k , v in self ._efuncs .items ():
865+ mapper .update ({k : visitor .visit (v )})
866+ return mapper
867+
868+
869+ class CoupledCallbackBuilder (BaseCallbackBuilder ):
870+ def __init__ (self , ** kwargs ):
871+ self ._submatrices_callback = None
872+ self ._destroy_submat_callback = None
873+ super ().__init__ (** kwargs )
874+
875+ @property
876+ def submatrices_callback (self ):
877+ return self ._submatrices_callback
878+
954879 @property
955880 def jacobian (self ):
956881 return self .inject_solve .expr .rhs .field_data .jacobian
@@ -1175,9 +1100,10 @@ def _whole_formfunc_body(self, body):
11751100 return Uxreplace (subs ).visit (formfunc_body )
11761101
11771102 def _create_destroy_submatrix (self ):
1178- # Need a special destroy because each submatrix has a manually
1179- # PetscMalloc'ed context attached via MatShellSetContext
1180-
1103+ """
1104+ Each submatrix has a PetscMalloc'd context attached via MatShellSetContext
1105+ that PETSc's default MatDestroy won't free, so we register a custom destroy.
1106+ """
11811107 objs = self .objs
11821108
11831109 get_ctx = petsc_call (
0 commit comments