Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit c7fe54f

Browse files
"tune function and CLI command"
1 parent a9bdd58 commit c7fe54f

File tree

7 files changed

+237
-15
lines changed

7 files changed

+237
-15
lines changed

dffml/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class DuplicateName(Exception):
5757
"train": "high_level.ml",
5858
"predict": "high_level.ml",
5959
"score": "high_level.ml",
60+
"tune": "high_level.ml",
6061
"load": "high_level.source",
6162
"save": "high_level.source",
6263
"run": "high_level.dataflow",

dffml/cli/cli.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from .dataflow import Dataflow
4141
from .config import Config
42-
from .ml import Train, Accuracy, Predict
42+
from .ml import Train, Accuracy, Predict, Tune
4343
from .list import List
4444

4545
version = VERSION
@@ -366,6 +366,7 @@ class CLI(CMD):
366366
train = Train
367367
accuracy = Accuracy
368368
predict = Predict
369+
tune = Tune
369370
service = services()
370371
dataflow = Dataflow
371372
config = Config

dffml/cli/ml.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import inspect
22

33
from ..model.model import Model
4+
from ..tuner.tuner import Tuner
45
from ..source.source import Sources, SubsetSources
56
from ..util.cli.cmd import CMD, CMDOutputOverride
6-
from ..high_level.ml import train, predict, score
7+
from ..high_level.ml import train, predict, score, tune
78
from ..util.config.fields import FIELD_SOURCES
89
from ..util.cli.cmds import (
910
SourcesCMD,
@@ -15,6 +16,7 @@
1516
)
1617
from ..base import config, field
1718
from ..accuracy import AccuracyScorer
19+
1820
from ..feature import Features
1921

2022

@@ -118,3 +120,37 @@ class Predict(CMD):
118120

119121
record = PredictRecord
120122
_all = PredictAll
123+
124+
125+
@config
126+
class TuneCMDConfig:
127+
model: Model = field("Model used for ML", required=True)
128+
tuner: Tuner = field("Tuner to optimize hyperparameters", required=True)
129+
scorer: AccuracyScorer = field(
130+
"Method to use to score accuracy", required=True
131+
)
132+
features: Features = field("Predict Feature(s)", default=Features())
133+
sources: Sources = FIELD_SOURCES
134+
135+
136+
class Tune(MLCMD):
137+
"""Optimize hyperparameters of model with given sources"""
138+
139+
CONFIG = TuneCMDConfig
140+
141+
async def run(self):
142+
# Instantiate the accuracy scorer class if for some reason it is a class
143+
# at this point rather than an instance.
144+
if inspect.isclass(self.scorer):
145+
self.scorer = self.scorer.withconfig(self.extra_config)
146+
if inspect.isclass(self.tuner):
147+
self.tuner = self.tuner.withconfig(self.extra_config)
148+
149+
return await tune(
150+
self.model,
151+
self.tuner,
152+
self.scorer,
153+
self.features,
154+
[self.sources[0]],
155+
[self.sources[1]],
156+
)

dffml/high_level/ml.py

+149
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import contextlib
22
from typing import Union, Dict, Any, List
33

4+
45
from ..record import Record
56
from ..source.source import BaseSource
67
from ..feature import Feature, Features
78
from ..model import Model, ModelContext
89
from ..util.internal import records_to_sources, list_records_to_dict
910
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext
11+
from ..tuner import Tuner, TunerContext
1012

1113

