@@ -59,7 +59,12 @@ def get_dataset(args):
5959 args .crop_h = [int (i ) for i in args .crop_h .split (',' )]
6060 args .crop_w = [int (i ) for i in args .crop_w .split (',' )]
6161
62- train_set = MatDatasetOffline (args , train_transform )
62+ normalize = transforms .Compose ([
63+ transforms .ToTensor (),
64+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],std = [0.229 , 0.224 , 0.225 ])
65+ ])
66+
67+ train_set = MatDatasetOffline (args , train_transform , normalize )
6368 train_loader = DataLoader (dataset = train_set , num_workers = args .threads , batch_size = args .batchSize , shuffle = True )
6469
6570 return train_loader
@@ -173,6 +178,7 @@ def train(args, model, optimizer, train_loader, epoch):
173178 fg = Variable (batch [2 ])
174179 bg = Variable (batch [3 ])
175180 trimap = Variable (batch [4 ])
181+ img_norm = Variable (batch [6 ])
176182 img_info = batch [- 1 ]
177183
178184 if args .cuda :
@@ -181,14 +187,15 @@ def train(args, model, optimizer, train_loader, epoch):
181187 fg = fg .cuda ()
182188 bg = bg .cuda ()
183189 trimap = trimap .cuda ()
190+ img_norm = img_norm .cuda ()
184191
185192 #print("Shape: Img:{} Alpha:{} Fg:{} Bg:{} Trimap:{}".format(img.shape, alpha.shape, fg.shape, bg.shape, trimap.shape))
186193 #print("Val: Img:{} Alpha:{} Fg:{} Bg:{} Trimap:{} Img_info".format(img, alpha, fg, bg, trimap, img_info))
187194
188195 adjust_learning_rate (args , optimizer , epoch )
189196 optimizer .zero_grad ()
190197
191- pred_mattes , pred_alpha = model (torch .cat ((img , trimap ), 1 ))
198+ pred_mattes , pred_alpha = model (torch .cat ((img_norm , trimap / 255. ), 1 ))
192199
193200 if args .stage == 0 :
194201 # stage0 loss, simple alpha loss
0 commit comments