26
26
27
27
28
28
def get_data (is_train ):
29
- voc_data = VocSegDataset (opt .voc_root , is_train , opt .crop_size , img_transforms )
30
- return DataLoader (voc_data , opt .batch_size , True , num_workers = opt .num_workers )
29
+ voc_data = VocSegDataset (opt .voc_root , is_train , opt .crop_size ,
30
+ img_transforms )
31
+ return DataLoader (
32
+ voc_data , opt .batch_size , True , num_workers = opt .num_workers )
31
33
32
34
33
35
def get_model (num_classes ):
@@ -38,7 +40,8 @@ def get_model(num_classes):
38
40
39
41
40
42
def get_optimizer (model ):
41
- optimizer = torch .optim .SGD (model .parameters (), lr = opt .lr , weight_decay = opt .weight_decay )
43
+ optimizer = torch .optim .SGD (
44
+ model .parameters (), lr = opt .lr , weight_decay = opt .weight_decay )
42
45
return ScheduledOptim (optimizer )
43
46
44
47
@@ -64,6 +67,7 @@ def __init__(self):
64
67
self .metric_meter [m ] = meter .AverageValueMeter ()
65
68
66
69
def train (self , kwargs ):
70
+ self .reset_meter ()
67
71
self .model .train ()
68
72
train_data = kwargs ['train_data' ]
69
73
for data in tqdm (train_data ):
@@ -97,28 +101,37 @@ def train(self, kwargs):
97
101
98
102
if (self .n_iter + 1 ) % opt .plot_freq == 0 :
99
103
# Plot metrics curve in tensorboard.
100
- self .writer .add_scalars ('loss' , {'train' : self .metric_meter ['loss' ].value ()[0 ]}, self .n_plot )
101
- self .writer .add_scalars ('acc' , {'train' : self .metric_meter ['acc' ].value ()[0 ]}, self .n_plot )
102
- self .writer .add_scalars ('iou' , {'train' : self .metric_meter ['iou' ].value ()[0 ]}, self .n_plot )
104
+ self .writer .add_scalars (
105
+ 'loss' , {'train' : self .metric_meter ['loss' ].value ()[0 ]},
106
+ self .n_plot )
107
+ self .writer .add_scalars (
108
+ 'acc' , {'train' : self .metric_meter ['acc' ].value ()[0 ]},
109
+ self .n_plot )
110
+ self .writer .add_scalars (
111
+ 'iou' , {'train' : self .metric_meter ['iou' ].value ()[0 ]},
112
+ self .n_plot )
103
113
104
114
# Show segmentation images.
105
115
# Get prediction segmentation and ground truth segmentation.
106
116
origin_image = inverse_normalization (imgs [0 ].cpu ().data )
107
117
pred_seg = cm [pred_labels [0 ]]
108
118
gt_seg = cm [true_labels [0 ]]
109
119
110
- self .writer .add_image ('train ori_img' , origin_image , self .n_plot )
120
+ self .writer .add_image ('train ori_img' , origin_image ,
121
+ self .n_plot )
111
122
self .writer .add_image ('train gt' , gt_seg , self .n_plot )
112
123
self .writer .add_image ('train pred' , pred_seg , self .n_plot )
113
124
self .n_plot += 1
114
125
115
126
self .n_iter += 1
116
127
117
128
self .metric_log ['Train Loss' ] = self .metric_meter ['loss' ].value ()[0 ]
118
- self .metric_log ['Train Mean Class Accuracy' ] = self .metric_meter ['acc' ].value ()[0 ]
129
+ self .metric_log ['Train Mean Class Accuracy' ] = self .metric_meter [
130
+ 'acc' ].value ()[0 ]
119
131
self .metric_log ['Train Mean IoU' ] = self .metric_meter ['iou' ].value ()[0 ]
120
132
121
133
def test (self , kwargs ):
134
+ self .reset_meter ()
122
135
self .model .eval ()
123
136
test_data = kwargs ['test_data' ]
124
137
for data in tqdm (test_data ):
@@ -146,9 +159,13 @@ def test(self, kwargs):
146
159
self .metric_meter ['iou' ].add (eval_metrics ['miou' ])
147
160
148
161
# Plot metrics curve in tensorboard.
149
- self .writer .add_scalars ('loss' , {'test' : self .metric_meter ['loss' ].value ()[0 ]}, self .n_plot )
150
- self .writer .add_scalars ('acc' , {'test' : self .metric_meter ['acc' ].value ()[0 ]}, self .n_plot )
151
- self .writer .add_scalars ('iou' , {'test' : self .metric_meter ['iou' ].value ()[0 ]}, self .n_plot )
162
+ self .writer .add_scalars ('loss' ,
163
+ {'test' : self .metric_meter ['loss' ].value ()[0 ]},
164
+ self .n_plot )
165
+ self .writer .add_scalars (
166
+ 'acc' , {'test' : self .metric_meter ['acc' ].value ()[0 ]}, self .n_plot )
167
+ self .writer .add_scalars (
168
+ 'iou' , {'test' : self .metric_meter ['iou' ].value ()[0 ]}, self .n_plot )
152
169
153
170
origin_img = inverse_normalization (imgs [0 ].cpu ().data )
154
171
pred_seg = cm [pred_labels [0 ]]
@@ -160,7 +177,8 @@ def test(self, kwargs):
160
177
self .n_plot += 1
161
178
162
179
self .metric_log ['Test Loss' ] = self .metric_meter ['loss' ].value ()[0 ]
163
- self .metric_log ['Test Mean Class Accuracy' ] = self .metric_meter ['acc' ].value ()[0 ]
180
+ self .metric_log ['Test Mean Class Accuracy' ] = self .metric_meter [
181
+ 'acc' ].value ()[0 ]
164
182
self .metric_log ['Test Mean IoU' ] = self .metric_meter ['iou' ].value ()[0 ]
165
183
166
184
def get_best_model (self ):
@@ -178,7 +196,8 @@ def train(**kwargs):
178
196
fcn_trainer = FcnTrainer ()
179
197
train_data = get_data (is_train = True )
180
198
test_data = get_data (is_train = False )
181
- fcn_trainer .fit (train_data = train_data , test_data = test_data , epochs = opt .max_epoch )
199
+ fcn_trainer .fit (
200
+ train_data = train_data , test_data = test_data , epochs = opt .max_epoch )
182
201
183
202
184
203
if __name__ == '__main__' :
0 commit comments