Skip to content

Commit 836382b

Browse files
authored
Adds type inference to argparse (#1)
* infer boolean or numeric * add typeinferencer class * type inferencer * default to infer string * fuzzy read_conf * remove fuzzy tag * add support for iterable types for depth 1 * fix default type to str * add basic iterable inference w/support for nested iterables, not strict * add tests for inferring iterables & string-to-bool * add infer types choices * update argparse config signature * set argparse config to default to fuzzy * add int --> float test
1 parent a7baa4a commit 836382b

File tree

3 files changed

+178
-17
lines changed

3 files changed

+178
-17
lines changed

src/configmypy/argparse_config.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import argparse
22
from copy import deepcopy
3+
from ctypes import ArgumentError
34
from .bunch import Bunch
45
from .utils import iter_nested_dict_flat
56
from .utils import update_nested_dict_from_flat
6-
7+
from .type_inference import TypeInferencer
78

89
class ArgparseConfig:
910
"""Read config from the command-line using argparse
@@ -12,8 +13,10 @@ class ArgparseConfig:
1213
1314
Parameters
1415
----------
15-
infer_types : bool, default is True
16-
if True, use type(value) to indicate expected type
16+
infer_types : False or literal['fuzzy', 'strict']
17+
if False, use python's typecasting from type(value) to indicate expected type
18+
if fuzzy, use custom type handling and allow all types to take values of None
19+
if strict, use custom type handling but do not allow passing bool or None to other types.
1720
1821
overwrite_nested_config : bool, default is True
1922
if True, users can set values at different levels of nesting
@@ -29,9 +32,12 @@ class ArgparseConfig:
2932
**additional_config : dict
3033
key, values to read from command-line and pass on to the next config
3134
"""
32-
def __init__(self, infer_types=True, overwrite_nested_config=False, **additional_config):
35+
def __init__(self, infer_types="fuzzy", overwrite_nested_config=False, **additional_config):
3336
self.additional_config = Bunch(additional_config)
34-
self.infer_types = infer_types
37+
if infer_types in [False, "fuzzy", "strict"]:
38+
self.infer_types = infer_types
39+
else:
40+
raise ArgumentError("Error: infer_types only takes values False, \'fuzzy\', \'strict\'")
3541
self.overwrite_nested_config = overwrite_nested_config
3642

3743
def read_conf(self, config=None, **additional_config):
@@ -44,6 +50,12 @@ def read_conf(self, config=None, **additional_config):
4450
----------
4551
config : dict, default is None
4652
if not None, a dict config to update
53+
infer_types: bool, default is False
54+
if True, uses custom type handling
55+
where all values, regardless of type,
56+
may take the value None.
57+
If false, uses default argparse
58+
typecasting
4759
additional_config : dict
4860
4961
Returns
@@ -61,10 +73,19 @@ def read_conf(self, config=None, **additional_config):
6173

6274
parser = argparse.ArgumentParser(description='Read the config from the commandline.')
6375
for key, value in iter_nested_dict_flat(config, return_intermediate_keys=self.overwrite_nested_config):
64-
if self.infer_types and value is not None:
65-
parser.add_argument(f'--{key}', type=type(value), default=value)
66-
else:
67-
parser.add_argument(f'--{key}', default=value)
76+
# smartly infer types if infer_types is turned on
77+
# otherwise force default typecasting
78+
if self.infer_types:
79+
if self.infer_types == 'strict':
80+
strict=True
81+
elif self.infer_types == 'fuzzy':
82+
strict=False
83+
84+
type_inferencer = TypeInferencer(orig_type=type(value), strict=strict)
85+
else:
86+
type_inferencer = type(value)
87+
88+
parser.add_argument(f'--{key}', type=type_inferencer, default=value)
6889

6990
args = parser.parse_args()
7091
self.config = Bunch(args.__dict__)

src/configmypy/tests/test_argparse_config.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,46 @@
33

44

55
TEST_CONFIG_DICT = {
6-
'default': {'opt': {'optimizer': 'adam', 'lr': 0.1},
7-
'data': {'dataset': 'ns', 'batch_size': 12}},
8-
'test': {'opt': {'optimizer': 'SGD'}}
6+
'default': {'opt': {'optimizer': 'adam', 'lr': 1, 'regularizer': True},
7+
'data': {'dataset': 'ns', 'batch_size': 12, 'test_resolutions':[16,32]}},
8+
'test': {'opt': {'optimizer': 'SGD'}}
99
}
1010

1111

1212
def test_ArgparseConfig(monkeypatch):
1313
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24'])
1414
config = Bunch(TEST_CONFIG_DICT['default'])
15-
parser = ArgparseConfig(infer_types=True)
15+
parser = ArgparseConfig(infer_types='fuzzy')
1616
args, kwargs = parser.read_conf(config)
1717
config.data.batch_size = 24
1818
assert config == args
1919
assert kwargs == {}
20-
21-
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24', '--config_name', 'test'])
20+
21+
# test ability to infer None, int-->float and iterables
22+
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24',\
23+
'--data.test_resolutions', '[8, None]',\
24+
'--opt.lr', '0.01',\
25+
'--opt.regularizer', 'False',\
26+
'--config_name', 'test'])
2227
config = Bunch(TEST_CONFIG_DICT['default'])
23-
parser = ArgparseConfig(infer_types=True, config_name=None)
28+
parser = ArgparseConfig(infer_types='fuzzy', config_name=None)
2429
args, kwargs = parser.read_conf(config)
2530
config.data.batch_size = 24
31+
assert args.opt.lr == 0.01 # ensure no casting to int
32+
assert args.data.test_resolutions == [8, None]
33+
assert not args.opt.regularizer #boolean False, not 'False'
34+
35+
36+
args.data.test_resolutions = [16,32]
37+
args.opt.regularizer = True
38+
args.opt.lr = 1
2639
assert config == args
2740
assert kwargs == Bunch(dict(config_name='test'))
2841