1214
async def train(model, *args: Union[BaseSource, Record, Dict[str, Any], List]):
@@ -293,3 +295,150 @@ async def predict(
293295
)
294296
if update:
295297
await sctx.update(record)
298+
299+
async def tune(
300+
model,
301+
tuner: Union[Tuner, TunerContext],
302+
accuracy_scorer: Union[AccuracyScorer, AccuracyContext],
303+
features: Union[Feature, Features],
304+
train_ds: Union[BaseSource, Record, Dict[str, Any], List],
305+
valid_ds: Union[BaseSource, Record, Dict[str, Any], List],
306+
) -> float:
307+
308+
"""
309+
Tune the hyperparameters of a model with a given tuner.
310+
311+
312+
Parameters
313+
----------
314+
model : Model
315+
Machine Learning model to use. See :doc:`/plugins/dffml_model` for
316+
models options.
317+
tuner: Tuner
318+
Hyperparameter tuning method to use. See :doc:`/plugins/dffml_tuner` for
319+
tuner options.
320+
train_ds : list
321+
Input data for training. Could be a ``dict``, :py:class:`Record`,
322+
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
323+
with the extension being one of the data sources.
324+
valid_ds : list
325+
Validation data for testing. Could be a ``dict``, :py:class:`Record`,
326+
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
327+
with the extension being one of the data sources.
328+
329+
330+
Returns
331+
-------
332+
float
333+
A decimal value representing the result of the accuracy scorer on the given
334+
test set. For instance, ClassificationAccuracy represents the percentage of correct
335+
classifications made by the model.
336+
337+
Examples
338+
--------
339+
340+
>>> import asyncio
341+
>>> from dffml import *
342+
>>> from dffml_model_xgboost.xgbclassifier import XGBClassifierModel
343+
>>>
344+
>>> model = XGBClassifierModel(
345+
... features=Features(
346+
... Feature("SepalLength", float, 1),
347+
... Feature("SepalWidth", float, 1),
348+
... Feature("PetalLength", float, 1),
349+
... ),
350+
... predict=Feature("classification", int, 1),
351+
... location="tempdir",
352+
... )
353+
>>>
354+
>>> async def main():
355+
... await tune(
356+
... model,
357+
... ParameterGrid(
358+
... parameters={
359+
... "learning_rate": [0.01, 0.05, 0.1],
360+
... "n_estimators": [20, 100, 200],
361+
... "max_depth": [3,5,8]
362+
... }
363+
... ),
364+
... MeanSquaredErrorAccuracy(),
365+
... Features(
366+
... Feature("SepalLength", float, 1),
367+
... Feature("SepalWidth", float, 1),
368+
... Feature("PetalLength", float, 1),
369+
... ),
370+
... [CSVSource(filename="iris_training.csv")],
371+
... [CSVSource(filename="iris_test.csv")],
372+
... )
373+
>>>
374+
>>> asyncio.run(main())
375+
Accuracy: 0.0
376+
"""
377+
378+
if not isinstance(features, (Feature, Features)):
379+
raise TypeError(
380+
f"features was {type(features)}: {features!r}. Should have been Feature or Features"
381+
)
382+
if isinstance(features, Feature):
383+
features = Features(features)
384+
if hasattr(model.config, "predict"):
385+
if isinstance(model.config.predict, Features):
386+
predict_feature = [
387+
feature.name for feature in model.config.predict
388+
]
389+
else:
390+
predict_feature = [model.config.predict.name]
391+
392+
if hasattr(model.config, "features") and any(
393+
isinstance(td, list) for td in train_ds
394+
):
395+
train_ds = list_records_to_dict(
396+
[feature.name for feature in model.config.features]
397+
+ predict_feature,
398+
*train_ds,
399+
model=model,
400+
)
401+
if hasattr(model.config, "features") and any(
402+
isinstance(td, list) for td in valid_ds
403+
):
404+
valid_ds = list_records_to_dict(
405+
[feature.name for feature in model.config.features]
406+
+ predict_feature,
407+
*valid_ds,
408+
model=model,
409+
)
410+
411+
async with contextlib.AsyncExitStack() as astack:
412+
# Open sources
413+
train = await astack.enter_async_context(records_to_sources(*train_ds))
414+
test = await astack.enter_async_context(records_to_sources(*valid_ds))
415+
# Allow for keep models open
416+
if isinstance(model, Model):
417+
model = await astack.enter_async_context(model)
418+
mctx = await astack.enter_async_context(model())
419+
elif isinstance(model, ModelContext):
420+
mctx = model
421+
422+
# Allow for keep models open
423+
if isinstance(accuracy_scorer, AccuracyScorer):
424+
accuracy_scorer = await astack.enter_async_context(accuracy_scorer)
425+
actx = await astack.enter_async_context(accuracy_scorer())
426+
elif isinstance(accuracy_scorer, AccuracyContext):
427+
actx = accuracy_scorer
428+
else:
429+
# TODO Replace this with static type checking and maybe dynamic
430+
# through something like pydantic. See issue #36
431+
raise TypeError(f"{accuracy_scorer} is not an AccuracyScorer")
432+
433+
if isinstance(tuner, Tuner):
434+
tuner = await astack.enter_async_context(tuner)
435+
tctx = await astack.enter_async_context(tuner())
436+
elif isinstance(tuner, TunerContext):
437+
tctx = tuner
438+
else:
439+
raise TypeError(f"{tuner} is not an Tuner")
440+
441+
return float(
442+
await tctx.optimize(mctx, model.config.predict, actx, train, test)
443+
)
444+

