@@ -15,8 +15,7 @@ from .enum_quda import (  # noqa: F401
1515    QUDA_MAX_DIM ,
1616    QUDA_MAX_GEOMETRY ,
1717    QUDA_MAX_MULTI_SHIFT ,
18-     QUDA_MAX_BLOCK_SRC ,
19-     QUDA_MAX_ARRAY_SIZE ,
18+     QUDA_MAX_MULTI_SRC ,
2019    QUDA_MAX_DWF_LS ,
2120    QUDA_MAX_MG_LEVEL ,
2221    qudaError_t ,
@@ -277,10 +276,10 @@ class QudaInvertParam:
277276
278277    compute_true_res : int 
279278    """Whether to compute the true residual post solve""" 
280-     true_res : double 
281-     """Actual L2 residual norm achieved in solver""" 
282-     true_res_hq : double 
283-     """Actual heavy quark residual norm achieved in solver""" 
279+     true_res : List [ double ,  QUDA_MAX_MULTI_SRC ] 
280+     """Actual L2 residual norm achieved in the  solver""" 
281+     true_res_hq : List [ double ,  QUDA_MAX_MULTI_SRC ] 
282+     """Actual heavy quark residual norm achieved in the  solver""" 
284283    maxiter : int 
285284    """Maximum number of iterations in the linear solver""" 
286285    reliable_delta : double 
@@ -473,6 +472,14 @@ class QudaInvertParam:
473472    """The Gflops rate of the solver""" 
474473    secs : double 
475474    """The time taken by the solver""" 
475+     energy : double 
476+     """The energy consumed by the solver""" 
477+     power : double 
478+     """The mean power of the solver""" 
479+     temp : double 
480+     """The mean temperature of the device for the duration of the solve""" 
481+     clock : double 
482+     """The mean clock frequency of the device for the duration of the solve""" 
476483
477484    tune : QudaTune 
478485    """Enable auto-tuning? (default = QUDA_TUNE_YES)""" 
@@ -763,6 +770,8 @@ class QudaEigParam:
763770    """For the Ritz rotation, the maximal number of extra vectors the solver may allocate""" 
764771    block_size : int 
765772    """For block method solvers, the block size""" 
773+     compute_evals_batch_size : int 
774+     """The batch size used when computing eigenvalues""" 
766775    max_ortho_attempts : int 
767776    """For block method solvers, quit after n attempts at block orthonormalisation""" 
768777    ortho_block_size : int 
@@ -815,12 +824,6 @@ class QudaEigParam:
815824    partfile : QudaBoolean 
816825    """Whether to save eigenvectors in QIO singlefile or partfile format""" 
817826
818-     gflops : double 
819-     """The Gflops rate of the eigensolver setup""" 
820- 
821-     secs : double 
822-     """The time taken by the eigensolver setup""" 
823- 
824827    extlib_type : QudaExtLibType 
825828    """Which external library to use in the deflation operations (Eigen)""" 
826829
@@ -868,6 +871,9 @@ class QudaMultigridParam:
868871    setup_inv_type : List [QudaInverterType , QUDA_MAX_MG_LEVEL ]
869872    """Inverter to use in the setup phase""" 
870873
874+     n_vec_batch : List [int , QUDA_MAX_MG_LEVEL ]
875+     """Solver batch size to use in the setup phase""" 
876+ 
871877    num_setup_iter : List [int , QUDA_MAX_MG_LEVEL ]
872878    """Number of setup iterations""" 
873879
@@ -1022,12 +1028,6 @@ class QudaMultigridParam:
10221028    preserve_deflation : QudaBoolean 
10231029    """Whether to preserve the deflation space during MG update""" 
10241030
1025-     gflops : double 
1026-     """The Gflops rate of the multigrid solver setup""" 
1027- 
1028-     secs : double 
1029-     """The time taken by the multigrid solver setup""" 
1030- 
10311031    mu_factor : List [double , QUDA_MAX_MG_LEVEL ]
10321032    """Multiplicative factor for the mu parameter""" 
10331033
@@ -1364,6 +1364,25 @@ def invertQuda(h_x: Pointer, h_b: Pointer, param: QudaInvertParam) -> None:
13641364        Contains all metadata regarding host and device 
13651365        storage and solver parameters 
13661366    """ 
1367+ 
1368+ def  invertMultiSrcQuda (_hp_x : Pointers , _hp_b : Pointers , param : QudaInvertParam ) ->  None :
1369+     """ 
1370+     Perform the solve like @invertQuda but for multiple rhs by spliting the comm grid into 
1371+     sub-partitions: each sub-partition invert one or more rhs'. 
1372+     The QudaInvertParam object specifies how the solve should be performed on each sub-partition. 
1373+     Unlike @invertQuda, the interface also takes the host side gauge as input. The gauge pointer and 
1374+     gauge_param are used if for inv_param split_grid[0] * split_grid[1] * split_grid[2] * split_grid[3] 
1375+     is larger than 1, in which case gauge field is not required to be loaded beforehand; otherwise 
1376+     this interface would just work as @invertQuda, which requires gauge field to be loaded beforehand, 
1377+     and the gauge field pointer and gauge_param are not used. 
1378+ 
1379+     @param _hp_x: 
1380+         Array of solution spinor fields 
1381+     @param _hp_b: 
1382+         Array of source spinor fields 
1383+     @param param: 
1384+         Contains all metadata regarding host and device storage and solver parameters 
1385+     """ 
13671386    ...
13681387
13691388def  invertMultiShiftQuda (_hp_x : Pointers , _hp_b : Pointer , param : QudaInvertParam ) ->  None :
@@ -1456,6 +1475,24 @@ def dslashQuda(h_out: Pointer, h_in: Pointer, inv_param: QudaInvertParam, parity
14561475    """ 
14571476    ...
14581477
1478+ def  dslashMultiSrcQuda (_hp_x : Pointers , _hp_b : Pointers , param : QudaInvertParam , parity : QudaParity ) ->  None :
1479+     """ 
1480+     Perform the solve like @dslashQuda but for multiple rhs by spliting the comm grid into 
1481+     sub-partitions: each sub-partition does one or more rhs'. 
1482+     The QudaInvertParam object specifies how the solve should be performed on each sub-partition. 
1483+     Unlike @invertQuda, the interface also takes the host side gauge as 
1484+     input - gauge field is not required to be loaded beforehand. 
1485+ 
1486+     @param _hp_x: 
1487+         Array of solution spinor fields 
1488+     @param _hp_b: 
1489+         Array of source spinor fields 
1490+     @param param: 
1491+         Contains all metadata regarding host and device storage and solver parameters 
1492+     @param parity: 
1493+         Parity to apply dslash on 
1494+     """ 
1495+ 
14591496def  cloverQuda (h_out : Pointer , h_in : Pointer , inv_param : QudaInvertParam , parity : QudaParity , inverse : int ) ->  None :
14601497    """ 
14611498    Apply the clover operator or its inverse. 
@@ -2074,6 +2111,14 @@ class QudaQuarkSmearParam:
20742111    """Time taken for the smearing operations""" 
20752112    gflops : double 
20762113    """Flops count for the smearing operations""" 
2114+     energy : double 
2115+     """The energy consumed by the smearing operations""" 
2116+     power : double 
2117+     """The mean power of the smearing operations""" 
2118+     temp : double 
2119+     """The mean temperature of the device for the duration of the smearing operations""" 
2120+     clock : double 
2121+     """The mean clock frequency of the device for the duration of the smearing operations""" 
20772122
20782123def  performTwoLinkGaussianSmearNStep (h_in : Pointer , smear_param : QudaQuarkSmearParam ) ->  None :
20792124    """ 
0 commit comments