@@ -35,16 +35,16 @@ def check_logits_losses(logits_list, losses):
3535
3636def loss_computation (logits_list , labels , losses , edges = None ):
3737 check_logits_losses (logits_list , losses )
38- loss = 0
38+ loss_list = []
3939 for i in range (len (logits_list )):
4040 logits = logits_list [i ]
4141 loss_i = losses ['types' ][i ]
42- # Whether to use edges as labels According to loss type .
42+ # Whether to use edges as labels According to loss type.
4343 if loss_i .__class__ .__name__ in ('BCELoss' , ) and loss_i .edge_label :
44- loss += losses ['coef' ][i ] * loss_i (logits , edges )
44+ loss_list . append ( losses ['coef' ][i ] * loss_i (logits , edges ) )
4545 else :
46- loss += losses ['coef' ][i ] * loss_i (logits , labels )
47- return loss
46+ loss_list . append ( losses ['coef' ][i ] * loss_i (logits , labels ) )
47+ return loss_list
4848
4949
5050def train (model ,
@@ -115,6 +115,7 @@ def train(model,
115115
116116 timer = Timer ()
117117 avg_loss = 0.0
118+ avg_loss_list = []
118119 iters_per_epoch = len (batch_sampler )
119120 best_mean_iou = - 1.0
120121 best_model_iter = - 1
@@ -140,11 +141,12 @@ def train(model,
140141 logits_list = ddp_model (images )
141142 else :
142143 logits_list = model (images )
143- loss = loss_computation (
144+ loss_list = loss_computation (
144145 logits_list = logits_list ,
145146 labels = labels ,
146147 losses = losses ,
147148 edges = edges )
149+ loss = sum (loss_list )
148150 loss .backward ()
149151
150152 optimizer .step ()
@@ -154,10 +156,18 @@ def train(model,
154156 optimizer ._learning_rate .step ()
155157 model .clear_gradients ()
156158 avg_loss += loss .numpy ()[0 ]
159+ if not avg_loss_list :
160+ avg_loss_list = [l for l in loss_list ]
161+ else :
162+ for i in range (len (loss_list )):
163+ avg_loss_list [i ] += loss_list [i ]
157164 train_batch_cost += timer .elapsed_time ()
158165
159166 if (iter ) % log_iters == 0 and local_rank == 0 :
160167 avg_loss /= log_iters
168+ avg_loss_list = [
169+ l .numpy ()[0 ] / log_iters for l in avg_loss_list
170+ ]
161171 avg_train_reader_cost = train_reader_cost / log_iters
162172 avg_train_batch_cost = train_batch_cost / log_iters
163173 train_reader_cost = 0.0
@@ -171,12 +181,22 @@ def train(model,
171181 avg_train_reader_cost , eta ))
172182 if use_vdl :
173183 log_writer .add_scalar ('Train/loss' , avg_loss , iter )
184+ # Record all losses if there are more than 2 losses.
185+ if len (avg_loss_list ) > 1 :
186+ avg_loss_dict = {}
187+ for i , value in enumerate (avg_loss_list ):
188+ avg_loss_dict ['loss_' + str (i )] = value
189+ for key , value in avg_loss_dict .items ():
190+ log_tag = 'Train/' + key
191+ log_writer .add_scalar (log_tag , value , iter )
192+
174193 log_writer .add_scalar ('Train/lr' , lr , iter )
175194 log_writer .add_scalar ('Train/batch_cost' ,
176195 avg_train_batch_cost , iter )
177196 log_writer .add_scalar ('Train/reader_cost' ,
178197 avg_train_reader_cost , iter )
179198 avg_loss = 0.0
199+ avg_loss_list = []
180200
181201 if (iter % save_interval == 0
182202 or iter == iters ) and (val_dataset is not None ):
0 commit comments