Skip to content

Commit 46bc1a4

Browse files
authored
Refactor cal_edm_tddft to replace raw ScaLAPACK and BLAS calls with ScalapackConnector and BlasConnector interfaces (#6687)
1 parent 443f89e commit 46bc1a4

File tree

1 file changed

+131
-133
lines changed

1 file changed

+131
-133
lines changed

source/source_estate/module_dm/cal_edm_tddft.cpp

Lines changed: 131 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22

33
#include "source_base/module_external/lapack_connector.h"
44
#include "source_base/module_external/scalapack_connector.h"
5-
65
#include "source_io/module_parameter/parameter.h" // use PARAM.globalv
76
namespace elecstate
87
{
98
// use the original formula (Hamiltonian matrix) to calculate energy density matrix
109
void cal_edm_tddft(Parallel_Orbitals& pv,
11-
LCAO_domain::Setup_DM<std::complex<double>> &dmat,
10+
LCAO_domain::Setup_DM<std::complex<double>>& dmat,
1211
K_Vectors& kv,
1312
hamilt::Hamilt<std::complex<double>>* p_hamilt)
1413
{
15-
// mohan add 2024-03-27
14+
ModuleBase::timer::tick("elecstate", "cal_edm_tddft");
15+
1616
const int nlocal = PARAM.globalv.nlocal;
1717
assert(nlocal >= 0);
1818

@@ -25,10 +25,6 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
2525
ModuleBase::ComplexMatrix& tmp_edmk = dmat.dm->EDMK[ik];
2626

2727
#ifdef __MPI
28-
29-
// mohan add 2024-03-27
30-
//! be careful, the type of nloc is 'long'
31-
//! whether the long type is safe, needs more discussion
3228
const int nloc = pv.nloc;
3329
const int ncol = pv.ncol;
3430
const int nrow = pv.nrow;
@@ -54,14 +50,14 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
5450
hamilt::MatrixBlock<std::complex<double>> s_mat;
5551

5652
p_hamilt->matrix(h_mat, s_mat);
57-
zcopy_(&nloc, h_mat.p, &inc, Htmp, &inc);
58-
zcopy_(&nloc, s_mat.p, &inc, Sinv, &inc);
53+
BlasConnector::copy(nloc, h_mat.p, inc, Htmp, inc);
54+
BlasConnector::copy(nloc, s_mat.p, inc, Sinv, inc);
5955

6056
vector<int> ipiv(nloc, 0);
6157
int info = 0;
6258
const int one_int = 1;
6359

64-
pzgetrf_(&nlocal, &nlocal, Sinv, &one_int, &one_int, pv.desc, ipiv.data(), &info);
60+
ScalapackConnector::getrf(nlocal, nlocal, Sinv, one_int, one_int, pv.desc, ipiv.data(), &info);
6561

6662
int lwork = -1;
6763
int liwork = -1;
@@ -72,136 +68,136 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
7268
// if liwork = -1, then the size of iwork is (at least) of length 1.
7369
std::vector<int> iwork(1, 0);
7470

75-
pzgetri_(&nlocal,
76-
Sinv,
77-
&one_int,
78-
&one_int,
79-
pv.desc,
80-
ipiv.data(),
81-
work.data(),
82-
&lwork,
83-
iwork.data(),
84-
&liwork,
85-
&info);
71+
ScalapackConnector::getri(nlocal,
72+
Sinv,
73+
one_int,
74+
one_int,
75+
pv.desc,
76+
ipiv.data(),
77+
work.data(),
78+
&lwork,
79+
iwork.data(),
80+
&liwork,
81+
&info);
8682

8783
lwork = work[0].real();
8884
work.resize(lwork, 0);
8985
liwork = iwork[0];
9086
iwork.resize(liwork, 0);
9187

92-
pzgetri_(&nlocal,
93-
Sinv,
94-
&one_int,
95-
&one_int,
96-
pv.desc,
97-
ipiv.data(),
98-
work.data(),
99-
&lwork,
100-
iwork.data(),
101-
&liwork,
102-
&info);
88+
ScalapackConnector::getri(nlocal,
89+
Sinv,
90+
one_int,
91+
one_int,
92+
pv.desc,
93+
ipiv.data(),
94+
work.data(),
95+
&lwork,
96+
iwork.data(),
97+
&liwork,
98+
&info);
10399

104100
const char N_char = 'N';
105101
const char T_char = 'T';
106-
const std::complex<double> one_float = {1.0, 0.0};
107-
const std::complex<double> zero_float = {0.0, 0.0};
108-
const std::complex<double> half_float = {0.5, 0.0};
109-
110-
pzgemm_(&N_char,
111-
&N_char,
112-
&nlocal,
113-
&nlocal,
114-
&nlocal,
115-
&one_float,
116-
Htmp,
117-
&one_int,
118-
&one_int,
119-
pv.desc,
120-
Sinv,
121-
&one_int,
122-
&one_int,
123-
pv.desc,
124-
&zero_float,
125-
tmp1,
126-
&one_int,
127-
&one_int,
128-
pv.desc);
129-
130-
pzgemm_(&T_char,
131-
&N_char,
132-
&nlocal,
133-
&nlocal,
134-
&nlocal,
135-
&one_float,
136-
tmp1,
137-
&one_int,
138-
&one_int,
139-
pv.desc,
140-
tmp_dmk,
141-
&one_int,
142-
&one_int,
143-
pv.desc,
144-
&zero_float,
145-
tmp2,
146-
&one_int,
147-
&one_int,
148-
pv.desc);
149-
150-
pzgemm_(&N_char,
151-
&N_char,
152-
&nlocal,
153-
&nlocal,
154-
&nlocal,
155-
&one_float,
156-
Sinv,
157-
&one_int,
158-
&one_int,
159-
pv.desc,
160-
Htmp,
161-
&one_int,
162-
&one_int,
163-
pv.desc,
164-
&zero_float,
165-
tmp3,
166-
&one_int,
167-
&one_int,
168-
pv.desc);
169-
170-
pzgemm_(&N_char,
171-
&T_char,
172-
&nlocal,
173-
&nlocal,
174-
&nlocal,
175-
&one_float,
176-
tmp_dmk,
177-
&one_int,
178-
&one_int,
179-
pv.desc,
180-
tmp3,
181-
&one_int,
182-
&one_int,
183-
pv.desc,
184-
&zero_float,
185-
tmp4,
186-
&one_int,
187-
&one_int,
188-
pv.desc);
189-
190-
pzgeadd_(&N_char,
191-
&nlocal,
192-
&nlocal,
193-
&half_float,
194-
tmp2,
195-
&one_int,
196-
&one_int,
197-
pv.desc,
198-
&half_float,
199-
tmp4,
200-
&one_int,
201-
&one_int,
202-
pv.desc);
203-
204-
zcopy_(&nloc, tmp4, &inc, tmp_edmk.c, &inc);
102+
const std::complex<double> one_complex = {1.0, 0.0};
103+
const std::complex<double> zero_complex = {0.0, 0.0};
104+
const std::complex<double> half_complex = {0.5, 0.0};
105+
106+
ScalapackConnector::gemm(N_char,
107+
N_char,
108+
nlocal,
109+
nlocal,
110+
nlocal,
111+
one_complex,
112+
Htmp,
113+
one_int,
114+
one_int,
115+
pv.desc,
116+
Sinv,
117+
one_int,
118+
one_int,
119+
pv.desc,
120+
zero_complex,
121+
tmp1,
122+
one_int,
123+
one_int,
124+
pv.desc);
125+
126+
ScalapackConnector::gemm(T_char,
127+
N_char,
128+
nlocal,
129+
nlocal,
130+
nlocal,
131+
one_complex,
132+
tmp1,
133+
one_int,
134+
one_int,
135+
pv.desc,
136+
tmp_dmk,
137+
one_int,
138+
one_int,
139+
pv.desc,
140+
zero_complex,
141+
tmp2,
142+
one_int,
143+
one_int,
144+
pv.desc);
145+
146+
ScalapackConnector::gemm(N_char,
147+
N_char,
148+
nlocal,
149+
nlocal,
150+
nlocal,
151+
one_complex,
152+
Sinv,
153+
one_int,
154+
one_int,
155+
pv.desc,
156+
Htmp,
157+
one_int,
158+
one_int,
159+
pv.desc,
160+
zero_complex,
161+
tmp3,
162+
one_int,
163+
one_int,
164+
pv.desc);
165+
166+
ScalapackConnector::gemm(N_char,
167+
T_char,
168+
nlocal,
169+
nlocal,
170+
nlocal,
171+
one_complex,
172+
tmp_dmk,
173+
one_int,
174+
one_int,
175+
pv.desc,
176+
tmp3,
177+
one_int,
178+
one_int,
179+
pv.desc,
180+
zero_complex,
181+
tmp4,
182+
one_int,
183+
one_int,
184+
pv.desc);
185+
186+
ScalapackConnector::geadd(N_char,
187+
nlocal,
188+
nlocal,
189+
half_complex,
190+
tmp2,
191+
one_int,
192+
one_int,
193+
pv.desc,
194+
half_complex,
195+
tmp4,
196+
one_int,
197+
one_int,
198+
pv.desc);
199+
200+
BlasConnector::copy(nloc, tmp4, inc, tmp_edmk.c, inc);
205201

206202
delete[] Htmp;
207203
delete[] Sinv;
@@ -219,7 +215,7 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
219215
hamilt::MatrixBlock<std::complex<double>> s_mat;
220216

221217
p_hamilt->matrix(h_mat, s_mat);
222-
// cout<<"hmat "<<h_mat.p[0]<<endl;
218+
223219
for (int i = 0; i < nlocal; i++)
224220
{
225221
for (int j = 0; j < nlocal; j++)
@@ -251,7 +247,9 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
251247
tmp_edmk = 0.5 * (Sinv * Htmp * tmp_dmk_base + tmp_dmk_base * Htmp * Sinv);
252248
delete[] work;
253249
#endif
254-
}
250+
} // end ik
251+
252+
ModuleBase::timer::tick("elecstate", "cal_edm_tddft");
255253
return;
256254
} // cal_edm_tddft
257255
} // namespace elecstate

0 commit comments

Comments
 (0)