Skip to content

Commit ad9c62b

Browse files
authored
make loop example parallel safe (#4)
1 parent 2b9fbb1 commit ad9c62b

File tree

1 file changed

+114
-41
lines changed

1 file changed

+114
-41
lines changed

Notebooks/Examples-Loop-Mesh/Ex_Darcy_3D_Loop_Mesh_Fault.py

Lines changed: 114 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,35 @@
1111
import nest_asyncio
1212
nest_asyncio.apply()
1313

14-
# +
1514
import underworld3 as uw
1615
import numpy as np
1716
from enum import Enum
1817
from petsc4py import PETSc
1918
import sympy
2019
from sympy import Piecewise, ceiling, Abs
21-
import matplotlib.pyplot as plt
22-
20+
import os
2321
options = PETSc.Options()
24-
# -
2522

2623
# vis tools
2724
if uw.mpi.size == 1:
2825
import pyvista as pv
2926
import underworld3.visualisation as vis
27+
import matplotlib.pyplot as plt
3028

3129
# importing loop meshing tools
3230
from underworld3.utilities import create_dmplex_from_medit
3331

32+
# +
33+
# create output dir
34+
if uw.mpi.size==1:
35+
output_dir = './output/darcy_loop_mesh/serial/'
36+
else:
37+
output_dir = './output/darcy_loop_mesh/parallel/'
38+
39+
if uw.mpi.rank == 0:
40+
os.makedirs(output_dir, exist_ok=True)
41+
# -
42+
3443
# loading mesh data from dmplex
3544
medit_plex = create_dmplex_from_medit('./meshout.mesh')
3645

@@ -220,7 +229,7 @@ def plot_P_V(_mesh, _p_soln, _v_soln):
220229
# +
221230
# set up two materials
222231
interface = 2.5
223-
k1 = 1e-3
232+
k1 = 1e-1
224233
k2 = 1e-5
225234

226235
# Groundwater pressure boundary condition on the left wall
@@ -247,21 +256,73 @@ def plot_P_V(_mesh, _p_soln, _v_soln):
247256
perm_arr = np.zeros_like(permeability.data)
248257

249258
# +
250-
# labels start and end
251-
cStart, cEnd = mesh.dm.getHeightStratum(0)
259+
# # dealing with vertices
252260

253-
# fault tetra
254-
fault_tetra = mesh.dm.getStratumIS("TetraLabels", 404).array
261+
# pStart, pEnd = mesh.dm.getDepthStratum(0)
262+
# pNum = pEnd-pStart
255263

256-
for c in range(cStart, cEnd):
257-
if c in fault_tetra:
258-
perm_arr[c] = k1
259-
# with mesh.access(permeability):
260-
# permeability.data[c] = k1
264+
# # Get coordinates array
265+
# coords = mesh.dm.getCoordinates().array.reshape(pNum, mesh.cdim)
266+
267+
# # dealing with tets
268+
# tet_Start, tet_End = mesh.dm.getDepthStratum(3)
269+
270+
# # fault tetra
271+
# fault_tetra = mesh.dm.getStratumIS("TetraLabels", 404).array
272+
273+
# for t in range(tet_Start, tet_End):
274+
# if t in fault_tetra:
275+
# coneclose, orient = mesh.dm.getTransitiveClosure(t)
276+
# if np.any(coords[coneclose[-4:]-tet_pEnd][:,1]>5):
277+
# perm_arr[t] = k1
278+
# else:
279+
# perm_arr[t] = 0.0
280+
# else:
281+
# perm_arr[t] = k2
282+
283+
# +
284+
comm = mesh.dm.getComm() # PETSc communicator
285+
286+
# Get vertex depth range
287+
pStart, pEnd = mesh.dm.getDepthStratum(0)
288+
pNum = pEnd - pStart # Number of local vertices
289+
290+
# Get coordinates
291+
coord_sec = mesh.dm.getCoordinateSection()
292+
coord_vec = mesh.dm.getCoordinatesLocal().array
293+
294+
# Ensure proper reshaping by getting the actual number of vertices
295+
actual_pNum = coord_vec.shape[0] // mesh.cdim
296+
coords = coord_vec.reshape(actual_pNum, mesh.cdim) # Use computed pNum
297+
298+
# Get tetrahedral depth range
299+
tet_Start, tet_End = mesh.dm.getDepthStratum(3)
300+
301+
# Get fault tetra indices
302+
fault_tetra_set = set(mesh.dm.getStratumIS("TetraLabels", 404).array)
303+
304+
# Parallel-safe permeability assignment
305+
for t in range(tet_Start, tet_End):
306+
if t in fault_tetra_set:
307+
coneclose, orient = mesh.dm.getTransitiveClosure(t, useCone=True)
308+
309+
# Get the last 4 entries (vertices) and their y-coordinates
310+
vertex_indices = [v for v in coneclose if pStart <= v < pEnd]
311+
if len(vertex_indices) != 4:
312+
continue # Ensure we have exactly 4 vertices for the tetrahedron
313+
314+
y_coords = [coords[v - pStart][1] for v in vertex_indices]
315+
316+
# Assign permeability based on y-coordinate
317+
if np.any(np.array(y_coords) > 5):
318+
perm_arr[t] = k1
319+
else:
320+
perm_arr[t] = 0.0
261321
else:
262-
perm_arr[c] = k2
263-
# with mesh.access(permeability):
264-
# permeability.data[c] = k2
322+
perm_arr[t] = k2
323+
324+
# Ensure consistency across MPI ranks
325+
comm.Barrier()
265326
# -
266327

267328
# assigning k values to mesh variable
@@ -271,14 +332,22 @@ def plot_P_V(_mesh, _p_soln, _v_soln):
271332
darcy.constitutive_model.Parameters.permeability = permeability.sym[0]
272333

273334
# darcy solve without gravity
274-
darcy.solve()
335+
darcy.solve(verbose=True)
336+
337+
# +
338+
# # saving output
275339

276-
# saving output
277340
mesh.petsc_save_checkpoint(index=0, meshVars=[p_soln, v_soln],
278-
outputPath='./output/darcy_3d_loop_mesh_fault_no_g')
341+
outputPath=f'{output_dir}darcy_3d_loop_mesh_fault_no_g')
342+
343+
# mesh.write_timestep(f'darcy_3d_loop_mesh_fault_no_g', meshUpdates=True,
344+
# meshVars=[v_soln, p_soln],
345+
# outputPath='./output/', index=0,)
346+
# -
279347

