@@ -62,7 +62,7 @@ def infer_str(var, strict:bool=True):
62
62
else :
63
63
return str (var )
64
64
65
- def infer_iterable (var , inner_type : Callable = None , strict : bool = True ):
65
+ def infer_iterable (var , strict : bool = True , inner_type : Callable = None ):
66
66
# Use ast.literal_eval to parse the iterable tree,
67
67
# then use custom type handling to infer the inner types
68
68
raw_ast_iter = literal_eval (var )
@@ -109,25 +109,30 @@ def __init__(self, orig_type: Callable, strict: bool=True):
109
109
110
110
"""
111
111
if orig_type == type (None ):
112
- self .orig_type = infer_str
112
+ self .orig_type = str
113
113
else :
114
114
self .orig_type = orig_type
115
115
self .strict = strict
116
+
117
+ if self .orig_type == str or self .orig_type == type (None ):
118
+ self .type_callable = infer_str
119
+ elif self .orig_type == bool or self .orig_type == ScalarBoolean :
120
+ self .type_callable = infer_boolean
121
+ elif self .orig_type == float or self .orig_type == int or self .orig_type == ScalarFloat or self .orig_type == ScalarInt :
122
+ self .type_callable = infer_numeric
123
+ elif self .orig_type == tuple or self .orig_type == list or self .orig_type == CommentedMap or self .orig_type == CommentedSeq :
124
+ self .type_callable = infer_iterable
125
+ else :
126
+ self .type_callable = self .orig_type
116
127
117
128
def __call__ (self , var ):
118
129
"""
119
130
Callable method passed to argparse's builtin Callable type argument.
120
131
var: original variable (any type)
121
132
133
+ calls the proper type inferencer
122
134
"""
123
- if self .orig_type == bool or self .orig_type == ScalarBoolean :
124
- return infer_boolean (var , self .strict )
125
- elif self .orig_type == float or self .orig_type == int or self .orig_type == ScalarFloat or self .orig_type == ScalarInt :
126
- return infer_numeric (var , self .strict )
127
- elif self .orig_type == tuple or self .orig_type == list or self .orig_type == CommentedMap or self .orig_type == CommentedSeq :
128
- return infer_iterable (var , None , self .strict )
129
- else :
130
- if self .strict :
131
- return infer_str (var )
132
- else :
133
- return self .orig_type (var )
135
+ return self .type_callable (var , self .strict )
136
+
137
+ def __repr__ (self ):
138
+ return f"TypeInferencer[{ self .orig_type } ]"
0 commit comments