Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 4a7de3a

Browse files
"tune function and CLI command"
1 parent a9bdd58 commit 4a7de3a

File tree

12 files changed

+243
-16
lines changed

12 files changed

+243
-16
lines changed

dffml/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class DuplicateName(Exception):
5757
"train": "high_level.ml",
5858
"predict": "high_level.ml",
5959
"score": "high_level.ml",
60+
"tune": "high_level.ml",
6061
"load": "high_level.source",
6162
"save": "high_level.source",
6263
"run": "high_level.dataflow",

dffml/cli/cli.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from .dataflow import Dataflow
4141
from .config import Config
42-
from .ml import Train, Accuracy, Predict
42+
from .ml import Train, Accuracy, Predict, Tune
4343
from .list import List
4444

4545
version = VERSION
@@ -366,6 +366,7 @@ class CLI(CMD):
366366
train = Train
367367
accuracy = Accuracy
368368
predict = Predict
369+
tune = Tune
369370
service = services()
370371
dataflow = Dataflow
371372
config = Config

dffml/cli/ml.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import inspect
22

33
from ..model.model import Model
4+
from ..tuner.tuner import Tuner
45
from ..source.source import Sources, SubsetSources
56
from ..util.cli.cmd import CMD, CMDOutputOverride
6-
from ..high_level.ml import train, predict, score
7+
from ..high_level.ml import train, predict, score, tune
78
from ..util.config.fields import FIELD_SOURCES
89
from ..util.cli.cmds import (
910
SourcesCMD,
@@ -15,6 +16,7 @@
1516
)
1617
from ..base import config, field
1718
from ..accuracy import AccuracyScorer
19+
1820
from ..feature import Features
1921

2022

@@ -118,3 +120,37 @@ class Predict(CMD):
118120

119121
record = PredictRecord
120122
_all = PredictAll
123+
124+
125+
@config
126+
class TuneCMDConfig:
127+
model: Model = field("Model used for ML", required=True)
128+
tuner: Tuner = field("Tuner to optimize hyperparameters", required=True)
129+
scorer: AccuracyScorer = field(
130+
"Method to use to score accuracy", required=True
131+
)
132+
features: Features = field("Predict Feature(s)", default=Features())
133+
sources: Sources = FIELD_SOURCES
134+
135+
136+
class Tune(MLCMD):
137+
"""Optimize hyperparameters of model with given sources"""
138+
139+
CONFIG = TuneCMDConfig
140+
141+
async def run(self):
142+
# Instantiate the accuracy scorer class if for some reason it is a class
143+
# at this point rather than an instance.
144+
if inspect.isclass(self.scorer):
145+
self.scorer = self.scorer.withconfig(self.extra_config)
146+
if inspect.isclass(self.tuner):
147+
self.tuner = self.tuner.withconfig(self.extra_config)
148+
149+
return await tune(
150+
self.model,
151+
self.tuner,
152+
self.scorer,
153+
self.features,
154+
[self.sources[0]],
155+
[self.sources[1]],
156+
)

dffml/high_level/ml.py

+148
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import contextlib
22
from typing import Union, Dict, Any, List
33

4+
45
from ..record import Record
56
from ..source.source import BaseSource
67
from ..feature import Feature, Features
78
from ..model import Model, ModelContext
89
from ..util.internal import records_to_sources, list_records_to_dict
910
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext
11+
from ..tuner import Tuner, TunerContext
1012

1113

