11
11
np .random .seed (0 )
12
12
13
13
DATA_DIM = 224
14
+ RESIZE_DIM = 256
14
15
15
16
THREAD = 16
16
17
BUF_SIZE = 10240
@@ -34,8 +35,8 @@ def crop_image(img, target_size, center):
34
35
width , height = img .size
35
36
size = target_size
36
37
if center == True :
37
- w_start = (width - size ) / 2
38
- h_start = (height - size ) / 2
38
+ w_start = (width - size ) // 2
39
+ h_start = (height - size ) // 2
39
40
else :
40
41
w_start = np .random .randint (0 , width - size + 1 )
41
42
h_start = np .random .randint (0 , height - size + 1 )
@@ -98,7 +99,7 @@ def random_color(img, lower=0.5, upper=1.5):
98
99
return img
99
100
100
101
101
- def process_image (sample , mode , color_jitter , rotate ):
102
+ def process_image (sample , mode , color_jitter , rotate , crop_size , resize_size ):
102
103
img_path = sample [0 ]
103
104
104
105
try :
@@ -108,10 +109,10 @@ def process_image(sample, mode, color_jitter, rotate):
108
109
return None
109
110
if mode == 'train' :
110
111
if rotate : img = rotate_image (img )
111
- img = random_crop (img , DATA_DIM )
112
+ img = random_crop (img , crop_size )
112
113
else :
113
- img = resize_short (img , target_size = 256 )
114
- img = crop_image (img , target_size = DATA_DIM , center = True )
114
+ img = resize_short (img , target_size = resize_size )
115
+ img = crop_image (img , target_size = crop_size , center = True )
115
116
if mode == 'train' :
116
117
if color_jitter :
117
118
img = distort_color (img )
@@ -185,9 +186,15 @@ def test(data_dir=DATA_DIR):
185
186
186
187
187
188
class ImageNetDataset (Dataset ):
188
- def __init__ (self , data_dir = DATA_DIR , mode = 'train' ):
189
+ def __init__ (self ,
190
+ data_dir = DATA_DIR ,
191
+ mode = 'train' ,
192
+ crop_size = DATA_DIM ,
193
+ resize_size = RESIZE_DIM ):
189
194
super (ImageNetDataset , self ).__init__ ()
190
195
self .data_dir = data_dir
196
+ self .crop_size = crop_size
197
+ self .resize_size = resize_size
191
198
train_file_list = os .path .join (data_dir , 'train_list.txt' )
192
199
val_file_list = os .path .join (data_dir , 'val_list.txt' )
193
200
test_file_list = os .path .join (data_dir , 'test_list.txt' )
@@ -211,21 +218,27 @@ def __getitem__(self, index):
211
218
[data_path , sample [1 ]],
212
219
mode = 'train' ,
213
220
color_jitter = False ,
214
- rotate = False )
221
+ rotate = False ,
222
+ crop_size = self .crop_size ,
223
+ resize_size = self .resize_size )
215
224
return data , np .array ([label ]).astype ('int64' )
216
225
elif self .mode == 'val' :
217
226
data , label = process_image (
218
227
[data_path , sample [1 ]],
219
228
mode = 'val' ,
220
229
color_jitter = False ,
221
- rotate = False )
230
+ rotate = False ,
231
+ crop_size = self .crop_size ,
232
+ resize_size = self .resize_size )
222
233
return data , np .array ([label ]).astype ('int64' )
223
234
elif self .mode == 'test' :
224
235
data = process_image (
225
236
[data_path , sample [1 ]],
226
237
mode = 'test' ,
227
238
color_jitter = False ,
228
- rotate = False )
239
+ rotate = False ,
240
+ crop_size = self .crop_size ,
241
+ resize_size = self .resize_size )
229
242
return data
230
243
231
244
def __len__ (self ):
0 commit comments