1- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2- #
3- # Licensed under the Apache License, Version 2.0 (the "License");
4- # you may not use this file except in compliance with the License.
5- # You may obtain a copy of the License at
6- #
7- # http://www.apache.org/licenses/LICENSE-2.0
8- #
9- # Unless required by applicable law or agreed to in writing, software
10- # distributed under the License is distributed on an "AS IS" BASIS,
11- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12- # See the License for the specific language governing permissions and
13- # limitations under the License.
14-
151import os
16- import cv2
17- import math
18- import random
192import numpy as np
203from PIL import Image
21-
22- from paddle .vision .datasets import DatasetFolder
4+ from paddle .io import Dataset
235from paddle .vision .transforms import transforms
246
257
26- class ImageNetDataset (DatasetFolder ):
8+ class ImageNetDataset (Dataset ):
279 def __init__ (self ,
28- path ,
10+ data_dir ,
2911 mode = 'train' ,
3012 image_size = 224 ,
3113 resize_short_size = 256 ):
32- super (ImageNetDataset , self ).__init__ (path )
14+ super (ImageNetDataset , self ).__init__ ()
15+ train_file_list = os .path .join (data_dir , 'train_list.txt' )
16+ val_file_list = os .path .join (data_dir , 'val_list.txt' )
17+ test_file_list = os .path .join (data_dir , 'test_list.txt' )
18+ self .data_dir = data_dir
3319 self .mode = mode
3420
3521 normalize = transforms .Normalize (
@@ -47,11 +33,35 @@ def __init__(self,
4733 normalize
4834 ])
4935
50- def __getitem__ (self , idx ):
51- img_path , label = self .samples [idx ]
36+ if mode == 'train' :
37+ with open (train_file_list ) as flist :
38+ full_lines = [line .strip () for line in flist ]
39+ np .random .shuffle (full_lines )
40+ if os .getenv ('PADDLE_TRAINING_ROLE' ):
41+ # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
42+ trainer_id = int (os .getenv ("PADDLE_TRAINER_ID" , "0" ))
43+ trainer_count = int (os .getenv ("PADDLE_TRAINERS_NUM" , "1" ))
44+ per_node_lines = len (full_lines ) // trainer_count
45+ lines = full_lines [trainer_id * per_node_lines :(
46+ trainer_id + 1 ) * per_node_lines ]
47+ print (
48+ "read images from %d, length: %d, lines length: %d, total: %d"
49+ % (trainer_id * per_node_lines , per_node_lines ,
50+ len (lines ), len (full_lines )))
51+ else :
52+ lines = full_lines
53+ self .data = [line .split () for line in lines ]
54+ else :
55+ with open (val_file_list ) as flist :
56+ lines = [line .strip () for line in flist ]
57+ self .data = [line .split () for line in lines ]
58+
59+ def __getitem__ (self , index ):
60+ img_path , label = self .data [index ]
61+ img_path = os .path .join (self .data_dir , img_path )
5262 img = Image .open (img_path ).convert ('RGB' )
5363 label = np .array ([label ]).astype (np .int64 )
5464 return self .transform (img ), label
5565
5666 def __len__ (self ):
57- return len (self .samples )
67+ return len (self .data )
0 commit comments