Skip to content

Commit ffe3246

Browse files
committed
misc: Clean up
1 parent 2c7b8ea commit ffe3246

13 files changed

Lines changed: 266 additions & 232 deletions

File tree

devito/ir/support/basic.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
1919
CriticalRegion, Function, Symbol, Temp, TempArray,
2020
TBArray)
21-
from devito.types.misc import PostIncrementIndex
2221

2322
__all__ = ['IterationInstance', 'TimedAccess', 'Scope', 'ExprGeometry']
2423

@@ -410,11 +409,7 @@ def distance(self, other):
410409
# `self.itintervals=(time, x, y)`, `n=0`
411410
continue
412411
elif not sai and not oai:
413-
# TODO: temp fix
414-
if any(isinstance(i, PostIncrementIndex)
415-
for i in (self[n], other[n])):
416-
ret.append(S.Zero)
417-
elif self[n] - other[n] == 0:
412+
if self[n] - other[n] == 0:
418413
# E.g., `self=R<a,[4]>` and `other=W<a,[4]>`
419414
ret.append(S.Zero)
420415
else:

devito/petsc/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def core_metadata():
5353
}
5454

5555

56-
# Maximum number of bytes (including the null terminator) reserved for a
57-
# KSPType string in the profiler struct.
56+
# Maximum number of bytes for a KSPType string in the profiler struct.
5857
KSPTYPE_MAX_LEN = 64
5958

6059

devito/petsc/iet/builder.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,7 @@ def _setup(self):
7070
local_size = math.prod(
7171
v for v, dim in zip(target.shape_allocated, target.dimensions) if dim.is_Space
7272
)
73-
# TODO: Check - VecCreateSeqWithArray
74-
# local_x = petsc_call('VecCreateMPIWithArray',
75-
# [sobjs['comm'], 1, local_size, 'PETSC_DECIDE',
76-
# field_from_ptr, Byref(sobjs['xlocal'])])
73+
# TODO: Check - VecCreateSeqWithArray vs VecCreateMPIWithArray
7774
local_x = petsc_call('VecCreateSeqWithArray',
7875
['PETSC_COMM_SELF', 1, local_size,
7976
field_from_ptr, Byref(sobjs['xlocal'])])
@@ -139,7 +136,7 @@ def _extend_setup(self):
139136

140137
def _create_dmda_calls(self, dmda):
141138
dmda_create = self._create_dmda(dmda)
142-
# TODO: probs need to set the dm options prefix the same as snes?
139+
# TODO: probably need to set the dm options prefix the same as snes?
143140
dm_set_from_opts = petsc_call('DMSetFromOptions', [dmda])
144141
dm_setup = petsc_call('DMSetUp', [dmda])
145142
dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL'])
@@ -355,7 +352,7 @@ def _create_dmda_calls(self, dmda):
355352
dmda_create = self._create_dmda(dmda)
356353

