File tree Expand file tree Collapse file tree 7 files changed +14
-14
lines changed Expand file tree Collapse file tree 7 files changed +14
-14
lines changed Original file line number Diff line number Diff line change @@ -38,7 +38,7 @@ class ParameterUpdater {
3838 virtual void startPass () {}
3939
4040 // called by Trainer then finishing a pass, ruturn true if pass accepted
41- virtual bool finishPass (real cost = 0 ) { return true ; }
41+ virtual bool finishPass () { return true ; }
4242
4343 // called by Trainer before backward() of a batch
4444 // Return the type of pass it needs. This pass type will be passed
@@ -112,9 +112,9 @@ class ParameterUpdaterComposite : public ParameterUpdater {
112112 [&](int tid, size_t numThreads) { updaters_[tid]->startPass (); });
113113 }
114114
115- virtual bool finishPass (real cost = 0 ) {
115+ virtual bool finishPass () {
116116 syncThreadPool_->execPlusOwner (
117- [&](int tid, size_t numThreads) { updaters_[tid]->finishPass (cost ); });
117+ [&](int tid, size_t numThreads) { updaters_[tid]->finishPass (); });
118118 return true ;
119119 }
120120
Original file line number Diff line number Diff line change @@ -102,9 +102,9 @@ class SgdLocalUpdater : public ParameterUpdater {
102102 * @param cost sum cost during one pass.
103103 * @return true if accept (used for owlqn).
104104 */
105- virtual bool finishPass (real cost ) {
105+ virtual bool finishPass () {
106106 optimizer_->finishPass ();
107- return ParameterUpdater::finishPass (cost );
107+ return ParameterUpdater::finishPass ();
108108 }
109109
110110 /* *
@@ -220,9 +220,9 @@ class SgdUpdaterWithCpuAverager : public SgdLocalUpdater {
220220 averager_->startPass ();
221221 SgdLocalUpdater::startPass ();
222222 }
223- virtual bool finishPass (real cost ) {
223+ virtual bool finishPass () {
224224 averager_->finishPass ();
225- return SgdLocalUpdater::finishPass (cost );
225+ return SgdLocalUpdater::finishPass ();
226226 }
227227
228228 // / apply the averaged parameter to PARAMETER_VALUE
Original file line number Diff line number Diff line change @@ -309,7 +309,7 @@ void RemoteParameterUpdater::startPass() {
309309 }
310310}
311311
312- bool RemoteParameterUpdater::finishPass (real cost ) {
312+ bool RemoteParameterUpdater::finishPass () {
313313 if (localUpdater_) {
314314 localUpdater_->finishPass ();
315315 }
@@ -712,7 +712,7 @@ void SparseRemoteParameterUpdater::startPass() {
712712 }
713713}
714714
715- bool SparseRemoteParameterUpdater::finishPass (real cost ) {
715+ bool SparseRemoteParameterUpdater::finishPass () {
716716 if (config_.algorithm () == TrainAlgorithm::SGD) {
717717 parameterClient_->waitPassFinish ();
718718 } else {
Original file line number Diff line number Diff line change @@ -90,7 +90,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
9090 */
9191 virtual void finishBatch (real cost);
9292 virtual void startPass ();
93- virtual bool finishPass (real cost );
93+ virtual bool finishPass ();
9494
9595#ifndef PADDLE_DISABLE_TIMER
9696 virtual void setForwardbackwardTime (uint64_t delta) {
@@ -281,7 +281,7 @@ class SparseRemoteParameterUpdater : public ParameterUpdater {
281281 // / send all sparse related parameters to all pservers
282282 virtual void finishBatch (real cost);
283283 virtual void startPass ();
284- virtual bool finishPass (real cost );
284+ virtual bool finishPass ();
285285
286286 virtual void apply ();
287287 virtual void restore ();
Original file line number Diff line number Diff line change @@ -70,7 +70,7 @@ void SgdThreadUpdater::startPass() {
7070 }
7171}
7272
73- bool SgdThreadUpdater::finishPass (real cost ) {
73+ bool SgdThreadUpdater::finishPass () {
7474 catchUpWith ();
7575
7676 for (auto & para : parameters_) {
Original file line number Diff line number Diff line change @@ -47,7 +47,7 @@ class SgdThreadUpdater : public ParameterUpdater {
4747 virtual void startPass ();
4848
4949 // Use the finishPass() function of the base optimizer.
50- virtual bool finishPass (real cost );
50+ virtual bool finishPass ();
5151
5252 virtual void init (const std::vector<ParameterPtr>& parameters);
5353 virtual PassType startBatch (int64_t batchSize);
Original file line number Diff line number Diff line change @@ -537,7 +537,7 @@ void Trainer::trainOnePassBatch(int passId) {
537537
538538 trainerInternal_.getGradientMachine ()->onPassEnd ();
539539
540- bool accepted = trainerInternal_.getParameterUpdater ()->finishPass (cost );
540+ bool accepted = trainerInternal_.getParameterUpdater ()->finishPass ();
541541
542542 globalStat.setThreadInfo (true );
543543 globalStat.printAllStatus ();
You can’t perform that action at this time.
0 commit comments