Skip to content

Commit 4e34220

Browse files
authored
Merge pull request #970 from reyoung/feature/clean_parameter_updater_finish_pass
Remove unused cost parameter in ParameterUpdater
2 parents 06ea2bf + 71a316e commit 4e34220

File tree

7 files changed

+14
-14
lines changed

7 files changed

+14
-14
lines changed

paddle/parameter/ParameterUpdaterBase.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ParameterUpdater {
3838
virtual void startPass() {}
3939

4040
// called by Trainer then finishing a pass, ruturn true if pass accepted
41-
virtual bool finishPass(real cost = 0) { return true; }
41+
virtual bool finishPass() { return true; }
4242

4343
// called by Trainer before backward() of a batch
4444
// Return the type of pass it needs. This pass type will be passed
@@ -112,9 +112,9 @@ class ParameterUpdaterComposite : public ParameterUpdater {
112112
[&](int tid, size_t numThreads) { updaters_[tid]->startPass(); });
113113
}
114114

115-
virtual bool finishPass(real cost = 0) {
115+
virtual bool finishPass() {
116116
syncThreadPool_->execPlusOwner(
117-
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(cost); });
117+
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(); });
118118
return true;
119119
}
120120

paddle/trainer/ParameterUpdater.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ class SgdLocalUpdater : public ParameterUpdater {
102102
* @param cost sum cost during one pass.
103103
* @return true if accept (used for owlqn).
104104
*/
105-
virtual bool finishPass(real cost) {
105+
virtual bool finishPass() {
106106
optimizer_->finishPass();
107-
return ParameterUpdater::finishPass(cost);
107+
return ParameterUpdater::finishPass();
108108
}
109109

110110
/**
@@ -220,9 +220,9 @@ class SgdUpdaterWithCpuAverager : public SgdLocalUpdater {
220220
averager_->startPass();
221221
SgdLocalUpdater::startPass();
222222
}
223-
virtual bool finishPass(real cost) {
223+
virtual bool finishPass() {
224224
averager_->finishPass();
225-
return SgdLocalUpdater::finishPass(cost);
225+
return SgdLocalUpdater::finishPass();
226226
}
227227

228228
/// apply the averaged parameter to PARAMETER_VALUE

paddle/trainer/RemoteParameterUpdater.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ void RemoteParameterUpdater::startPass() {
309309
}
310310
}
311311

312-
bool RemoteParameterUpdater::finishPass(real cost) {
312+
bool RemoteParameterUpdater::finishPass() {
313313
if (localUpdater_) {
314314
localUpdater_->finishPass();
315315
}
@@ -712,7 +712,7 @@ void SparseRemoteParameterUpdater::startPass() {
712712
}
713713
}
714714

715-
bool SparseRemoteParameterUpdater::finishPass(real cost) {
715+
bool SparseRemoteParameterUpdater::finishPass() {
716716
if (config_.algorithm() == TrainAlgorithm::SGD) {
717717
parameterClient_->waitPassFinish();
718718
} else {

paddle/trainer/RemoteParameterUpdater.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
9090
*/
9191
virtual void finishBatch(real cost);
9292
virtual void startPass();
93-
virtual bool finishPass(real cost);
93+
virtual bool finishPass();
9494

9595
#ifndef PADDLE_DISABLE_TIMER
9696
virtual void setForwardbackwardTime(uint64_t delta) {
@@ -281,7 +281,7 @@ class SparseRemoteParameterUpdater : public ParameterUpdater {
281281
/// send all sparse related parameters to all pservers
282282
virtual void finishBatch(real cost);
283283
virtual void startPass();
284-
virtual bool finishPass(real cost);
284+
virtual bool finishPass();
285285

286286
virtual void apply();
287287
virtual void restore();

paddle/trainer/ThreadParameterUpdater.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void SgdThreadUpdater::startPass() {
7070
}
7171
}
7272

73-
bool SgdThreadUpdater::finishPass(real cost) {
73+
bool SgdThreadUpdater::finishPass() {
7474
catchUpWith();
7575

7676
for (auto& para : parameters_) {

paddle/trainer/ThreadParameterUpdater.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class SgdThreadUpdater : public ParameterUpdater {
4747
virtual void startPass();
4848

4949
// Use the finishPass() function of the base optimizer.
50-
virtual bool finishPass(real cost);
50+
virtual bool finishPass();
5151

5252
virtual void init(const std::vector<ParameterPtr>& parameters);
5353
virtual PassType startBatch(int64_t batchSize);

paddle/trainer/Trainer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ void Trainer::trainOnePassBatch(int passId) {
537537

538538
trainerInternal_.getGradientMachine()->onPassEnd();
539539

540-
bool accepted = trainerInternal_.getParameterUpdater()->finishPass(cost);
540+
bool accepted = trainerInternal_.getParameterUpdater()->finishPass();
541541

542542
globalStat.setThreadInfo(true);
543543
globalStat.printAllStatus();

0 commit comments

Comments
 (0)