280348
# plotting soln without gravity
281-
plot_P_V(mesh, p_soln, v_soln)
349+
if uw.mpi.size==1:
350+
plot_P_V(mesh, p_soln, v_soln)
282351

283352
# # copy soln
284353
with mesh.access(p_soln_0, v_soln_0):
@@ -287,34 +356,38 @@ def plot_P_V(_mesh, _p_soln, _v_soln):
287356

288357
# now switch on gravity
289358
darcy.constitutive_model.Parameters.s = sympy.Matrix([0, 0, -1]).T
290-
darcy.solve()
359+
darcy.solve(verbose=True)
291360

361+
# +
292362
# saving output
363+
293364
mesh.petsc_save_checkpoint(index=0, meshVars=[p_soln, v_soln, permeability],
294-
outputPath='./output/darcy_3d_loop_mesh_fault_g')
365+
outputPath=f'{output_dir}darcy_3d_loop_mesh_fault_g')
366+
# -
295367

296368
# plotting soln without gravity
297-
plot_P_V(mesh, p_soln, v_soln)
369+
if uw.mpi.size==1:
370+
plot_P_V(mesh, p_soln, v_soln)
298371

299-
# +
300372
# set up interpolation coordinates
301-
xcoords = np.linspace(minX + 0.001 * (maxX - minX), maxX - 0.001 * (maxX - minX), 100)
302-
ycoords = np.full_like(xcoords, 5)
303-
zcoords = np.full_like(xcoords, 2)
304-
xyz_coords = np.column_stack([xcoords, ycoords, zcoords])
305-
306-
pressure_interp = uw.function.evaluate(p_soln.sym[0], xyz_coords)
307-
pressure_interp_0 = uw.function.evaluate(p_soln_0.sym[0], xyz_coords)
308-
# -
373+
if uw.mpi.size==1:
374+
xcoords = np.linspace(minX + 0.001 * (maxX - minX), maxX - 0.001 * (maxX - minX), 100)
375+
ycoords = np.full_like(xcoords, 5)
376+
zcoords = np.full_like(xcoords, 2)
377+
xyz_coords = np.column_stack([xcoords, ycoords, zcoords])
378+
379+
pressure_interp = uw.function.evaluate(p_soln.sym[0], xyz_coords)
380+
pressure_interp_0 = uw.function.evaluate(p_soln_0.sym[0], xyz_coords)
309381

310382
# plotting numerical and analytical solution
311-
fig = plt.figure(figsize=(15,7))
312-
ax1 = fig.add_subplot(111, xlabel="X-Distance", ylabel="Pressure")
313-
ax1.plot(xcoords, pressure_interp, linewidth=3, label="Numerical solution")
314-
ax1.plot(xcoords, pressure_interp_0, linewidth=3, label="Numerical solution (no G)")
315-
# ax1.plot(pressure_analytic, xcoords, linewidth=3, linestyle="--", label="Analytic solution")
316-
# ax1.plot(pressure_analytic_noG, xcoords, linewidth=3, linestyle="--", label="Analytic (no gravity)")
317-
ax1.grid("on")
318-
ax1.legend()
383+
if uw.mpi.size==1:
384+
fig = plt.figure(figsize=(15,7))
385+
ax1 = fig.add_subplot(111, xlabel="X-Distance", ylabel="Pressure")
386+
ax1.plot(xcoords, pressure_interp, linewidth=3, label="Numerical solution")
387+
ax1.plot(xcoords, pressure_interp_0, linewidth=3, label="Numerical solution (no G)")
388+
# ax1.plot(pressure_analytic, xcoords, linewidth=3, linestyle="--", label="Analytic solution")
389+
# ax1.plot(pressure_analytic_noG, xcoords, linewidth=3, linestyle="--", label="Analytic (no gravity)")
390+
ax1.grid("on")
391+
ax1.legend()
319392

320393

0 commit comments

Comments
 (0)