Skip to content

Commit c1c8baf

Browse files
committed
Update QUDA to lattice/quda#1489.
1 parent 55eff52 commit c1c8baf

File tree

11 files changed

+126
-55
lines changed

11 files changed

+126
-55
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Python wrapper for [QUDA](https://github.com/lattice/quda) written in Cython.
44

55
This project aims to benefit from the optimized linear algebra library [CuPy](https://cupy.dev/) in Python based on CUDA. CuPy and QUDA will allow us to perform most lattice QCD research operations with high performance. [PyTorch](https://pytorch.org/) is an alternative option.
66

7-
This project is based on the latest QUDA `develop` branch. PyQUDA should be compatible with any commit of QUDA after 2024, but leave some features disabled.
7+
This project is based on the latest QUDA `develop` branch. PyQUDA should be compatible with any commit of QUDA after https://github.com/lattice/quda/pull/1489, but leave some features disabled.
88

99
## Feature
1010

pyquda/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
LatticeGauge,
1212
LatticeMom,
1313
LatticeFermion,
14+
MultiLatticeFermion,
1415
LatticeStaggeredFermion,
16+
MultiLatticeStaggeredFermion,
1517
LatticePropagator,
1618
LatticeStaggeredPropagator,
1719
lexico,

pyquda/dirac/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
QudaGaugeSmearParam,
1010
QudaGaugeObservableParam,
1111
invertQuda,
12+
invertMultiSrcQuda,
1213
invertMultiShiftQuda,
1314
MatQuda,
1415
MatDagMatQuda,
1516
dslashQuda,
17+
dslashMultiSrcQuda,
1618
newMultigridQuda,
1719
updateMultigridQuda,
1820
destroyMultigridQuda,
@@ -232,6 +234,19 @@ def dslash(self, x: LatticeFermion, parity: QudaParity):
232234
dslashQuda(b.data_ptr, x.data_ptr, self.invert_param, parity)
233235
return b
234236

237+
def invertMultiSrc(self, b: MultiLatticeFermion):
238+
self.invert_param.num_src = b.L5
239+
x = MultiLatticeFermion(b.latt_info, b.L5)
240+
invertMultiSrcQuda(x.data_ptrs, b.data_ptrs, self.invert_param)
241+
self.performance()
242+
return x
243+
244+
def dslashMultiSrc(self, x: MultiLatticeFermion, parity: QudaParity):
245+
self.invert_param.num_src = x.L5
246+
b = MultiLatticeFermion(x.latt_info, x.L5)
247+
dslashMultiSrcQuda(b.data_ptrs, x.data_ptrs, self.invert_param, parity)
248+
return b
249+
235250
def _invertMultiShiftParam(self, offset: List[float], residue: List[float], norm: float = None):
236251
assert len(offset) == len(residue)
237252
num_offset = len(offset)
@@ -349,6 +364,19 @@ def dslash(self, x: LatticeStaggeredFermion, parity: QudaParity):
349364
dslashQuda(b.data_ptr, x.data_ptr, self.invert_param, parity)
350365
return b
351366

367+
def invertMultiSrc(self, b: MultiLatticeStaggeredFermion):
368+
self.invert_param.num_src = b.L5
369+
x = MultiLatticeStaggeredFermion(b.latt_info, b.L5)
370+
invertMultiSrcQuda(x.data_ptrs, b.data_ptrs, self.invert_param)
371+
self.performance()
372+
return x
373+
374+
def dslashMultiSrc(self, x: MultiLatticeStaggeredFermion, parity: QudaParity):
375+
self.invert_param.num_src = x.L5
376+
b = MultiLatticeStaggeredFermion(x.latt_info, x.L5)
377+
dslashMultiSrcQuda(b.data_ptrs, x.data_ptrs, self.invert_param, parity)
378+
return b
379+
352380
def invertMultiShiftPC(
353381
self, b: LatticeStaggeredFermion, offset: List[float], residue: List[float], norm: float = None
354382
) -> Union[LatticeStaggeredFermion, MultiLatticeStaggeredFermion]:

pyquda/dirac/general.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def newQudaMultigridParam(
232232
mg_param.n_block_ortho = [1] * QUDA_MAX_MG_LEVEL
233233

234234
mg_param.setup_inv_type = [QudaInverterType.QUDA_CGNR_INVERTER] * QUDA_MAX_MG_LEVEL
235+
mg_param.n_vec_batch = [1] * QUDA_MAX_MG_LEVEL
235236
mg_param.num_setup_iter = [1] * QUDA_MAX_MG_LEVEL
236237
mg_param.setup_tol = [setup_tol] * QUDA_MAX_MG_LEVEL
237238
mg_param.setup_maxiter = [setup_maxiter] * QUDA_MAX_MG_LEVEL

pyquda/enum_quda.in.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,9 @@
2727
This number may be changed if need be.
2828
"""
2929

30-
QUDA_MAX_BLOCK_SRC = 64
30+
QUDA_MAX_MULTI_SRC = 128
3131
"""
32-
Maximum number of sources that can be supported by the block solver
33-
"""
34-
35-
QUDA_MAX_ARRAY_SIZE = max(QUDA_MAX_MULTI_SHIFT, QUDA_MAX_BLOCK_SRC)
36-
"""
37-
Maximum array length used in QudaInvertParam arrays
32+
Maximum number of sources that can be supported by the multi-src solver
3833
"""
3934

4035
QUDA_MAX_DWF_LS = 32
@@ -97,6 +92,11 @@ class QudaGaugeFixed(IntEnum):
9792

9893

9994
class QudaDslashType(IntEnum):
95+
"""
96+
Note: make sure QudaDslashType has corresponding entries in
97+
tests/utils/misc.cpp
98+
"""
99+
100100
pass
101101

102102

pyquda/pyquda.pyi

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ from .enum_quda import ( # noqa: F401
1515
QUDA_MAX_DIM,
1616
QUDA_MAX_GEOMETRY,
1717
QUDA_MAX_MULTI_SHIFT,
18-
QUDA_MAX_BLOCK_SRC,
19-
QUDA_MAX_ARRAY_SIZE,
18+
QUDA_MAX_MULTI_SRC,
2019
QUDA_MAX_DWF_LS,
2120
QUDA_MAX_MG_LEVEL,
2221
qudaError_t,
@@ -277,10 +276,10 @@ class QudaInvertParam:
277276

278277
compute_true_res: int
279278
"""Whether to compute the true residual post solve"""
280-
true_res: double
281-
"""Actual L2 residual norm achieved in solver"""
282-
true_res_hq: double
283-
"""Actual heavy quark residual norm achieved in solver"""
279+
true_res: List[double, QUDA_MAX_MULTI_SRC]
280+
"""Actual L2 residual norm achieved in the solver"""
281+
true_res_hq: List[double, QUDA_MAX_MULTI_SRC]
282+
"""Actual heavy quark residual norm achieved in the solver"""
284283
maxiter: int
285284
"""Maximum number of iterations in the linear solver"""
286285
reliable_delta: double
@@ -473,6 +472,14 @@ class QudaInvertParam:
473472
"""The Gflops rate of the solver"""
474473
secs: double
475474
"""The time taken by the solver"""
475+
energy: double
476+
"""The energy consumed by the solver"""
477+
power: double
478+
"""The mean power of the solver"""
479+
temp: double
480+
"""The mean temperature of the device for the duration of the solve"""
481+
clock: double
482+
"""The mean clock frequency of the device for the duration of the solve"""
476483

477484
tune: QudaTune
478485
"""Enable auto-tuning? (default = QUDA_TUNE_YES)"""
@@ -763,6 +770,8 @@ class QudaEigParam:
763770
"""For the Ritz rotation, the maximal number of extra vectors the solver may allocate"""
764771
block_size: int
765772
"""For block method solvers, the block size"""
773+
compute_evals_batch_size: int
774+
"""The batch size used when computing eigenvalues"""
766775
max_ortho_attempts: int
767776
"""For block method solvers, quit after n attempts at block orthonormalisation"""
768777
ortho_block_size: int
@@ -815,12 +824,6 @@ class QudaEigParam:
815824
partfile: QudaBoolean
816825
"""Whether to save eigenvectors in QIO singlefile or partfile format"""
817826

818-
gflops: double
819-
"""The Gflops rate of the eigensolver setup"""
820-
821-
secs: double
822-
"""The time taken by the eigensolver setup"""
823-
824827
extlib_type: QudaExtLibType
825828
"""Which external library to use in the deflation operations (Eigen)"""
826829

@@ -868,6 +871,9 @@ class QudaMultigridParam:
868871
setup_inv_type: List[QudaInverterType, QUDA_MAX_MG_LEVEL]
869872
"""Inverter to use in the setup phase"""
870873

874+
n_vec_batch: List[int, QUDA_MAX_MG_LEVEL]
875+
"""Solver batch size to use in the setup phase"""
876+
871877
num_setup_iter: List[int, QUDA_MAX_MG_LEVEL]
872878
"""Number of setup iterations"""
873879

@@ -1022,12 +1028,6 @@ class QudaMultigridParam:
10221028
preserve_deflation: QudaBoolean
10231029
"""Whether to preserve the deflation space during MG update"""
10241030

1025-
gflops: double
1026-
"""The Gflops rate of the multigrid solver setup"""
1027-
1028-
secs: double
1029-
"""The time taken by the multigrid solver setup"""
1030-
10311031
mu_factor: List[double, QUDA_MAX_MG_LEVEL]
10321032
"""Multiplicative factor for the mu parameter"""
10331033

@@ -1364,6 +1364,25 @@ def invertQuda(h_x: Pointer, h_b: Pointer, param: QudaInvertParam) -> None:
13641364
Contains all metadata regarding host and device
13651365
storage and solver parameters
13661366
"""
1367+
1368+
def invertMultiSrcQuda(_hp_x: Pointers, _hp_b: Pointers, param: QudaInvertParam) -> None:
1369+
"""
1370+
Perform the solve like @invertQuda but for multiple rhs by spliting the comm grid into
1371+
sub-partitions: each sub-partition invert one or more rhs'.
1372+
The QudaInvertParam object specifies how the solve should be performed on each sub-partition.
1373+
Unlike @invertQuda, the interface also takes the host side gauge as input. The gauge pointer and
1374+
gauge_param are used if for inv_param split_grid[0] * split_grid[1] * split_grid[2] * split_grid[3]
1375+
is larger than 1, in which case gauge field is not required to be loaded beforehand; otherwise
1376+
this interface would just work as @invertQuda, which requires gauge field to be loaded beforehand,
1377+
and the gauge field pointer and gauge_param are not used.
1378+
1379+
@param _hp_x:
1380+
Array of solution spinor fields
1381+
@param _hp_b:
1382+
Array of source spinor fields
1383+
@param param:
1384+
Contains all metadata regarding host and device storage and solver parameters
1385+
"""
13671386
...
13681387

13691388
def invertMultiShiftQuda(_hp_x: Pointers, _hp_b: Pointer, param: QudaInvertParam) -> None:
@@ -1456,6 +1475,24 @@ def dslashQuda(h_out: Pointer, h_in: Pointer, inv_param: QudaInvertParam, parity
14561475
"""
14571476
...
14581477

1478+
def dslashMultiSrcQuda(_hp_x: Pointers, _hp_b: Pointers, param: QudaInvertParam, parity: QudaParity) -> None:
1479+
"""
1480+
Perform the solve like @dslashQuda but for multiple rhs by spliting the comm grid into
1481+
sub-partitions: each sub-partition does one or more rhs'.
1482+
The QudaInvertParam object specifies how the solve should be performed on each sub-partition.
1483+
Unlike @invertQuda, the interface also takes the host side gauge as
1484+
input - gauge field is not required to be loaded beforehand.
1485+
1486+
@param _hp_x:
1487+
Array of solution spinor fields
1488+
@param _hp_b:
1489+
Array of source spinor fields
1490+
@param param:
1491+
Contains all metadata regarding host and device storage and solver parameters
1492+
@param parity:
1493+
Parity to apply dslash on
1494+
"""
1495+
14591496
def cloverQuda(h_out: Pointer, h_in: Pointer, inv_param: QudaInvertParam, parity: QudaParity, inverse: int) -> None:
14601497
"""
14611498
Apply the clover operator or its inverse.
@@ -2074,6 +2111,14 @@ class QudaQuarkSmearParam:
20742111
"""Time taken for the smearing operations"""
20752112
gflops: double
20762113
"""Flops count for the smearing operations"""
2114+
energy: double
2115+
"""The energy consumed by the smearing operations"""
2116+
power: double
2117+
"""The mean power of the smearing operations"""
2118+
temp: double
2119+
"""The mean temperature of the device for the duration of the smearing operations"""
2120+
clock: double
2121+
"""The mean clock frequency of the device for the duration of the smearing operations"""
20772122

20782123
def performTwoLinkGaussianSmearNStep(h_in: Pointer, smear_param: QudaQuarkSmearParam) -> None:
20792124
"""

pyquda/quda/include/enum_quda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ typedef enum QudaGaugeFixed_s {
8989
// Types used in QudaInvertParam
9090
//
9191

92+
// Note: make sure QudaDslashType has corresponding entries in
93+
// tests/utils/misc.cpp
9294
typedef enum QudaDslashType_s {
9395
QUDA_WILSON_DSLASH,
9496
QUDA_CLOVER_WILSON_DSLASH,

pyquda/quda/include/quda.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ extern "C" {
145145
double tol_hq; /**< Solver tolerance in the heavy quark residual norm */
146146

147147
int compute_true_res; /** Whether to compute the true residual post solve */
148-
double true_res; /**< Actual L2 residual norm achieved in solver */
149-
double true_res_hq; /**< Actual heavy quark residual norm achieved in solver */
148+
double true_res[QUDA_MAX_MULTI_SRC]; /**< Actual L2 residual norm achieved in the solver */
149+
double true_res_hq[QUDA_MAX_MULTI_SRC]; /**< Actual heavy quark residual norm achieved in the solver */
150150
int maxiter; /**< Maximum number of iterations in the linear solver */
151151
double reliable_delta; /**< Reliable update tolerance */
152152
double reliable_delta_refinement; /**< Reliable update tolerance used in post multi-shift solver refinement */
@@ -278,6 +278,10 @@ extern "C" {
278278
int iter; /**< The number of iterations performed by the solver */
279279
double gflops; /**< The Gflops rate of the solver */
280280
double secs; /**< The time taken by the solver */
281+
double energy; /**< The energy consumed by the solver */
282+
double power; /**< The mean power of the solver */
283+
double temp; /**< The mean temperature of the device for the duration of the solve */
284+
double clock; /**< The mean clock frequency of the device for the duration of the solve */
281285

282286
QudaTune tune; /**< Enable auto-tuning? (default = QUDA_TUNE_YES) */
283287

@@ -550,6 +554,8 @@ extern "C" {
550554
int batched_rotate;
551555
/** For block method solvers, the block size **/
552556
int block_size;
557+
/** The batch size used when computing eigenvalues **/
558+
int compute_evals_batch_size;
553559
/** For block method solvers, quit after n attempts at block orthonormalisation **/
554560
int max_ortho_attempts;
555561
/** For hybrid modifeld Gram-Schmidt orthonormalisations **/
@@ -602,12 +608,6 @@ extern "C" {
602608
/** Whether to save eigenvectors in QIO singlefile or partfile format */
603609
QudaBoolean partfile;
604610

605-
/** The Gflops rate of the eigensolver setup */
606-
double gflops;
607-
608-
/**< The time taken by the eigensolver setup */
609-
double secs;
610-
611611
/** Which external library to use in the deflation operations (Eigen) */
612612
QudaExtLibType extlib_type;
613613
//-------------------------------------------------
@@ -655,6 +655,9 @@ extern "C" {
655655
/** Inverter to use in the setup phase */
656656
QudaInverterType setup_inv_type[QUDA_MAX_MG_LEVEL];
657657

658+
/** Solver batch size to use in the setup phase */
659+
int n_vec_batch[QUDA_MAX_MG_LEVEL];
660+
658661
/** Number of setup iterations */
659662
int num_setup_iter[QUDA_MAX_MG_LEVEL];
660663

@@ -805,12 +808,6 @@ extern "C" {
805808
/** Whether to preserve the deflation space during MG update */
806809
QudaBoolean preserve_deflation;
807810

808-
/** The Gflops rate of the multigrid solver setup */
809-
double gflops;
810-
811-
/**< The time taken by the multigrid solver setup */
812-
double secs;
813-
814811
/** Multiplicative factor for the mu parameter */
815812
double mu_factor[QUDA_MAX_MG_LEVEL];
816813

@@ -1819,6 +1816,10 @@ extern "C" {
18191816
double secs;
18201817
/** Flops count for the smearing operations **/
18211818
double gflops;
1819+
double energy; /**< The energy consumed by the smearing operations */
1820+
double power; /**< The mean power of the smearing operations */
1821+
double temp; /**< The mean temperature of the device for the duration of the smearing operations */
1822+
double clock; /**< The mean clock frequency of the device for the duration of the smearing operations */
18221823

18231824
} QudaQuarkSmearParam;
18241825

pyquda/quda/include/quda_constants.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,9 @@
3232

3333
/**
3434
* @def QUDA_MAX_BLOCK_SRC
35-
* @brief Maximum number of sources that can be supported by the block solver
35+
* @brief Maximum number of sources that can be supported by the multi-src solver
3636
*/
37-
#define QUDA_MAX_BLOCK_SRC 64
38-
39-
/**
40-
* @def QUDA_MAX_ARRAY
41-
* @brief Maximum array length used in QudaInvertParam arrays
42-
*/
43-
#define QUDA_MAX_ARRAY_SIZE (QUDA_MAX_MULTI_SHIFT > QUDA_MAX_BLOCK_SRC ? QUDA_MAX_MULTI_SHIFT : QUDA_MAX_BLOCK_SRC)
37+
#define QUDA_MAX_MULTI_SRC 128
4438

4539
/**
4640
* @def QUDA_MAX_DWF_LS

pyquda/src/pyquda.in.pyx

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,8 @@ def eigensolveQuda(Pointers h_evecs, ndarray[double_complex, ndim=1] h_evals, Qu
251251
def invertQuda(Pointer h_x, Pointer h_b, QudaInvertParam param):
252252
quda.invertQuda(h_x.ptr, h_b.ptr, &param.param)
253253

254-
# def invertMultiSrcQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, Pointer h_gauge, QudaGaugeParam gauge_param)
255-
# def invertMultiSrcStaggeredQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, Pointer milc_fatlinks, Pointer milc_longlinks, QudaGaugeParam gauge_param)
256-
# def invertMultiSrcCloverQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, Pointer h_gauge, QudaGaugeParam gauge_param, Pointer h_clover, Pointer h_clovinv)
254+
def invertMultiSrcQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param):
255+
quda.invertMultiSrcQuda(_hp_x.ptrs, _hp_b.ptrs, &param.param)
257256

258257
def invertMultiShiftQuda(Pointers _hp_x, Pointer _hp_b, QudaInvertParam param):
259258
quda.invertMultiShiftQuda(_hp_x.ptrs, _hp_b.ptr, &param.param)
@@ -276,9 +275,8 @@ def dumpMultigridQuda(Pointer mg_instance, QudaMultigridParam param):
276275
def dslashQuda(Pointer h_out, Pointer h_in, QudaInvertParam inv_param, quda.QudaParity parity):
277276
quda.dslashQuda(h_out.ptr, h_in.ptr, &inv_param.param, parity)
278277

279-
# def dslashMultiSrcQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, QudaParity parity, Pointer h_gauge, QudaGaugeParam gauge_param)
280-
# def dslashMultiSrcStaggeredQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, QudaParity parity, Pointers milc_fatlinks, Pointers milc_longlinks, QudaGaugeParam gauge_param)
281-
# def dslashMultiSrcCloverQuda(Pointers_hp_x, Pointers_hp_b, QudaInvertParam param, QudaParity parity, Pointer h_gauge, QudaGaugeParam gauge_param, Pointer h_clover, Pointer h_clovinv)
278+
def dslashMultiSrcQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, quda.QudaParity parity):
279+
quda.dslashMultiSrcQuda(_hp_x.ptrs, _hp_b.ptrs, &param.param, parity)
282280

283281
def cloverQuda(Pointer h_out, Pointer h_in, QudaInvertParam inv_param, quda.QudaParity parity, int inverse):
284282
quda.cloverQuda(h_out.ptr, h_in.ptr, &inv_param.param, parity, inverse)

0 commit comments

Comments
 (0)