|
4 | 4 | from ast import literal_eval
|
5 | 5 | from typing import Callable
|
6 | 6 |
|
| 7 | +# import custom Yaml types to handle sequences |
| 8 | +from ruamel.yaml import CommentedSeq, CommentedMap |
| 9 | +from ruamel.yaml.scalarfloat import ScalarFloat |
| 10 | +from ruamel.yaml.scalarint import ScalarInt |
| 11 | +from ruamel.yaml.scalarbool import ScalarBoolean |
| 12 | + |
7 | 13 |
|
8 | 14 | def infer_boolean(var, strict: bool=True):
|
9 | 15 | """
|
@@ -56,7 +62,7 @@ def infer_str(var, strict:bool=True):
|
56 | 62 | else:
|
57 | 63 | return str(var)
|
58 | 64 |
|
59 |
| -def infer_iterable(var, inner_type: Callable=None, strict: bool=True): |
| 65 | +def infer_iterable(var, strict: bool=True, inner_type: Callable=None): |
60 | 66 | # Use ast.literal_eval to parse the iterable tree,
|
61 | 67 | # then use custom type handling to infer the inner types
|
62 | 68 | raw_ast_iter = literal_eval(var)
|
@@ -103,25 +109,30 @@ def __init__(self, orig_type: Callable, strict: bool=True):
|
103 | 109 |
|
104 | 110 | """
|
105 | 111 | if orig_type == type(None):
|
106 |
| - self.orig_type = infer_str |
| 112 | + self.orig_type = str |
107 | 113 | else:
|
108 | 114 | self.orig_type = orig_type
|
109 | 115 | self.strict = strict
|
| 116 | + |
| 117 | + if self.orig_type == str or self.orig_type == type(None): |
| 118 | + self.type_callable = infer_str |
| 119 | + elif self.orig_type == bool or self.orig_type == ScalarBoolean: |
| 120 | + self.type_callable = infer_boolean |
| 121 | + elif self.orig_type == float or self.orig_type == int or self.orig_type == ScalarFloat or self.orig_type == ScalarInt: |
| 122 | + self.type_callable = infer_numeric |
| 123 | + elif self.orig_type == tuple or self.orig_type == list or self.orig_type == CommentedMap or self.orig_type == CommentedSeq: |
| 124 | + self.type_callable = infer_iterable |
| 125 | + else: |
| 126 | + self.type_callable = self.orig_type |
110 | 127 |
|
111 | 128 | def __call__(self, var):
|
112 | 129 | """
|
113 | 130 | Callable method passed to argparse's builtin Callable type argument.
|
114 | 131 | var: original variable (any type)
|
115 | 132 |
|
| 133 | + calls the proper type inferencer |
116 | 134 | """
|
117 |
| - if self.orig_type == bool: |
118 |
| - return infer_boolean(var, self.strict) |
119 |
| - elif self.orig_type == float or self.orig_type == int: |
120 |
| - return infer_numeric(var, self.strict) |
121 |
| - elif self.orig_type == tuple or self.orig_type == list: |
122 |
| - return infer_iterable(var, None, self.strict) |
123 |
| - else: |
124 |
| - if self.strict: |
125 |
| - return infer_str(var) |
126 |
| - else: |
127 |
| - return self.orig_type(var) |
| 135 | + return self.type_callable(var, self.strict) |
| 136 | + |
| 137 | + def __repr__(self): |
| 138 | + return f"TypeInferencer[{self.orig_type}]" |
0 commit comments