dffml/noasync.py

+16
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
train as high_level_train,
77
score as high_level_score,
88
predict as high_level_predict,
9+
tune as high_level_tune,
910
)
1011

1112

@@ -24,6 +25,21 @@ def train(*args, **kwargs):
2425
)
2526
)
2627

28+
def tune(*args, **kwargs):
29+
return asyncio.run(high_level_tune(*args, **kwargs))
30+
31+
32+
tune.__doc__ = (
33+
high_level_tune.__doc__.replace("await ", "")
34+
.replace("async ", "")
35+
.replace("asyncio.run(main())", "main()")
36+
.replace(" >>> import asyncio\n", "")
37+
.replace(
38+
" >>> from dffml import *\n",
39+
" >>> from dffml import *\n >>> from dffml.noasync import tune\n",
40+
)
41+
)
42+
2743

2844
def score(*args, **kwargs):
2945
return asyncio.run(high_level_score(*args, **kwargs))

dffml/tuner/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,4 @@
77
TunerConfig,
88
TunerContext,
99
Tuner,
10-
)
11-
from .parameter_grid import ParameterGrid
10+
)

dffml/tuner/parameter_grid.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
@config
1919
class ParameterGridConfig:
2020
parameters: dict = field("Parameters to be optimized")
21+
objective: str = field("How to optimize for the scorer", default="max")
2122

2223

2324
class ParameterGridContext(TunerContext):
@@ -61,31 +62,50 @@ async def optimize(
6162
float
6263
The highest score value
6364
"""
64-
highest_acc = -1
65+
if self.parent.config.objective == "min":
66+
highest_acc = float("inf")
67+
elif self.parent.config.objective == "max":
68+
highest_acc = -1
69+
6570
best_config = dict()
6671
logging.info(
6772
f"Optimizing model with parameter grid: {self.parent.config.parameters}"
6873
)
74+
6975
names = list(self.parent.config.parameters.keys())
7076
logging.info(names)
71-
with model.config.no_enforce_immutable():
77+
78+
with model.parent.config.no_enforce_immutable():
7279
for combination in itertools.product(
7380
*list(self.parent.config.parameters.values())
7481
):
7582
logging.info(combination)
83+
7684
for i in range(len(combination)):
7785
param = names[i]
78-
setattr(model.config, names[i], combination[i])
79-
await train(model, *train_data)
80-
acc = await score(model, accuracy_scorer, feature, *test_data)
86+
setattr(model.parent.config, names[i], combination[i])
87+
88+
await train(model.parent, *train_data)
89+
90+
acc = await score(
91+
model.parent, accuracy_scorer, feature, *test_data
92+
)
93+
8194
logging.info(f"Accuracy of the tuned model: {acc}")
82-
if acc > highest_acc:
83-
highest_acc = acc
84-
for param in names:
85-
best_config[param] = getattr(model.config, param)
95+
if self.parent.config.objective == "min":
96+
if acc < highest_acc:
97+
highest_acc = acc
98+
99+
elif self.parent.config.objective == "max":
100+
if acc > highest_acc:
101+
highest_acc = acc
102+
for param in names:
103+
best_config[param] = getattr(
104+
model.parent.config, param
105+
)
86106
for param in names:
87-
setattr(model.config, param, best_config[param])
88-
await train(model, *train_data)
107+
setattr(model.parent.config, param, best_config[param])
108+
await train(model.parent, *train_data)
89109
logging.info(f"\nOptimal Hyper-parameters: {best_config}")
90110
logging.info(f"Accuracy of Optimized model: {highest_acc}")
91111
return highest_acc

0 commit comments

Comments
 (0)