Skip to content

Commit acb8a53

Browse files
normalize input
1 parent 9537598 commit acb8a53

File tree

5 files changed

+39
-13
lines changed

5 files changed

+39
-13
lines changed

core/data.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,10 @@ def get_files(mydir):
7171

7272
# Dataset not composite online
7373
class MatDatasetOffline(torch.utils.data.Dataset):
74-
def __init__(self, args, transform=None):
74+
def __init__(self, args, transform=None, normalize=None):
7575
self.samples=[]
7676
self.transform = transform
77+
self.normalize = normalize
7778
self.args = args
7879
self.size_h = args.size_h
7980
self.size_w = args.size_w
@@ -144,6 +145,14 @@ def __getitem__(self,index):
144145
trimap = gen_trimap(alpha)
145146
grad = compute_gradient(img)
146147

148+
if self.normalize:
149+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
150+
# first, 0-255 to 0-1
151+
# second, x-mean/std and HWC to CHW
152+
img_norm = self.normalize(img_rgb)
153+
else:
154+
img_norm = None
155+
147156
#img_id = img_info[0].split('/')[-1]
148157
#cv2.imwrite("result/debug/{}_img.png".format(img_id), img)
149158
#cv2.imwrite("result/debug/{}_alpha.png".format(img_id), alpha)
@@ -158,7 +167,7 @@ def __getitem__(self,index):
158167
fg = torch.from_numpy(fg.astype(np.float32)).permute(2, 0, 1)
159168
bg = torch.from_numpy(bg.astype(np.float32)).permute(2, 0, 1)
160169

161-
return img, alpha, fg, bg, trimap, grad, img_info
170+
return img, alpha, fg, bg, trimap, grad, img_norm, img_info
162171

163172
def __len__(self):
164173
return len(self.samples)

core/deploy.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,18 @@ def inference_once(args, model, scale_img, scale_trimap, aligned=True):
6262
assert(scale_img.shape[0] == args.size_h)
6363
assert(scale_img.shape[1] == args.size_w)
6464

65+
normalize = transforms.Compose([
66+
transforms.ToTensor(),
67+
transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
68+
])
69+
70+
scale_img_rgb = cv2.cvtColor(scale_img, cv2.COLOR_BGR2RGB)
71+
# first, 0-255 to 0-1
72+
# second, x-mean/std and HWC to CHW
73+
tensor_img = normalize(scale_img_rgb).unsqueeze(0)
74+
6575
scale_grad = compute_gradient(scale_img)
66-
tensor_img = torch.from_numpy(scale_img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2)
76+
#tensor_img = torch.from_numpy(scale_img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2)
6777
tensor_trimap = torch.from_numpy(scale_trimap.astype(np.float32)[np.newaxis, np.newaxis, :, :])
6878
tensor_grad = torch.from_numpy(scale_grad.astype(np.float32)[np.newaxis, np.newaxis, :, :])
6979

@@ -73,7 +83,7 @@ def inference_once(args, model, scale_img, scale_trimap, aligned=True):
7383
tensor_grad = tensor_grad.cuda()
7484
#print('Img Shape:{} Trimap Shape:{}'.format(img.shape, trimap.shape))
7585

76-
input_t = torch.cat((tensor_img, tensor_trimap), 1)
86+
input_t = torch.cat((tensor_img, tensor_trimap / 255.), 1)
7787

7888
# forward
7989
if args.stage <= 1:

core/train.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ def get_dataset(args):
5959
args.crop_h = [int(i) for i in args.crop_h.split(',')]
6060
args.crop_w = [int(i) for i in args.crop_w.split(',')]
6161

62-
train_set = MatDatasetOffline(args, train_transform)
62+
normalize = transforms.Compose([
63+
transforms.ToTensor(),
64+
transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
65+
])
66+
67+
train_set = MatDatasetOffline(args, train_transform, normalize)
6368
train_loader = DataLoader(dataset=train_set, num_workers=args.threads, batch_size=args.batchSize, shuffle=True)
6469