357354
# TODO: likely need to set the dm options prefix the same as snes?
358-
# likely shouldn't hardcode this option like this.. (should be set in the options
355+
# Probably shouldn't hardcode this option.. (should be set in the options
359356
# callback)
360357
da_create_section = petsc_call(
361358
'PetscOptionsSetValue', [Null, String("-da_use_section"), Null]

devito/petsc/iet/callbacks.py

Lines changed: 62 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _make_core(self):
122122
# Make the initial guess callback
123123
if self.field_data.initial_guess.exprs:
124124
self._make_initial_guess()
125-
# Make the callback used to constrain boundary nodes
125+
# Make the callback to constrain boundary nodes
126126
if self.field_data.constrain_bc:
127127
self._make_constrain_bc()
128128
self._make_user_struct_efunc()
@@ -650,14 +650,9 @@ def _make_constrain_bc(self):
650650
Constructs the `CountBCs` and `SetPointBCs` efuncs. Works for both
651651
single- and multi-field.
652652
"""
653-
constrain_bc = self.field_data.constrain_bc
653+
constrain_bc_dict = self.field_data.constrain_bc
654654
sobjs = self.solver_objs
655655

656-
# Normalize to dict {target: ConstrainBC}
657-
if isinstance(constrain_bc, dict):
658-
constrain_bc_dict = constrain_bc
659-
else:
660-
constrain_bc_dict = {self.field_data.target: constrain_bc}
661656
targets = list(constrain_bc_dict.keys())
662657

663658
all_increment_exprs = [
@@ -741,136 +736,14 @@ def _create_count_bc_body(self, body, pairs):
741736
return Uxreplace(subs).visit(body)
742737

743738
def _create_set_point_bc_body(self, body, constrain_bc_dict):
744-
"""Single-field SetPointBCs body. `constrain_bc_dict` has one entry."""
745-
(target, constrain_bc), = constrain_bc_dict.items()
746-
tname = target.name
747-
linsolve_expr = self.inject_solve.expr.rhs
748-
objs = self.objs
749-
sobjs = self.solver_objs
750-
751-
dmda = sobjs['callbackdm']
752-
ctx = objs['dummyctx']
753-
754-
dm_get_local_info = petsc_call(
755-
'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)]
756-
)
757-
758-
body = self.time_dependence.uxreplace_time(body)
759-
760-
fields = get_user_struct_fields(body)
761-
self._struct_params.extend(fields)
762-
763-
dm_get_app_context = petsc_call(
764-
'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)]
765-
)
766-
petsc_obj_comm = Call('PetscObjectComm', arguments=[PetscObjectCast(dmda)])
767-
is_create_general = petsc_call(
768-
'ISCreateGeneral',
769-
[petsc_obj_comm, sobjs[f'numBC_{tname}'], sobjs[f'bcPointsArr_{tname}'],
770-
'PETSC_OWN_POINTER', Byref(sobjs['bcPointsIS'])]
771-
)
772-
malloc_bc_points_arr = petsc_call(
773-
'PetscMalloc1',
774-
[sobjs[f'numBC_{tname}'], Byref(sobjs[f'bcPointsArr_{tname}']._C_symbol)]
775-
)
776-
malloc_bc_points = petsc_call(
777-
'PetscMalloc1', [1, Byref(sobjs['bcPoints']._C_symbol)]
778-
)
779-
dummy_expr = DummyExpr(sobjs['bcPoints'].indexed[0], sobjs['bcPointsIS'])
780-
set_point_bc = petsc_call(
781-
'DMDASetPointBC', [dmda, 1, sobjs['bcPoints'], Null]
782-
)
783-
body = body._rebuild(
784-
body=(
785-
(malloc_bc_points_arr,)
786-
+ body.body
787-
+ (is_create_general, malloc_bc_points, dummy_expr, set_point_bc,)
788-
)
789-
)
790-
791-
derefs = dereference_funcs(ctx, fields)
792-
standalones = [
793-
Definition(ctx),
794-
dm_get_app_context,
795-
Definition(sobjs[f'k_iter_{tname}'])
796-
]
797-
body = self._make_callable_body(
798-
body, standalones=standalones, stacks=(dm_get_local_info,) + derefs
799-
)
800-
801-
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for
802-
i in fields if not isinstance(i.function, AbstractFunction)}
803-
subs[constrain_bc.counter._C_symbol] = \
804-
sobjs[f'bcPointsArr_{tname}'].indexed[sobjs[f'k_iter_{tname}']]
805-
806-
return Uxreplace(subs).visit(body)
807-
808-
def _make_user_struct_efunc(self):
809-
"""
810-
This is the struct initialised inside the main kernel and
811-
attached to the DM via DMSetApplicationContext.
812-
"""
813-
mainctx = self.solver_objs['userctx'] = MainUserStruct(
814-
name=self.sregistry.make_name(prefix='ctx'),
815-
pname=self.sregistry.make_name(prefix='UserCtx'),
816-
fields=self.filtered_struct_params,
817-
liveness='lazy',
818-
modifier=None
819-
)
820-
body = [
821-
DummyExpr(FieldFromPointer(i._C_symbol, mainctx), i._C_symbol)
822-
for i in mainctx.callback_fields
823-
]
824-
struct_callback_body = self._make_callable_body(body)
825-
cb = Callable(
826-
self.sregistry.make_name(prefix='PopulateUserContext'),
827-
struct_callback_body, self.objs['err'],
828-
parameters=[mainctx]
829-
)
830-
self._efuncs[cb.name] = cb
831-
self._user_struct_efunc = cb
832-
833-
def _uxreplace_efuncs(self):
834-
sobjs = self.solver_objs
835-
callback_user_struct = CallbackUserStruct(
836-
name=sobjs['userctx'].name,
837-
pname=sobjs['userctx'].pname,
838-
fields=self.filtered_struct_params,
839-
liveness='lazy',
840-
modifier=' *',
841-
parent=sobjs['userctx']
842-
)
843-
mapper = {}
844-
visitor = Uxreplace({self.objs['dummyctx']: callback_user_struct})
845-
for k, v in self._efuncs.items():
846-
mapper.update({k: visitor.visit(v)})
847-
return mapper
848-
849-
850-
class CoupledCallbackBuilder(BaseCallbackBuilder):
851-
def __init__(self, **kwargs):
852-
self._submatrices_callback = None
853-
self._destroy_submat_callback = None
854-
super().__init__(**kwargs)
855-
856-
@property
857-
def submatrices_callback(self):
858-
return self._submatrices_callback
859-
860-
def _create_set_point_bc_body(self, body, _constrain_bc_dict):
861-
return self._create_set_point_bc_body_coupled(body)
862-
863-
def _create_set_point_bc_body_coupled(self, body):
864739
"""
865-
# TODO : ADD DOCS - MAKE IT CLEARER
866-
Combined SetPointBCs body for all target fields.
740+
Generic SetPointBCs body, handles single- and multi-field.
867741
"""
742+
targets = list(constrain_bc_dict.keys())
743+
nfields = len(targets)
868744
linsolve_expr = self.inject_solve.expr.rhs
869745
objs = self.objs
870746
sobjs = self.solver_objs
871-
constrain_bc = self.field_data.constrain_bc
872-
targets = self.field_data.targets
873-
nfields = len(targets)
874747
dmda = sobjs['callbackdm']
875748
ctx = objs['dummyctx']
876749

@@ -946,11 +819,63 @@ def _create_set_point_bc_body_coupled(self, body):
946819
i in fields if not isinstance(i.function, AbstractFunction)}
947820
for t in targets:
948821
tname = t.name
949-
subs[constrain_bc[t].counter._C_symbol] = \
822+
subs[constrain_bc_dict[t].counter._C_symbol] = \
950823
sobjs[f'bcPointsArr_{tname}'].indexed[sobjs[f'k_iter_{tname}']]
951824

952825
return Uxreplace(subs).visit(body)
953826

827+
def _make_user_struct_efunc(self):
828+
"""
829+
This is the struct initialised inside the main kernel and
830+
attached to the DM via DMSetApplicationContext.
831+
"""
832+
mainctx = self.solver_objs['userctx'] = MainUserStruct(
833+
name=self.sregistry.make_name(prefix='ctx'),
834+
pname=self.sregistry.make_name(prefix='UserCtx'),
835+
fields=self.filtered_struct_params,
836+
liveness='lazy',
837+
modifier=None
838+
)
839+
body = [
840+
DummyExpr(FieldFromPointer(i._C_symbol, mainctx), i._C_symbol)
841+
for i in mainctx.callback_fields
842+
]
843+
struct_callback_body = self._make_callable_body(body)
844+
cb = Callable(
845+
self.sregistry.make_name(prefix='PopulateUserContext'),
846+
struct_callback_body, self.objs['err'],
847+
parameters=[mainctx]
848+
)
849+
self._efuncs[cb.name] = cb
850+
self._user_struct_efunc = cb
851+
852+
def _uxreplace_efuncs(self):
853+
sobjs = self.solver_objs
854+
callback_user_struct = CallbackUserStruct(
855+
name=sobjs['userctx'].name,
856+
pname=sobjs['userctx'].pname,
857+
fields=self.filtered_struct_params,
858+
liveness='lazy',
859+
modifier=' *',
860+
parent=sobjs['userctx']
861+
)
862+
mapper = {}
863+
visitor = Uxreplace({self.objs['dummyctx']: callback_user_struct})
864+
for k, v in self._efuncs.items():
865+
mapper.update({k: visitor.visit(v)})
866+
return mapper
867+
868+
869+
class CoupledCallbackBuilder(BaseCallbackBuilder):
870+
def __init__(self, **kwargs):
871+
self._submatrices_callback = None
872+
self._destroy_submat_callback = None
873+
super().__init__(**kwargs)
874+
875+
@property
876+
def submatrices_callback(self):
877+
return self._submatrices_callback
878+
954879
@property
955880
def jacobian(self):
956881
return self.inject_solve.expr.rhs.field_data.jacobian
@@ -1175,9 +1100,10 @@ def _whole_formfunc_body(self, body):
11751100
return Uxreplace(subs).visit(formfunc_body)
11761101

11771102
def _create_destroy_submatrix(self):
1178-
# Need a special destroy because each submatrix has a manually
1179-
# PetscMalloc'ed context attached via MatShellSetContext
1180-
1103+
"""
1104+
Each submatrix has a PetscMalloc'd context attached via MatShellSetContext
1105+
that PETSc's default MatDestroy won't free, so we register a custom destroy.
1106+
"""
11811107
objs = self.objs
11821108

11831109
get_ctx = petsc_call(

devito/petsc/iet/passes.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -137,34 +137,12 @@ def lower_petsc(iet, **kwargs):
137137

138138
iet = Transformer(subs).visit(iet)
139139
body = core + tuple(setup) + iet.body.body + tuple(clear_options)
140-
# from IPython import embed; embed()
141140
body = iet.body._rebuild(body=body)
142141
iet = iet._rebuild(body=body)
143-
# from IPython import embed; embed()
144142
metadata = {**core_metadata(), 'efuncs': tuple(efuncs.values())}
145143
return iet, metadata
146144

147145

148-
@iet_pass
149-
def strip_petsc_callback_halos(iet, **kwargs):
150-
"""
151-
Remove any HaloSpot nodes that `mpiize` may have injected into PETSc
152-
callback functions (FormFunction, SetPointBCs, FormRHS, etc.).
153-
154-
HaloSpots should only appear in the main kernel.
155-
"""
156-
if not isinstance(iet, PETScCallable):
157-
return iet, {}
158-
159-
halos = FindNodes(HaloSpot).visit(iet)
160-
if not halos:
161-
return iet, {}
162-
163-
# Replace each HaloSpot with its body (unwrap it)
164-
mapper = {hs: hs.body for hs in halos}
165-
return Transformer(mapper).visit(iet), {}
166-
167-
168146
def lower_petsc_symbols(iet, **kwargs):
169147
"""
170148
The `place_definitions` and `place_casts` passes may introduce new
@@ -185,6 +163,10 @@ def lower_petsc_symbols(iet, **kwargs):
185163
@iet_pass
186164
def linear_indices(iet, **kwargs):
187165
"""
166+
Convert multidimensional grid accesses in the callback `SetPointBCs` to flat
167+
linear indices. DMDASetPointBC expects BC points as 1D offsets (e.g. i*ny + j),
168+
so each u[x, y] access is linearised to its stride expression and written into
169+
bcPointsArr.
188170
"""
189171
if not iet.name.startswith("SetPointBCs"):
190172
return iet, {}
@@ -196,9 +178,7 @@ def linear_indices(iet, **kwargs):
196178

197179
tracker = Tracker('basic', dtype, kwargs['sregistry'])
198180

199-
# TODO: CLEAN UP this is a hack
200-
# Exclude SubDomainSet backing functions from linearization - in SETPOINTBCS
201-
# I don't want to linearize the accesses to the SubDomainSet
181+
# TODO: Rethink - bit of a hack to only linearise the relevant index accesses
202182
indexeds = [
203183
i for i in FindSymbols('indexeds').visit(iet)
204184
if not isinstance(i.function, LocalType)

devito/petsc/iet/type_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
PointerIS, PointerDM, VecScatter, JacobianStruct, SubMatrixStruct,
1111
PetscMPIInt, PetscErrorCode, PointerMat, MatReuse,
1212
DummyArg, NofSubMats, PetscSectionGlobal,
13-
PetscSectionLocal, PetscSF, CallbackPetscInt, PointerPetscInt, SingleIS
13+
PetscSectionLocal, PetscSF, CallbackPetscInt, PointerPetscInt
1414
)
1515

1616

@@ -223,8 +223,8 @@ def _extend_build(self, base_dict):
223223
base_dict[f'k_iter_{tname}'] = PostIncrementIndex(
224224
name='k_iter', initvalue=0
225225
)
226-
base_dict['bcPointsIS'] = SingleIS(name='bcPointsIS')
227-
base_dict['bcPoints'] = PointerIS(name='bcPoints')
226+
base_dict['bcPointsIS'] = PointerIS(name='bcPointsIS', nindices=1)
227+
base_dict['bcCompsIS'] = PointerIS(name='bcCompsIS', nindices=1)
228228
return base_dict
229229

230230

devito/petsc/logging.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def fields(self):
168168
def _C_typedecl(self):
169169
"""
170170
Override generated C struct declaration so that KSPType fields
171-
are emitted as ``char name[KSPTYPE_MAX_LEN]`` rather than ``KSPType name``
172-
(i.e. ``const char *``). This avoids a segfault when Python reads the
173-
profiler struct after SNESDestroy has freed the KSP.
171+
are emitted as `char name[KSPTYPE_MAX_LEN]` rather than `KSPType name`
172+
(i.e. `const char *`). This provides a buffer for PetscStrncpy to copy
173+
the type string before SNESDestroy frees the KSP.
174174
"""
175175
entries = []
176176
for field in self._fields:

0 commit comments

Comments
 (0)