|
| 1 | +# Infer types of argparse arguments intelligently |
| 2 | + |
| 3 | +from argparse import ArgumentTypeError |
| 4 | +from ast import literal_eval |
| 5 | +from typing import Callable |
| 6 | + |
| 7 | + |
| 8 | +def infer_boolean(var, strict: bool=True): |
| 9 | + """ |
| 10 | + accept argparse inputs of string, correctly |
| 11 | + convert into bools. Without this behavior, |
| 12 | + bool('False') becomes True by default. |
| 13 | +
|
| 14 | + Parameters |
| 15 | + ---------- |
| 16 | + var: input from parser.args |
| 17 | + """ |
| 18 | + if var.lower() == 'true': |
| 19 | + return True |
| 20 | + elif var.lower() == 'false': |
| 21 | + return False |
| 22 | + elif var.lower() == 'None': |
| 23 | + return None |
| 24 | + elif strict: |
| 25 | + raise ArgumentTypeError() |
| 26 | + else: |
| 27 | + return str(var) |
| 28 | + |
| 29 | +def infer_numeric(var, strict: bool=True): |
| 30 | + # int if possible -> float -> NoneType -> Err |
| 31 | + if var.isnumeric(): |
| 32 | + return int(var) |
| 33 | + elif '.' in list(var): |
| 34 | + decimal_left_right = var.split('.') |
| 35 | + if len(decimal_left_right) == 2: |
| 36 | + if sum([x.isnumeric() for x in decimal_left_right]) == 2: # True, True otherwise False |
| 37 | + return float(var) |
| 38 | + elif var.lower() == 'none': |
| 39 | + return None |
| 40 | + elif strict: |
| 41 | + raise ArgumentTypeError() |
| 42 | + else: |
| 43 | + return str(var) |
| 44 | + |
| 45 | +def infer_str(var, strict:bool=True): |
| 46 | + """ |
| 47 | + infer string values and handle fuzzy None and bool types (optional) |
| 48 | +
|
| 49 | + var: str input |
| 50 | + strict: whether to allow "fuzzy" None types |
| 51 | + """ |
| 52 | + if strict: |
| 53 | + return str(var) |
| 54 | + elif var.lower() == 'None': |
| 55 | + return None |
| 56 | + else: |
| 57 | + return str(var) |
| 58 | + |
| 59 | +def infer_iterable(var, inner_type: Callable=None, strict: bool=True): |
| 60 | + # Use ast.literal_eval to parse the iterable tree, |
| 61 | + # then use custom type handling to infer the inner types |
| 62 | + raw_ast_iter = literal_eval(var) |
| 63 | + if inner_type is not None: |
| 64 | + return iterable_helper(raw_ast_iter, inner_type, strict) |
| 65 | + else: |
| 66 | + # currently argparse config cannot support inferring |
| 67 | + # more granular types than list or tuple |
| 68 | + return raw_ast_iter |
| 69 | + |
| 70 | + |
| 71 | +def iterable_helper(var, inner_type: Callable, strict: bool=True): |
| 72 | + """ |
| 73 | + recursively loop through iterable and apply custom type |
| 74 | + callables to each inner variable to conform to strictness |
| 75 | + """ |
| 76 | + if isinstance(var, list): |
| 77 | + return [iterable_helper(x,inner_type, strict) for x in var] |
| 78 | + elif isinstance(var, tuple): |
| 79 | + return tuple([iterable_helper(x,inner_type, strict) for x in var]) |
| 80 | + else: |
| 81 | + return inner_type(str(var),strict) |
| 82 | + |
| 83 | + |
| 84 | +class TypeInferencer(object): |
| 85 | + def __init__(self, orig_type: Callable, strict: bool=True): |
| 86 | + """ |
| 87 | + TypeInferencer mediates between argparse |
| 88 | + and ArgparseConfig |
| 89 | + |
| 90 | + orig_type: Callable type |
| 91 | + type of original var from config |
| 92 | + cannot be NoneType - defaults to infer_str |
| 93 | + strict: bool, default True |
| 94 | + whether to use type inferencing. If False, |
| 95 | + default to simply applying default type converter. |
| 96 | + This may cause issues for some types: |
| 97 | + ex. if argparseConfig takes a boolean for arg 'x', |
| 98 | + passing --x False into the command line will return a value |
| 99 | + of True, since argparse works with strings and bool('False') = True. |
| 100 | + strict: bool, default True |
| 101 | + if True, raise ArgumentTypeError when the value cannot be cast to |
| 102 | + the allowed types. Otherwise default to str. |
| 103 | +
|
| 104 | + """ |
| 105 | + if orig_type == type(None): |
| 106 | + self.orig_type = infer_str |
| 107 | + else: |
| 108 | + self.orig_type = orig_type |
| 109 | + self.strict = strict |
| 110 | + |
| 111 | + def __call__(self, var): |
| 112 | + """ |
| 113 | + Callable method passed to argparse's builtin Callable type argument. |
| 114 | + var: original variable (any type) |
| 115 | +
|
| 116 | + """ |
| 117 | + if self.orig_type == bool: |
| 118 | + return infer_boolean(var, self.strict) |
| 119 | + elif self.orig_type == float or self.orig_type == int: |
| 120 | + return infer_numeric(var, self.strict) |
| 121 | + elif self.orig_type == tuple or self.orig_type == list: |
| 122 | + return infer_iterable(var, None, self.strict) |
| 123 | + else: |
| 124 | + if self.strict: |
| 125 | + return infer_str(var) |
| 126 | + else: |
| 127 | + return self.orig_type(var) |
0 commit comments