1214
async def train(model, *args: Union[BaseSource, Record, Dict[str, Any], List]):
@@ -293,3 +295,149 @@ async def predict(
293295
)
294296
if update:
295297
await sctx.update(record)
298+
299+
async def tune(
300+
model,
301+
tuner: Union[Tuner, TunerContext],
302+
accuracy_scorer: Union[AccuracyScorer, AccuracyContext],
303+
features: Union[Feature, Features],
304+
train_ds: Union[BaseSource, Record, Dict[str, Any], List],
305+
valid_ds: Union[BaseSource, Record, Dict[str, Any], List],
306+
) -> float:
307+
308+
"""
309+
Tune the hyperparameters of a model with a given tuner.
310+
311+
312+
Parameters
313+
----------
314+
model : Model
315+
Machine Learning model to use. See :doc:`/plugins/dffml_model` for
316+
models options.
317+
tuner: Tuner
318+
Hyperparameter tuning method to use. See :doc:`/plugins/dffml_tuner` for
319+
tuner options.
320+
train_ds : list
321+
Input data for training. Could be a ``dict``, :py:class:`Record`,
322+
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
323+
with the extension being one of the data sources.
324+
valid_ds : list
325+
Validation data for testing. Could be a ``dict``, :py:class:`Record`,
326+
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
327+
with the extension being one of the data sources.
328+
329+
330+
Returns
331+
-------
332+
float
333+
A decimal value representing the result of the accuracy scorer on the given
334+
test set. For instance, ClassificationAccuracy represents the percentage of correct
335+
classifications made by the model.
336+
337+
Examples
338+
--------
339+
340+
>>> import asyncio
341+
>>> from dffml import *
342+
>>>
343+
>>> model = SLRModel(
344+
... features=Features(
345+
... Feature("Years", int, 1),
346+
... ),
347+
... predict=Feature("Salary", int, 1),
348+
... location="tempdir",
349+
... )
350+
>>>
351+
>>> async def main():
352+
... score = await tune(
353+
... model,
354+
... ParameterGrid(objective="min"),
355+
... MeanSquaredErrorAccuracy(),
356+
... Features(
357+
... Feature("Years", float, 1),
358+
... ),
359+
... [
360+
... {"Years": 0, "Salary": 10},
361+
... {"Years": 1, "Salary": 20},
362+
... {"Years": 2, "Salary": 30},
363+
... {"Years": 3, "Salary": 40}
364+
... ],
365+
... [
366+
... {"Years": 6, "Salary": 70},
367+
... {"Years": 7, "Salary": 80}
368+
... ]
369+
...
370+
... )
371+
... print(f"Tuner score: {score}")
372+
...
373+
>>> asyncio.run(main())
374+
Tuner score: 0.0
375+
"""
376+
377+
if not isinstance(features, (Feature, Features)):
378+
raise TypeError(
379+
f"features was {type(features)}: {features!r}. Should have been Feature or Features"
380+
)
381+
if isinstance(features, Feature):
382+
features = Features(features)
383+
if hasattr(model.config, "predict"):
384+
if isinstance(model.config.predict, Features):
385+
predict_feature = [
386+
feature.name for feature in model.config.predict
387+
]
388+
else:
389+
predict_feature = [model.config.predict.name]
390+
391+
if hasattr(model.config, "features") and any(
392+
isinstance(td, list) for td in train_ds
393+
):
394+
train_ds = list_records_to_dict(
395+
[feature.name for feature in model.config.features]
396+
+ predict_feature,
397+
*train_ds,
398+
model=model,
399+
)
400+
if hasattr(model.config, "features") and any(
401+
isinstance(td, list) for td in valid_ds
402+
):
403+
valid_ds = list_records_to_dict(
404+
[feature.name for feature in model.config.features]
405+
+ predict_feature,
406+
*valid_ds,
407+
model=model,
408+
)
409+
410+
async with contextlib.AsyncExitStack() as astack:
411+
# Open sources
412+
train = await astack.enter_async_context(records_to_sources(*train_ds))
413+
test = await astack.enter_async_context(records_to_sources(*valid_ds))
414+
# Allow for keep models open
415+
if isinstance(model, Model):
416+
model = await astack.enter_async_context(model)
417+
mctx = await astack.enter_async_context(model())
418+
elif isinstance(model, ModelContext):
419+
mctx = model
420+
421+
# Allow for keep models open
422+
if isinstance(accuracy_scorer, AccuracyScorer):
423+
accuracy_scorer = await astack.enter_async_context(accuracy_scorer)
424+
actx = await astack.enter_async_context(accuracy_scorer())
425+
elif isinstance(accuracy_scorer, AccuracyContext):
426+
actx = accuracy_scorer
427+
else:
428+
# TODO Replace this with static type checking and maybe dynamic
429+
# through something like pydantic. See issue #36
430+
raise TypeError(f"{accuracy_scorer} is not an AccuracyScorer")
431+
432+
if isinstance(tuner, Tuner):
433+
tuner = await astack.enter_async_context(tuner)
434+
tctx = await astack.enter_async_context(tuner())
435+
elif isinstance(tuner, TunerContext):
436+
tctx = tuner
437+
else:
438+
raise TypeError(f"{tuner} is not an Tuner")
439+
440+
return float(
441+
await tctx.optimize(mctx, model.config.predict, actx, train, test)
442+
)
443+

dffml/noasync.py

+16
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
train as high_level_train,
77
score as high_level_score,
88
predict as high_level_predict,
9+
tune as high_level_tune,
910
)
1011

1112

@@ -24,6 +25,21 @@ def train(*args, **kwargs):
2425
)
2526
)
2627

28+
def tune(*args, **kwargs):
29+
return asyncio.run(high_level_tune(*args, **kwargs))
30+
31+
32+
tune.__doc__ = (
33+
high_level_tune.__doc__.replace("await ", "")
34+
.replace("async ", "")
35+
.replace("asyncio.run(main())", "main()")
36+
.replace(" >>> import asyncio\n", "")
37+
.replace(
38+
" >>> from dffml import *\n",
39+
" >>> from dffml import *\n >>> from dffml.noasync import tune\n",
40+
)
41+
)
42+
2743

2844
def score(*args, **kwargs):
2945
return asyncio.run(high_level_score(*args, **kwargs))

dffml/skel/config/README.rst

-1
This file was deleted.

dffml/skel/config/README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../common/README.rst

dffml/skel/model/README.rst

-1
This file was deleted.

dffml/skel/model/README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../common/README.rst

dffml/skel/operations/README.rst

-1
This file was deleted.

dffml/skel/operations/README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../common/README.rst

dffml/skel/service/README.rst

-1
This file was deleted.

dffml/skel/service/README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../common/README.rst

dffml/skel/source/README.rst

-1
This file was deleted.

dffml/skel/source/README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../common/README.rst

dffml/tuner/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@
88
TunerContext,
99
Tuner,
1010
)
11-
from .parameter_grid import ParameterGrid

0 commit comments

Comments
 (0)