Skip to content

Commit 2e36a01

Browse files
Fix ce of dygraph quant (PaddlePaddle#873)
1 parent fd5084c commit 2e36a01

File tree

3 files changed

+38
-35
lines changed

3 files changed

+38
-35
lines changed
Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,21 @@
1-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
151
import os
16-
import cv2
17-
import math
18-
import random
192
import numpy as np
203
from PIL import Image
21-
22-
from paddle.vision.datasets import DatasetFolder
4+
from paddle.io import Dataset
235
from paddle.vision.transforms import transforms
246

257

26-
class ImageNetDataset(DatasetFolder):
8+
class ImageNetDataset(Dataset):
279
def __init__(self,
28-
path,
10+
data_dir,
2911
mode='train',
3012
image_size=224,
3113
resize_short_size=256):
32-
super(ImageNetDataset, self).__init__(path)
14+
super(ImageNetDataset, self).__init__()
15+
train_file_list = os.path.join(data_dir, 'train_list.txt')
16+
val_file_list = os.path.join(data_dir, 'val_list.txt')
17+
test_file_list = os.path.join(data_dir, 'test_list.txt')
18+
self.data_dir = data_dir
3319
self.mode = mode
3420

3521
normalize = transforms.Normalize(
@@ -47,11 +33,35 @@ def __init__(self,
4733
normalize
4834
])
4935

50-
def __getitem__(self, idx):
51-
img_path, label = self.samples[idx]
36+
if mode == 'train':
37+
with open(train_file_list) as flist:
38+
full_lines = [line.strip() for line in flist]
39+
np.random.shuffle(full_lines)
40+
if os.getenv('PADDLE_TRAINING_ROLE'):
41+
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
42+
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
43+
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
44+
per_node_lines = len(full_lines) // trainer_count
45+
lines = full_lines[trainer_id * per_node_lines:(
46+
trainer_id + 1) * per_node_lines]
47+
print(
48+
"read images from %d, length: %d, lines length: %d, total: %d"
49+
% (trainer_id * per_node_lines, per_node_lines,
50+
len(lines), len(full_lines)))
51+
else:
52+
lines = full_lines
53+
self.data = [line.split() for line in lines]
54+
else:
55+
with open(val_file_list) as flist:
56+
lines = [line.strip() for line in flist]
57+
self.data = [line.split() for line in lines]
58+
59+
def __getitem__(self, index):
60+
img_path, label = self.data[index]
61+
img_path = os.path.join(self.data_dir, img_path)
5262
img = Image.open(img_path).convert('RGB')
5363
label = np.array([label]).astype(np.int64)
5464
return self.transform(img), label
5565

5666
def __len__(self):
57-
return len(self.samples)
67+
return len(self.data)

ce_tests/dygraph/quant/src/ptq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def main():
6060
fp32_model = models.__dict__[FLAGS.arch](pretrained=True)
6161
fp32_model.eval()
6262

63-
val_dataset = ImageNetDataset(
64-
os.path.join(FLAGS.data, FLAGS.val_dir), mode='val')
63+
val_dataset = ImageNetDataset(FLAGS.data, mode='val')
6564

6665
# 2 quantizations
6766
ptq = PTQ()

ce_tests/dygraph/quant/src/qat.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,8 @@ def main():
8686
print("Resume from " + FLAGS.resume)
8787
model.load(FLAGS.resume)
8888

89-
train_dataset = ImageNetDataset(
90-
os.path.join(FLAGS.data, 'train'), mode='train')
91-
val_dataset = ImageNetDataset(
92-
os.path.join(FLAGS.data, FLAGS.val_dir), mode='val')
89+
train_dataset = ImageNetDataset(FLAGS.data, mode='train')
90+
val_dataset = ImageNetDataset(FLAGS.data, mode='val')
9391

9492
optim = make_optimizer(
9593
np.ceil(
@@ -152,10 +150,6 @@ def main():
152150
default="",
153151
help='path to dataset '
154152
'(should have subdirectories named "train" and "val"')
155-
parser.add_argument(
156-
'--val_dir',
157-
default="val",
158-
help='the dir that saves val images for paddle.Model')
159153

160154
# train
161155
parser.add_argument(

0 commit comments

Comments
 (0)