2942
# Test overwriting entire nested argument
3043
monkeypatch.setattr("sys.argv", ['test', '--data', '0', '--config_name', 'test'])
3144
config = Bunch(TEST_CONFIG_DICT['default'])
32-
parser = ArgparseConfig(infer_types=True, config_name=None, overwrite_nested_config=True)
45+
parser = ArgparseConfig(infer_types='fuzzy', config_name=None, overwrite_nested_config=True)
3346
args, kwargs = parser.read_conf(config)
3447
config['data'] = 0
3548
assert config == args

src/configmypy/type_inference.py

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Infer types of argparse arguments intelligently
2+
3+
from argparse import ArgumentTypeError
4+
from ast import literal_eval
5+
from typing import Callable
6+
7+
8+
def infer_boolean(var, strict: bool=True):
9+
"""
10+
accept argparse inputs of string, correctly
11+
convert into bools. Without this behavior,
12+
bool('False') becomes True by default.
13+
14+
Parameters
15+
----------
16+
var: input from parser.args
17+
"""
18+
if var.lower() == 'true':
19+
return True
20+
elif var.lower() == 'false':
21+
return False
22+
elif var.lower() == 'None':
23+
return None
24+
elif strict:
25+
raise ArgumentTypeError()
26+
else:
27+
return str(var)
28+
29+
def infer_numeric(var, strict: bool=True):
30+
# int if possible -> float -> NoneType -> Err
31+
if var.isnumeric():
32+
return int(var)
33+
elif '.' in list(var):
34+
decimal_left_right = var.split('.')
35+
if len(decimal_left_right) == 2:
36+
if sum([x.isnumeric() for x in decimal_left_right]) == 2: # True, True otherwise False
37+
return float(var)
38+
elif var.lower() == 'none':
39+
return None
40+
elif strict:
41+
raise ArgumentTypeError()
42+
else:
43+
return str(var)
44+
45+
def infer_str(var, strict:bool=True):
46+
"""
47+
infer string values and handle fuzzy None and bool types (optional)
48+
49+
var: str input
50+
strict: whether to allow "fuzzy" None types
51+
"""
52+
if strict:
53+
return str(var)
54+
elif var.lower() == 'None':
55+
return None
56+
else:
57+
return str(var)
58+
59+
def infer_iterable(var, inner_type: Callable=None, strict: bool=True):
60+
# Use ast.literal_eval to parse the iterable tree,
61+
# then use custom type handling to infer the inner types
62+
raw_ast_iter = literal_eval(var)
63+
if inner_type is not None:
64+
return iterable_helper(raw_ast_iter, inner_type, strict)
65+
else:
66+
# currently argparse config cannot support inferring
67+
# more granular types than list or tuple
68+
return raw_ast_iter
69+
70+
71+
def iterable_helper(var, inner_type: Callable, strict: bool=True):
72+
"""
73+
recursively loop through iterable and apply custom type
74+
callables to each inner variable to conform to strictness
75+
"""
76+
if isinstance(var, list):
77+
return [iterable_helper(x,inner_type, strict) for x in var]
78+
elif isinstance(var, tuple):
79+
return tuple([iterable_helper(x,inner_type, strict) for x in var])
80+
else:
81+
return inner_type(str(var),strict)
82+
83+
84+
class TypeInferencer(object):
85+
def __init__(self, orig_type: Callable, strict: bool=True):
86+
"""
87+
TypeInferencer mediates between argparse
88+
and ArgparseConfig
89+
90+
orig_type: Callable type
91+
type of original var from config
92+
cannot be NoneType - defaults to infer_str
93+
strict: bool, default True
94+
whether to use type inferencing. If False,
95+
default to simply applying default type converter.
96+
This may cause issues for some types:
97+
ex. if argparseConfig takes a boolean for arg 'x',
98+
passing --x False into the command line will return a value
99+
of True, since argparse works with strings and bool('False') = True.
100+
strict: bool, default True
101+
if True, raise ArgumentTypeError when the value cannot be cast to
102+
the allowed types. Otherwise default to str.
103+
104+
"""
105+
if orig_type == type(None):
106+
self.orig_type = infer_str
107+
else:
108+
self.orig_type = orig_type
109+
self.strict = strict
110+
111+
def __call__(self, var):
112+
"""
113+
Callable method passed to argparse's builtin Callable type argument.
114+
var: original variable (any type)
115+
116+
"""
117+
if self.orig_type == bool:
118+
return infer_boolean(var, self.strict)
119+
elif self.orig_type == float or self.orig_type == int:
120+
return infer_numeric(var, self.strict)
121+
elif self.orig_type == tuple or self.orig_type == list:
122+
return infer_iterable(var, None, self.strict)
123+
else:
124+
if self.strict:
125+
return infer_str(var)
126+
else:
127+
return self.orig_type(var)

0 commit comments

Comments
 (0)