|
1 | 1 | import contextlib
|
2 | 2 | from typing import Union, Dict, Any, List
|
3 | 3 |
|
| 4 | + |
4 | 5 | from ..record import Record
|
5 | 6 | from ..source.source import BaseSource
|
6 | 7 | from ..feature import Feature, Features
|
7 | 8 | from ..model import Model, ModelContext
|
8 | 9 | from ..util.internal import records_to_sources, list_records_to_dict
|
9 | 10 | from ..accuracy.accuracy import AccuracyScorer, AccuracyContext
|
| 11 | +from ..tuner import Tuner, TunerContext |
10 | 12 |
|
11 | 13 |
|
12 | 14 | async def train(model, *args: Union[BaseSource, Record, Dict[str, Any], List]):
|
@@ -293,3 +295,150 @@ async def predict(
|
293 | 295 | )
|
294 | 296 | if update:
|
295 | 297 | await sctx.update(record)
|
| 298 | + |
| 299 | +async def tune( |
| 300 | + model, |
| 301 | + tuner: Union[Tuner, TunerContext], |
| 302 | + accuracy_scorer: Union[AccuracyScorer, AccuracyContext], |
| 303 | + features: Union[Feature, Features], |
| 304 | + train_ds: Union[BaseSource, Record, Dict[str, Any], List], |
| 305 | + valid_ds: Union[BaseSource, Record, Dict[str, Any], List], |
| 306 | +) -> float: |
| 307 | + |
| 308 | + """ |
| 309 | + Tune the hyperparameters of a model with a given tuner. |
| 310 | +
|
| 311 | + |
| 312 | + Parameters |
| 313 | + ---------- |
| 314 | + model : Model |
| 315 | + Machine Learning model to use. See :doc:`/plugins/dffml_model` for |
| 316 | + models options. |
| 317 | + tuner: Tuner |
| 318 | + Hyperparameter tuning method to use. See :doc:`/plugins/dffml_tuner` for |
| 319 | + tuner options. |
| 320 | + train_ds : list |
| 321 | + Input data for training. Could be a ``dict``, :py:class:`Record`, |
| 322 | + filename, one of the data :doc:`/plugins/dffml_source`, or a filename |
| 323 | + with the extension being one of the data sources. |
| 324 | + valid_ds : list |
| 325 | + Validation data for testing. Could be a ``dict``, :py:class:`Record`, |
| 326 | + filename, one of the data :doc:`/plugins/dffml_source`, or a filename |
| 327 | + with the extension being one of the data sources. |
| 328 | +
|
| 329 | +
|
| 330 | + Returns |
| 331 | + ------- |
| 332 | + float |
| 333 | + A decimal value representing the result of the accuracy scorer on the given |
| 334 | + test set. For instance, ClassificationAccuracy represents the percentage of correct |
| 335 | + classifications made by the model. |
| 336 | +
|
| 337 | + Examples |
| 338 | + -------- |
| 339 | +
|
| 340 | + >>> import asyncio |
| 341 | + >>> from dffml import * |
| 342 | + >>> from dffml_model_xgboost.xgbclassifier import XGBClassifierModel |
| 343 | + >>> |
| 344 | + >>> model = XGBClassifierModel( |
| 345 | + ... features=Features( |
| 346 | + ... Feature("SepalLength", float, 1), |
| 347 | + ... Feature("SepalWidth", float, 1), |
| 348 | + ... Feature("PetalLength", float, 1), |
| 349 | + ... ), |
| 350 | + ... predict=Feature("classification", int, 1), |
| 351 | + ... location="tempdir", |
| 352 | + ... ) |
| 353 | + >>> |
| 354 | + >>> async def main(): |
| 355 | + ... await tune( |
| 356 | + ... model, |
| 357 | + ... ParameterGrid( |
| 358 | + ... parameters={ |
| 359 | + ... "learning_rate": [0.01, 0.05, 0.1], |
| 360 | + ... "n_estimators": [20, 100, 200], |
| 361 | + ... "max_depth": [3,5,8] |
| 362 | + ... } |
| 363 | + ... ), |
| 364 | + ... MeanSquaredErrorAccuracy(), |
| 365 | + ... Features( |
| 366 | + ... Feature("SepalLength", float, 1), |
| 367 | + ... Feature("SepalWidth", float, 1), |
| 368 | + ... Feature("PetalLength", float, 1), |
| 369 | + ... ), |
| 370 | + ... [CSVSource(filename="iris_training.csv")], |
| 371 | + ... [CSVSource(filename="iris_test.csv")], |
| 372 | + ... ) |
| 373 | + >>> |
| 374 | + >>> asyncio.run(main()) |
| 375 | + Accuracy: 0.0 |
| 376 | + """ |
| 377 | + |
| 378 | + if not isinstance(features, (Feature, Features)): |
| 379 | + raise TypeError( |
| 380 | + f"features was {type(features)}: {features!r}. Should have been Feature or Features" |
| 381 | + ) |
| 382 | + if isinstance(features, Feature): |
| 383 | + features = Features(features) |
| 384 | + if hasattr(model.config, "predict"): |
| 385 | + if isinstance(model.config.predict, Features): |
| 386 | + predict_feature = [ |
| 387 | + feature.name for feature in model.config.predict |
| 388 | + ] |
| 389 | + else: |
| 390 | + predict_feature = [model.config.predict.name] |
| 391 | + |
| 392 | + if hasattr(model.config, "features") and any( |
| 393 | + isinstance(td, list) for td in train_ds |
| 394 | + ): |
| 395 | + train_ds = list_records_to_dict( |
| 396 | + [feature.name for feature in model.config.features] |
| 397 | + + predict_feature, |
| 398 | + *train_ds, |
| 399 | + model=model, |
| 400 | + ) |
| 401 | + if hasattr(model.config, "features") and any( |
| 402 | + isinstance(td, list) for td in valid_ds |
| 403 | + ): |
| 404 | + valid_ds = list_records_to_dict( |
| 405 | + [feature.name for feature in model.config.features] |
| 406 | + + predict_feature, |
| 407 | + *valid_ds, |
| 408 | + model=model, |
| 409 | + ) |
| 410 | + |
| 411 | + async with contextlib.AsyncExitStack() as astack: |
| 412 | + # Open sources |
| 413 | + train = await astack.enter_async_context(records_to_sources(*train_ds)) |
| 414 | + test = await astack.enter_async_context(records_to_sources(*valid_ds)) |
| 415 | + # Allow for keep models open |
| 416 | + if isinstance(model, Model): |
| 417 | + model = await astack.enter_async_context(model) |
| 418 | + mctx = await astack.enter_async_context(model()) |
| 419 | + elif isinstance(model, ModelContext): |
| 420 | + mctx = model |
| 421 | + |
| 422 | + # Allow for keep models open |
| 423 | + if isinstance(accuracy_scorer, AccuracyScorer): |
| 424 | + accuracy_scorer = await astack.enter_async_context(accuracy_scorer) |
| 425 | + actx = await astack.enter_async_context(accuracy_scorer()) |
| 426 | + elif isinstance(accuracy_scorer, AccuracyContext): |
| 427 | + actx = accuracy_scorer |
| 428 | + else: |
| 429 | + # TODO Replace this with static type checking and maybe dynamic |
| 430 | + # through something like pydantic. See issue #36 |
| 431 | + raise TypeError(f"{accuracy_scorer} is not an AccuracyScorer") |
| 432 | + |
| 433 | + if isinstance(tuner, Tuner): |
| 434 | + tuner = await astack.enter_async_context(tuner) |
| 435 | + tctx = await astack.enter_async_context(tuner()) |
| 436 | + elif isinstance(tuner, TunerContext): |
| 437 | + tctx = tuner |
| 438 | + else: |
| 439 | + raise TypeError(f"{tuner} is not an Tuner") |
| 440 | + |
| 441 | + return float( |
| 442 | + await tctx.optimize(mctx, model.config.predict, actx, train, test) |
| 443 | + ) |
| 444 | + |
0 commit comments