22
33#include " source_base/module_external/lapack_connector.h"
44#include " source_base/module_external/scalapack_connector.h"
5-
65#include " source_io/module_parameter/parameter.h" // use PARAM.globalv
76namespace elecstate
87{
98// use the original formula (Hamiltonian matrix) to calculate energy density matrix
109void cal_edm_tddft (Parallel_Orbitals& pv,
11- LCAO_domain::Setup_DM<std::complex <double >> & dmat,
10+ LCAO_domain::Setup_DM<std::complex <double >>& dmat,
1211 K_Vectors& kv,
1312 hamilt::Hamilt<std::complex <double >>* p_hamilt)
1413{
15- // mohan add 2024-03-27
14+ ModuleBase::timer::tick (" elecstate" , " cal_edm_tddft" );
15+
1616 const int nlocal = PARAM.globalv .nlocal ;
1717 assert (nlocal >= 0 );
1818
@@ -25,10 +25,6 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
2525 ModuleBase::ComplexMatrix& tmp_edmk = dmat.dm ->EDMK [ik];
2626
2727#ifdef __MPI
28-
29- // mohan add 2024-03-27
30- // ! be careful, the type of nloc is 'long'
31- // ! whether the long type is safe, needs more discussion
3228 const int nloc = pv.nloc ;
3329 const int ncol = pv.ncol ;
3430 const int nrow = pv.nrow ;
@@ -54,14 +50,14 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
5450 hamilt::MatrixBlock<std::complex <double >> s_mat;
5551
5652 p_hamilt->matrix (h_mat, s_mat);
57- zcopy_ (& nloc, h_mat.p , & inc, Htmp, & inc);
58- zcopy_ (& nloc, s_mat.p , & inc, Sinv, & inc);
53+ BlasConnector::copy ( nloc, h_mat.p , inc, Htmp, inc);
54+ BlasConnector::copy ( nloc, s_mat.p , inc, Sinv, inc);
5955
6056 vector<int > ipiv (nloc, 0 );
6157 int info = 0 ;
6258 const int one_int = 1 ;
6359
64- pzgetrf_ (& nlocal, & nlocal, Sinv, & one_int, & one_int, pv.desc , ipiv.data (), &info);
60+ ScalapackConnector::getrf ( nlocal, nlocal, Sinv, one_int, one_int, pv.desc , ipiv.data (), &info);
6561
6662 int lwork = -1 ;
6763 int liwork = -1 ;
@@ -72,136 +68,136 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
7268 // if liwork = -1, then the size of iwork is (at least) of length 1.
7369 std::vector<int > iwork (1 , 0 );
7470
75- pzgetri_ (& nlocal,
76- Sinv,
77- & one_int,
78- & one_int,
79- pv.desc ,
80- ipiv.data (),
81- work.data (),
82- &lwork,
83- iwork.data (),
84- &liwork,
85- &info);
71+ ScalapackConnector::getri ( nlocal,
72+ Sinv,
73+ one_int,
74+ one_int,
75+ pv.desc ,
76+ ipiv.data (),
77+ work.data (),
78+ &lwork,
79+ iwork.data (),
80+ &liwork,
81+ &info);
8682
8783 lwork = work[0 ].real ();
8884 work.resize (lwork, 0 );
8985 liwork = iwork[0 ];
9086 iwork.resize (liwork, 0 );
9187
92- pzgetri_ (& nlocal,
93- Sinv,
94- & one_int,
95- & one_int,
96- pv.desc ,
97- ipiv.data (),
98- work.data (),
99- &lwork,
100- iwork.data (),
101- &liwork,
102- &info);
88+ ScalapackConnector::getri ( nlocal,
89+ Sinv,
90+ one_int,
91+ one_int,
92+ pv.desc ,
93+ ipiv.data (),
94+ work.data (),
95+ &lwork,
96+ iwork.data (),
97+ &liwork,
98+ &info);
10399
104100 const char N_char = ' N' ;
105101 const char T_char = ' T' ;
106- const std::complex <double > one_float = {1.0 , 0.0 };
107- const std::complex <double > zero_float = {0.0 , 0.0 };
108- const std::complex <double > half_float = {0.5 , 0.0 };
109-
110- pzgemm_ (& N_char,
111- & N_char,
112- & nlocal,
113- & nlocal,
114- & nlocal,
115- &one_float ,
116- Htmp,
117- & one_int,
118- & one_int,
119- pv.desc ,
120- Sinv,
121- & one_int,
122- & one_int,
123- pv.desc ,
124- &zero_float ,
125- tmp1,
126- & one_int,
127- & one_int,
128- pv.desc );
129-
130- pzgemm_ (& T_char,
131- & N_char,
132- & nlocal,
133- & nlocal,
134- & nlocal,
135- &one_float ,
136- tmp1,
137- & one_int,
138- & one_int,
139- pv.desc ,
140- tmp_dmk,
141- & one_int,
142- & one_int,
143- pv.desc ,
144- &zero_float ,
145- tmp2,
146- & one_int,
147- & one_int,
148- pv.desc );
149-
150- pzgemm_ (& N_char,
151- & N_char,
152- & nlocal,
153- & nlocal,
154- & nlocal,
155- &one_float ,
156- Sinv,
157- & one_int,
158- & one_int,
159- pv.desc ,
160- Htmp,
161- & one_int,
162- & one_int,
163- pv.desc ,
164- &zero_float ,
165- tmp3,
166- & one_int,
167- & one_int,
168- pv.desc );
169-
170- pzgemm_ (& N_char,
171- & T_char,
172- & nlocal,
173- & nlocal,
174- & nlocal,
175- &one_float ,
176- tmp_dmk,
177- & one_int,
178- & one_int,
179- pv.desc ,
180- tmp3,
181- & one_int,
182- & one_int,
183- pv.desc ,
184- &zero_float ,
185- tmp4,
186- & one_int,
187- & one_int,
188- pv.desc );
189-
190- pzgeadd_ (& N_char,
191- & nlocal,
192- & nlocal,
193- &half_float ,
194- tmp2,
195- & one_int,
196- & one_int,
197- pv.desc ,
198- &half_float ,
199- tmp4,
200- & one_int,
201- & one_int,
202- pv.desc );
203-
204- zcopy_ (& nloc, tmp4, & inc, tmp_edmk.c , & inc);
102+ const std::complex <double > one_complex = {1.0 , 0.0 };
103+ const std::complex <double > zero_complex = {0.0 , 0.0 };
104+ const std::complex <double > half_complex = {0.5 , 0.0 };
105+
106+ ScalapackConnector::gemm ( N_char,
107+ N_char,
108+ nlocal,
109+ nlocal,
110+ nlocal,
111+ one_complex ,
112+ Htmp,
113+ one_int,
114+ one_int,
115+ pv.desc ,
116+ Sinv,
117+ one_int,
118+ one_int,
119+ pv.desc ,
120+ zero_complex ,
121+ tmp1,
122+ one_int,
123+ one_int,
124+ pv.desc );
125+
126+ ScalapackConnector::gemm ( T_char,
127+ N_char,
128+ nlocal,
129+ nlocal,
130+ nlocal,
131+ one_complex ,
132+ tmp1,
133+ one_int,
134+ one_int,
135+ pv.desc ,
136+ tmp_dmk,
137+ one_int,
138+ one_int,
139+ pv.desc ,
140+ zero_complex ,
141+ tmp2,
142+ one_int,
143+ one_int,
144+ pv.desc );
145+
146+ ScalapackConnector::gemm ( N_char,
147+ N_char,
148+ nlocal,
149+ nlocal,
150+ nlocal,
151+ one_complex ,
152+ Sinv,
153+ one_int,
154+ one_int,
155+ pv.desc ,
156+ Htmp,
157+ one_int,
158+ one_int,
159+ pv.desc ,
160+ zero_complex ,
161+ tmp3,
162+ one_int,
163+ one_int,
164+ pv.desc );
165+
166+ ScalapackConnector::gemm ( N_char,
167+ T_char,
168+ nlocal,
169+ nlocal,
170+ nlocal,
171+ one_complex ,
172+ tmp_dmk,
173+ one_int,
174+ one_int,
175+ pv.desc ,
176+ tmp3,
177+ one_int,
178+ one_int,
179+ pv.desc ,
180+ zero_complex ,
181+ tmp4,
182+ one_int,
183+ one_int,
184+ pv.desc );
185+
186+ ScalapackConnector::geadd ( N_char,
187+ nlocal,
188+ nlocal,
189+ half_complex ,
190+ tmp2,
191+ one_int,
192+ one_int,
193+ pv.desc ,
194+ half_complex ,
195+ tmp4,
196+ one_int,
197+ one_int,
198+ pv.desc );
199+
200+ BlasConnector::copy ( nloc, tmp4, inc, tmp_edmk.c , inc);
205201
206202 delete[] Htmp;
207203 delete[] Sinv;
@@ -219,7 +215,7 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
219215 hamilt::MatrixBlock<std::complex <double >> s_mat;
220216
221217 p_hamilt->matrix (h_mat, s_mat);
222- // cout<<"hmat "<<h_mat.p[0]<<endl;
218+
223219 for (int i = 0 ; i < nlocal; i++)
224220 {
225221 for (int j = 0 ; j < nlocal; j++)
@@ -251,7 +247,9 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
251247 tmp_edmk = 0.5 * (Sinv * Htmp * tmp_dmk_base + tmp_dmk_base * Htmp * Sinv);
252248 delete[] work;
253249#endif
254- }
250+ } // end ik
251+
252+ ModuleBase::timer::tick (" elecstate" , " cal_edm_tddft" );
255253 return ;
256254} // cal_edm_tddft
257255} // namespace elecstate
0 commit comments