6570
return train_loader
@@ -173,6 +178,7 @@ def train(args, model, optimizer, train_loader, epoch):
173178
fg = Variable(batch[2])
174179
bg = Variable(batch[3])
175180
trimap = Variable(batch[4])
181+
img_norm = Variable(batch[6])
176182
img_info = batch[-1]
177183

178184
if args.cuda:
@@ -181,14 +187,15 @@ def train(args, model, optimizer, train_loader, epoch):
181187
fg = fg.cuda()
182188
bg = bg.cuda()
183189
trimap = trimap.cuda()
190+
img_norm = img_norm.cuda()
184191

185192
#print("Shape: Img:{} Alpha:{} Fg:{} Bg:{} Trimap:{}".format(img.shape, alpha.shape, fg.shape, bg.shape, trimap.shape))
186193
#print("Val: Img:{} Alpha:{} Fg:{} Bg:{} Trimap:{} Img_info".format(img, alpha, fg, bg, trimap, img_info))
187194

188195
adjust_learning_rate(args, optimizer, epoch)
189196
optimizer.zero_grad()
190197

191-
pred_mattes, pred_alpha = model(torch.cat((img, trimap), 1))
198+
pred_mattes, pred_alpha = model(torch.cat((img_norm, trimap / 255.), 1))
192199

193200
if args.stage == 0:
194201
# stage0 loss, simple alpha loss

deploy.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#/bin/bash
22

3-
TEST_DATA_ROOT=/data/datasets/matting/Combined_Dataset/Test_set/comp
3+
TEST_DATA_ROOT=/home/liuliang/DISK_2T/datasets/matting/Combined_Dataset/Test_set/comp
44

55
python core/deploy.py \
66
--size_h=320 \
@@ -9,7 +9,7 @@ python core/deploy.py \
99
--trimapDir=$TEST_DATA_ROOT/trimap \
1010
--alphaDir=$TEST_DATA_ROOT/alpha \
1111
--saveDir=result/stage0 \
12-
--resume=model/stage0/ckpt_e19.pth \
12+
--resume=model/stage0_norm/ckpt_e1.pth \
1313
--cuda \
1414
--stage=0 \
1515
--crop_or_resize=whole \

train.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#/bin/bash
2-
DATA_ROOT=/data/datasets/matting/Combined_Dataset
2+
DATA_ROOT=/home/liuliang/DISK_2T/datasets/matting/Combined_Dataset
33
TRAIN_DATA_ROOT=$DATA_ROOT/Training_set/comp
44
TEST_DATA_ROOT=$DATA_ROOT/Test_set/comp
55

@@ -12,14 +12,14 @@ python core/train.py \
1212
--fgDir=$TRAIN_DATA_ROOT/fg \
1313
--bgDir=$TRAIN_DATA_ROOT/bg \
1414
--imgDir=$TRAIN_DATA_ROOT/image \
15-
--saveDir=model/stage0 \
15+
--saveDir=model/stage0_norm \
1616
--batchSize=1 \
1717
--nEpochs=25 \
1818
--step=-1 \
1919
--lr=0.00001 \
2020
--wl_weight=0.5 \
2121
--threads=4 \
22-
--printFreq=1 \
22+
--printFreq=10 \
2323
--ckptSaveFreq=1 \
2424
--cuda \
2525
--stage=0 \
@@ -30,5 +30,5 @@ python core/train.py \
3030
--testAlphaDir=$TEST_DATA_ROOT/alpha \
3131
--testResDir=result/tmp \
3232
--crop_or_resize=whole \
33-
--max_size=1600
34-
#--resume=model/stage0/ckpt_e6.pth \
33+
--max_size=1600 \
34+
#--resume=model/stage0_norm/ckpt_e2.pth \

0 commit comments

Comments
 (0)