@@ -232,7 +232,7 @@ def provider(input_types=None,
232232 check = False ,
233233 check_fail_continue = False ,
234234 init_hook = None ,
235- ** kwargs ):
235+ ** outter_kwargs ):
236236 """
237237 Provider decorator. Use it to make a function into PyDataProvider2 object.
238238 In this function, user only need to get each sample for some train/test
@@ -318,11 +318,6 @@ def __init__(self, file_list, **kwargs):
318318 self .logger = logging .getLogger ("" )
319319 self .logger .setLevel (logging .INFO )
320320 self .input_types = None
321- if 'slots' in kwargs :
322- self .logger .warning ('setting slots value is deprecated, '
323- 'please use input_types instead.' )
324- self .slots = kwargs ['slots' ]
325- self .slots = input_types
326321 self .should_shuffle = should_shuffle
327322
328323 true_table = [1 , 't' , 'true' , 'on' ]
@@ -358,9 +353,19 @@ def __init__(self, file_list, **kwargs):
358353 self .check = check
359354 if init_hook is not None :
360355 init_hook (self , file_list = file_list , ** kwargs )
356+
357+ if 'slots' in outter_kwargs :
358+ self .logger .warning ('setting slots value is deprecated, '
359+ 'please use input_types instead.' )
360+ self .slots = outter_kwargs ['slots' ]
361+ if input_types is not None :
362+ self .slots = input_types
363+
361364 if self .input_types is not None :
362365 self .slots = self .input_types
363- assert self .slots is not None
366+
367+ assert self .slots is not None , \
368+ "Data Provider's input_types must be set"
364369 assert self .generator is not None
365370
366371 use_dynamic_order = False
0 commit comments