Skip to content

Commit 85f9160

Browse files
authored
Fix argparse config type inferencer (#2)
* update with new ruamel types * fix bug w custom ruamel yaml types * fix infer iterable signature and clean up TypeInferencer API
1 parent 46394af commit 85f9160

File tree

4 files changed

+46
-18
lines changed

4 files changed

+46
-18
lines changed

src/configmypy/argparse_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def read_conf(self, config=None, **additional_config):
8383
strict=True
8484
elif self.infer_types == 'fuzzy':
8585
strict=False
86-
86+
print(f"creating inferencer for {key} of type {type(value)}")
8787
type_inferencer = TypeInferencer(orig_type=type(value), strict=strict)
8888
else:
8989
type_inferencer = type(value)

src/configmypy/tests/test_argparse_config.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,17 @@
1010

1111

1212
def test_ArgparseConfig(monkeypatch):
13-
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24'])
13+
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24',\
14+
'--data.test_resolutions', '[16]'])
1415
config = Bunch(TEST_CONFIG_DICT['default'])
1516
parser = ArgparseConfig(infer_types='fuzzy')
1617
args, kwargs = parser.read_conf(config)
18+
assert args.data.test_resolutions == [16]
1719
config.data.batch_size = 24
20+
config.data.test_resolutions = [16]
1821
assert config == args
22+
# reset config data test_res
23+
config.data.test_resolutions = [16,32]
1924
assert kwargs == {}
2025

2126
# test ability to infer None, int-->float and iterables

src/configmypy/tests/test_pipeline_config.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
data:
1313
dataset: 'ns'
1414
batch_size: 12
15+
test_batch_sizes: [16,32]
16+
other:
17+
int_to_float: 1
18+
float_to_none: 0.5
1519
1620
test:
1721
opt:
@@ -20,23 +24,31 @@
2024

2125
TEST_CONFIG_DICT = {
2226
'default': {'opt': {'optimizer': 'adam', 'lr': 0.1},
23-
'data': {'dataset': 'ns', 'batch_size': 12}},
27+
'data': {'dataset': 'ns', 'batch_size': 12, 'test_batch_sizes':[16,16]},
28+
'other': {'int_to_float': 1, 'float_to_none': 0.5}},
2429
'test': {'opt': {'optimizer': 'SGD'}}
2530
}
2631

2732

2833
def test_ConfigPipeline(mocker, monkeypatch):
2934
""""Test for ConfigPipeline"""
3035
mocker.patch("builtins.open", mocker.mock_open(read_data=TEST_CONFIG_FILE))
31-
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24', '--config_file', 'config.yaml', '--config_name', 'test'])
36+
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24', \
37+
'--data.test_batch_sizes', '[16,16]',\
38+
'--other.int_to_float', '0.05', \
39+
'--other.float_to_none', 'None', \
40+
'--config_file', 'config.yaml', \
41+
'--config_name', 'test'])
3242

3343
true_config = Bunch(TEST_CONFIG_DICT['default'])
3444
true_config.data.batch_size = 24
3545
true_config.opt.optimizer = 'SGD'
46+
true_config.other.int_to_float = 0.05
47+
true_config.other.float_to_none = None
3648

3749
pipe = ConfigPipeline([
3850
YamlConfig('./config.yaml', config_name='default'),
39-
ArgparseConfig(config_file=None, config_name=None),
51+
ArgparseConfig(config_file=None, config_name=None, infer_types='fuzzy'),
4052
YamlConfig()
4153
])
4254
config = pipe.read_conf()

src/configmypy/type_inference.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
from ast import literal_eval
55
from typing import Callable
66

7+
# import custom Yaml types to handle sequences
8+
from ruamel.yaml import CommentedSeq, CommentedMap
9+
from ruamel.yaml.scalarfloat import ScalarFloat
10+
from ruamel.yaml.scalarint import ScalarInt
11+
from ruamel.yaml.scalarbool import ScalarBoolean
12+
713

814
def infer_boolean(var, strict: bool=True):
915
"""
@@ -56,7 +62,7 @@ def infer_str(var, strict:bool=True):
5662
else:
5763
return str(var)
5864

59-
def infer_iterable(var, inner_type: Callable=None, strict: bool=True):
65+
def infer_iterable(var, strict: bool=True, inner_type: Callable=None):
6066
# Use ast.literal_eval to parse the iterable tree,
6167
# then use custom type handling to infer the inner types
6268
raw_ast_iter = literal_eval(var)
@@ -103,25 +109,30 @@ def __init__(self, orig_type: Callable, strict: bool=True):
103109
104110
"""
105111
if orig_type == type(None):
106-
self.orig_type = infer_str
112+
self.orig_type = str
107113
else:
108114
self.orig_type = orig_type
109115
self.strict = strict
116+
117+
if self.orig_type == str or self.orig_type == type(None):
118+
self.type_callable = infer_str
119+
elif self.orig_type == bool or self.orig_type == ScalarBoolean:
120+
self.type_callable = infer_boolean
121+
elif self.orig_type == float or self.orig_type == int or self.orig_type == ScalarFloat or self.orig_type == ScalarInt:
122+
self.type_callable = infer_numeric
123+
elif self.orig_type == tuple or self.orig_type == list or self.orig_type == CommentedMap or self.orig_type == CommentedSeq:
124+
self.type_callable = infer_iterable
125+
else:
126+
self.type_callable = self.orig_type
110127

111128
def __call__(self, var):
112129
"""
113130
Callable method passed to argparse's builtin Callable type argument.
114131
var: original variable (any type)
115132
133+
calls the proper type inferencer
116134
"""
117-
if self.orig_type == bool:
118-
return infer_boolean(var, self.strict)
119-
elif self.orig_type == float or self.orig_type == int:
120-
return infer_numeric(var, self.strict)
121-
elif self.orig_type == tuple or self.orig_type == list:
122-
return infer_iterable(var, None, self.strict)
123-
else:
124-
if self.strict:
125-
return infer_str(var)
126-
else:
127-
return self.orig_type(var)
135+
return self.type_callable(var, self.strict)
136+
137+
def __repr__(self):
138+
return f"TypeInferencer[{self.orig_type}]"

0 commit comments

Comments
 (0)