Skip to content

Commit 2642efc

Browse files
committed
fix infer iterable signature and clean up TypeInferencer API
1 parent c4e6f5a commit 2642efc

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

src/configmypy/type_inference.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def infer_str(var, strict:bool=True):
6262
else:
6363
return str(var)
6464

65-
def infer_iterable(var, inner_type: Callable=None, strict: bool=True):
65+
def infer_iterable(var, strict: bool=True, inner_type: Callable=None):
6666
# Use ast.literal_eval to parse the iterable tree,
6767
# then use custom type handling to infer the inner types
6868
raw_ast_iter = literal_eval(var)
@@ -109,25 +109,30 @@ def __init__(self, orig_type: Callable, strict: bool=True):
109109
110110
"""
111111
if orig_type == type(None):
112-
self.orig_type = infer_str
112+
self.orig_type = str
113113
else:
114114
self.orig_type = orig_type
115115
self.strict = strict
116+
117+
if self.orig_type == str or self.orig_type == type(None):
118+
self.type_callable = infer_str
119+
elif self.orig_type == bool or self.orig_type == ScalarBoolean:
120+
self.type_callable = infer_boolean
121+
elif self.orig_type == float or self.orig_type == int or self.orig_type == ScalarFloat or self.orig_type == ScalarInt:
122+
self.type_callable = infer_numeric
123+
elif self.orig_type == tuple or self.orig_type == list or self.orig_type == CommentedMap or self.orig_type == CommentedSeq:
124+
self.type_callable = infer_iterable
125+
else:
126+
self.type_callable = self.orig_type
116127

117128
def __call__(self, var):
118129
"""
119130
Callable method passed to argparse's builtin Callable type argument.
120131
var: original variable (any type)
121132
133+
calls the proper type inferencer
122134
"""
123-
if self.orig_type == bool or self.orig_type == ScalarBoolean:
124-
return infer_boolean(var, self.strict)
125-
elif self.orig_type == float or self.orig_type == int or self.orig_type == ScalarFloat or self.orig_type == ScalarInt:
126-
return infer_numeric(var, self.strict)
127-
elif self.orig_type == tuple or self.orig_type == list or self.orig_type == CommentedMap or self.orig_type == CommentedSeq:
128-
return infer_iterable(var, None, self.strict)
129-
else:
130-
if self.strict:
131-
return infer_str(var)
132-
else:
133-
return self.orig_type(var)
135+
return self.type_callable(var, self.strict)
136+
137+
def __repr__(self):
138+
return f"TypeInferencer[{self.orig_type}]"

0 commit comments

Comments
 (0)