From 377d6888e1dbe553e72c5b135250c6ee7ddd080e Mon Sep 17 00:00:00 2001 From: ursk Date: Wed, 27 Mar 2024 19:25:28 -0700 Subject: [PATCH] Make spinoffs/autobnn a pip installable package. PiperOrigin-RevId: 619748042 --- .../spinoffs => spinoffs}/autobnn/README.md | 0 .../autobnn}/autobnn/BUILD | 61 +++++++-------- .../autobnn}/autobnn/__init__.py | 23 +++--- .../autobnn}/autobnn/bnn.py | 2 +- .../autobnn}/autobnn/bnn_test.py | 2 +- .../autobnn}/autobnn/bnn_tree.py | 8 +- .../autobnn}/autobnn/bnn_tree_test.py | 4 +- .../autobnn}/autobnn/estimators.py | 8 +- .../autobnn}/autobnn/estimators_test.py | 8 +- .../autobnn}/autobnn/kernels.py | 2 +- .../autobnn}/autobnn/kernels_test.py | 4 +- .../autobnn}/autobnn/likelihoods.py | 0 .../autobnn}/autobnn/likelihoods_test.py | 2 +- .../autobnn}/autobnn/models.py | 10 +-- .../autobnn}/autobnn/models_test.py | 4 +- .../autobnn}/autobnn/operators.py | 4 +- .../autobnn}/autobnn/operators_test.py | 6 +- .../autobnn}/autobnn/training_util.py | 78 +++++++++++++++++-- .../autobnn}/autobnn/training_util_test.py | 8 +- .../autobnn}/autobnn/util.py | 2 +- .../autobnn}/autobnn/util_test.py | 4 +- spinoffs/autobnn/autobnn/version.py | 36 +++++++++ spinoffs/autobnn/setup.py | 72 +++++++++++++++++ .../autobnn/setup_autobnn.sh | 0 24 files changed, 259 insertions(+), 89 deletions(-) rename {tensorflow_probability/spinoffs => spinoffs}/autobnn/README.md (100%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/BUILD (71%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/__init__.py (56%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/bnn.py (98%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/bnn_test.py (97%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/bnn_tree.py (95%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/bnn_tree_test.py (97%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/estimators.py (97%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/estimators_test.py (94%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/kernels.py (99%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/kernels_test.py (98%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/likelihoods.py (100%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/likelihoods_test.py (97%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/models.py (96%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/models_test.py (94%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/operators.py (98%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/operators_test.py (97%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/training_util.py (85%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/training_util_test.py (96%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/util.py (97%) rename {tensorflow_probability/spinoffs => spinoffs/autobnn}/autobnn/util_test.py (95%) create mode 100644 spinoffs/autobnn/autobnn/version.py create mode 100644 spinoffs/autobnn/setup.py rename {tensorflow_probability/spinoffs => spinoffs}/autobnn/setup_autobnn.sh (100%) diff --git a/tensorflow_probability/spinoffs/autobnn/README.md b/spinoffs/autobnn/README.md similarity index 100% rename from tensorflow_probability/spinoffs/autobnn/README.md rename to spinoffs/autobnn/README.md diff --git a/tensorflow_probability/spinoffs/autobnn/BUILD b/spinoffs/autobnn/autobnn/BUILD similarity index 71% rename from tensorflow_probability/spinoffs/autobnn/BUILD rename to spinoffs/autobnn/autobnn/BUILD index 0fac911f7c..669f3d936d 100644 --- a/tensorflow_probability/spinoffs/autobnn/BUILD +++ b/spinoffs/autobnn/autobnn/BUILD @@ -37,7 +37,7 @@ py_library( ":operators", ":training_util", ":util", - "//tensorflow_probability/python/internal:all_util", + # tensorflow_probability/python/internal:all_util dep, ], ) @@ -49,7 +49,7 @@ py_library( # flax:core dep, # jax dep, # jaxtyping dep, - "//tensorflow_probability/python/distributions:distribution.jax", + # tensorflow_probability/python/distributions:distribution.jax dep, ], ) @@ -62,8 +62,8 @@ py_test( # google/protobuf:use_fast_cpp_protos dep, # jax dep, "//tensorflow_probability:jax", - "//tensorflow_probability/python/distributions:lognormal.jax", - "//tensorflow_probability/python/distributions:normal.jax", + # tensorflow_probability/python/distributions:lognormal.jax dep, + # tensorflow_probability/python/distributions:normal.jax dep, ], ) @@ -118,7 +118,7 @@ py_test( ":estimators", ":kernels", ":operators", - "//tensorflow_probability/python/internal:test_util", + # tensorflow_probability/python/internal:test_util dep, ], ) @@ -130,10 +130,10 @@ py_library( # flax dep, # flax:core dep, # jax dep, - "//tensorflow_probability/python/distributions:lognormal.jax", - "//tensorflow_probability/python/distributions:normal.jax", - "//tensorflow_probability/python/distributions:student_t.jax", - "//tensorflow_probability/python/distributions:uniform.jax", + # tensorflow_probability/python/distributions:lognormal.jax dep, + # tensorflow_probability/python/distributions:normal.jax dep, + # tensorflow_probability/python/distributions:student_t.jax dep, + # tensorflow_probability/python/distributions:uniform.jax dep, ], ) @@ -147,7 +147,7 @@ py_test( # absl/testing:parameterized dep, # google/protobuf:use_fast_cpp_protos dep, # jax dep, - "//tensorflow_probability/python/distributions:lognormal.jax", + # tensorflow_probability/python/distributions:lognormal.jax dep, ], ) @@ -158,14 +158,14 @@ py_library( # flax:core dep, # jax dep, # jaxtyping dep, - "//tensorflow_probability/python/bijectors:softplus.jax", - "//tensorflow_probability/python/distributions:distribution.jax", - "//tensorflow_probability/python/distributions:inflated.jax", - "//tensorflow_probability/python/distributions:logistic.jax", - "//tensorflow_probability/python/distributions:lognormal.jax", - "//tensorflow_probability/python/distributions:negative_binomial.jax", - "//tensorflow_probability/python/distributions:normal.jax", - "//tensorflow_probability/python/distributions:transformed_distribution.jax", + # tensorflow_probability/python/bijectors:softplus.jax dep, + # tensorflow_probability/python/distributions:distribution.jax dep, + # tensorflow_probability/python/distributions:inflated.jax dep, + # tensorflow_probability/python/distributions:logistic.jax dep, + # tensorflow_probability/python/distributions:lognormal.jax dep, + # tensorflow_probability/python/distributions:negative_binomial.jax dep, + # tensorflow_probability/python/distributions:normal.jax dep, + # tensorflow_probability/python/distributions:transformed_distribution.jax dep, ], ) @@ -216,14 +216,14 @@ py_library( ":likelihoods", # flax:core dep, # jax dep, - "//tensorflow_probability/python/bijectors:chain.jax", - "//tensorflow_probability/python/bijectors:scale.jax", - "//tensorflow_probability/python/bijectors:shift.jax", - "//tensorflow_probability/python/distributions:beta.jax", - "//tensorflow_probability/python/distributions:dirichlet.jax", - "//tensorflow_probability/python/distributions:half_normal.jax", - "//tensorflow_probability/python/distributions:normal.jax", - "//tensorflow_probability/python/distributions:transformed_distribution.jax", + # tensorflow_probability/python/bijectors:chain.jax dep, + # tensorflow_probability/python/bijectors:scale.jax dep, + # tensorflow_probability/python/bijectors:shift.jax dep, + # tensorflow_probability/python/distributions:beta.jax dep, + # tensorflow_probability/python/distributions:dirichlet.jax dep, + # tensorflow_probability/python/distributions:half_normal.jax dep, + # tensorflow_probability/python/distributions:normal.jax dep, + # tensorflow_probability/python/distributions:transformed_distribution.jax dep, ], ) @@ -242,7 +242,7 @@ py_test( # google/protobuf:use_fast_cpp_protos dep, # jax dep, # numpy dep, - "//tensorflow_probability/python/distributions:distribution.jax", + # tensorflow_probability/python/distributions:distribution.jax dep, ], ) @@ -258,7 +258,6 @@ py_library( # matplotlib dep, # numpy dep, # pandas dep, - "//tensorflow_probability/python/experimental/timeseries:metrics", ], ) @@ -276,7 +275,7 @@ py_test( # google/protobuf:use_fast_cpp_protos dep, # jax dep, # numpy dep, - "//tensorflow_probability/python/internal:test_util", + # tensorflow_probability/python/internal:test_util dep, ], ) @@ -288,7 +287,7 @@ py_library( # jax dep, # numpy dep, # scipy dep, - "//tensorflow_probability/python/distributions:distribution.jax", + # tensorflow_probability/python/distributions:distribution.jax dep, ], ) @@ -301,6 +300,6 @@ py_test( # google/protobuf:use_fast_cpp_protos dep, # jax dep, # numpy dep, - "//tensorflow_probability/python/internal:test_util", + # tensorflow_probability/python/internal:test_util dep, ], ) diff --git a/tensorflow_probability/spinoffs/autobnn/__init__.py b/spinoffs/autobnn/autobnn/__init__.py similarity index 56% rename from tensorflow_probability/spinoffs/autobnn/__init__.py rename to spinoffs/autobnn/autobnn/__init__.py index 7d7e5492db..8cfaa4c862 100644 --- a/tensorflow_probability/spinoffs/autobnn/__init__.py +++ b/spinoffs/autobnn/autobnn/__init__.py @@ -14,18 +14,17 @@ # ============================================================================ """Package for training GP-like Bayesian Neural Nets w/ composite structure.""" -from tensorflow_probability.python.internal import all_util -from tensorflow_probability.spinoffs.autobnn import bnn -from tensorflow_probability.spinoffs.autobnn import bnn_tree -from tensorflow_probability.spinoffs.autobnn import estimators -from tensorflow_probability.spinoffs.autobnn import kernels -from tensorflow_probability.spinoffs.autobnn import likelihoods -from tensorflow_probability.spinoffs.autobnn import models -from tensorflow_probability.spinoffs.autobnn import operators -from tensorflow_probability.spinoffs.autobnn import training_util -from tensorflow_probability.spinoffs.autobnn import util +from autobnn import bnn +from autobnn import bnn_tree +from autobnn import estimators +from autobnn import kernels +from autobnn import likelihoods +from autobnn import models +from autobnn import operators +from autobnn import training_util +from autobnn import util -_allowed_symbols = [ +__all__ = [ 'bnn', 'bnn_tree', 'estimators', @@ -36,5 +35,3 @@ 'training_util', 'util', ] - -all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow_probability/spinoffs/autobnn/bnn.py b/spinoffs/autobnn/autobnn/bnn.py similarity index 98% rename from tensorflow_probability/spinoffs/autobnn/bnn.py rename to spinoffs/autobnn/autobnn/bnn.py index 20d6695861..4f947b357c 100644 --- a/tensorflow_probability/spinoffs/autobnn/bnn.py +++ b/spinoffs/autobnn/autobnn/bnn.py @@ -21,7 +21,7 @@ import jax import jax.numpy as jnp from jaxtyping import Array, Float, PyTree # pylint: disable=g-importing-member,g-multiple-import -from tensorflow_probability.spinoffs.autobnn import likelihoods +from autobnn import likelihoods from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib diff --git a/tensorflow_probability/spinoffs/autobnn/bnn_test.py b/spinoffs/autobnn/autobnn/bnn_test.py similarity index 97% rename from tensorflow_probability/spinoffs/autobnn/bnn_test.py rename to spinoffs/autobnn/autobnn/bnn_test.py index 4d0fe04831..3f83bfc049 100644 --- a/tensorflow_probability/spinoffs/autobnn/bnn_test.py +++ b/spinoffs/autobnn/autobnn/bnn_test.py @@ -17,7 +17,7 @@ from flax import linen as nn import jax import jax.numpy as jnp -from tensorflow_probability.spinoffs.autobnn import bnn +from autobnn import bnn from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib from tensorflow_probability.substrates.jax.distributions import normal as normal_lib from absl.testing import absltest diff --git a/tensorflow_probability/spinoffs/autobnn/bnn_tree.py b/spinoffs/autobnn/autobnn/bnn_tree.py similarity index 95% rename from tensorflow_probability/spinoffs/autobnn/bnn_tree.py rename to spinoffs/autobnn/autobnn/bnn_tree.py index 644e128307..2ea3227db0 100644 --- a/tensorflow_probability/spinoffs/autobnn/bnn_tree.py +++ b/spinoffs/autobnn/autobnn/bnn_tree.py @@ -19,10 +19,10 @@ from flax import linen as nn import jax import jax.numpy as jnp -from tensorflow_probability.spinoffs.autobnn import bnn -from tensorflow_probability.spinoffs.autobnn import kernels -from tensorflow_probability.spinoffs.autobnn import operators -from tensorflow_probability.spinoffs.autobnn import util +from autobnn import bnn +from autobnn import kernels +from autobnn import operators +from autobnn import util Array = jnp.ndarray diff --git a/tensorflow_probability/spinoffs/autobnn/bnn_tree_test.py b/spinoffs/autobnn/autobnn/bnn_tree_test.py similarity index 97% rename from tensorflow_probability/spinoffs/autobnn/bnn_tree_test.py rename to spinoffs/autobnn/autobnn/bnn_tree_test.py index ca1c23d9cf..24c18f2df8 100644 --- a/tensorflow_probability/spinoffs/autobnn/bnn_tree_test.py +++ b/spinoffs/autobnn/autobnn/bnn_tree_test.py @@ -18,8 +18,8 @@ from flax import linen as nn import jax import jax.numpy as jnp -from tensorflow_probability.spinoffs.autobnn import bnn_tree -from tensorflow_probability.spinoffs.autobnn import kernels +from autobnn import bnn_tree +from autobnn import kernels from absl.testing import absltest diff --git a/tensorflow_probability/spinoffs/autobnn/estimators.py b/spinoffs/autobnn/autobnn/estimators.py similarity index 97% rename from tensorflow_probability/spinoffs/autobnn/estimators.py rename to spinoffs/autobnn/autobnn/estimators.py index 5002ed2cc5..f1ce360351 100644 --- a/tensorflow_probability/spinoffs/autobnn/estimators.py +++ b/spinoffs/autobnn/autobnn/estimators.py @@ -19,10 +19,10 @@ import jax import jax.numpy as jnp from jaxtyping import ArrayLike, PyTree # pylint: disable=g-importing-member,g-multiple-import -from tensorflow_probability.spinoffs.autobnn import bnn -from tensorflow_probability.spinoffs.autobnn import likelihoods -from tensorflow_probability.spinoffs.autobnn import models -from tensorflow_probability.spinoffs.autobnn import training_util +from autobnn import bnn +from autobnn import likelihoods +from autobnn import models +from autobnn import training_util class _AutoBnnEstimator: diff --git a/tensorflow_probability/spinoffs/autobnn/estimators_test.py b/spinoffs/autobnn/autobnn/estimators_test.py similarity index 94% rename from tensorflow_probability/spinoffs/autobnn/estimators_test.py rename to spinoffs/autobnn/autobnn/estimators_test.py index 0f5f6c2b44..6fc50456f8 100644 --- a/tensorflow_probability/spinoffs/autobnn/estimators_test.py +++ b/spinoffs/autobnn/autobnn/estimators_test.py @@ -17,10 +17,10 @@ import jax import numpy as np from tensorflow_probability.python.internal import test_util -from tensorflow_probability.spinoffs.autobnn import estimators -from tensorflow_probability.spinoffs.autobnn import kernels -from tensorflow_probability.spinoffs.autobnn import operators -from tensorflow_probability.spinoffs.autobnn import util +from autobnn import estimators +from autobnn import kernels +from autobnn import operators +from autobnn import util class AutoBNNTest(test_util.TestCase): diff --git a/tensorflow_probability/spinoffs/autobnn/kernels.py b/spinoffs/autobnn/autobnn/kernels.py similarity index 99% rename from tensorflow_probability/spinoffs/autobnn/kernels.py rename to spinoffs/autobnn/autobnn/kernels.py index 3c9ad8019e..fcdf9d1f5c 100644 --- a/tensorflow_probability/spinoffs/autobnn/kernels.py +++ b/spinoffs/autobnn/autobnn/kernels.py @@ -19,7 +19,7 @@ from flax.linen import initializers import jax import jax.numpy as jnp -from tensorflow_probability.spinoffs.autobnn import bnn +from autobnn import bnn from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib from tensorflow_probability.substrates.jax.distributions import normal as normal_lib from tensorflow_probability.substrates.jax.distributions import student_t as student_t_lib diff --git a/tensorflow_probability/spinoffs/autobnn/kernels_test.py b/spinoffs/autobnn/autobnn/kernels_test.py similarity index 98% rename from tensorflow_probability/spinoffs/autobnn/kernels_test.py rename to spinoffs/autobnn/autobnn/kernels_test.py index 96362e580b..550d51b742 100644 --- a/tensorflow_probability/spinoffs/autobnn/kernels_test.py +++ b/spinoffs/autobnn/autobnn/kernels_test.py @@ -18,8 +18,8 @@ import jax import jax.numpy as jnp import numpy as np -from tensorflow_probability.spinoffs.autobnn import kernels -from tensorflow_probability.spinoffs.autobnn import util +from autobnn import kernels +from autobnn import util from tensorflow_probability.substrates.jax.distributions import lognormal as lognormal_lib from absl.testing import absltest diff --git a/tensorflow_probability/spinoffs/autobnn/likelihoods.py b/spinoffs/autobnn/autobnn/likelihoods.py similarity index 100% rename from tensorflow_probability/spinoffs/autobnn/likelihoods.py rename to spinoffs/autobnn/autobnn/likelihoods.py diff --git a/tensorflow_probability/spinoffs/autobnn/likelihoods_test.py b/spinoffs/autobnn/autobnn/likelihoods_test.py similarity index 97% rename from tensorflow_probability/spinoffs/autobnn/likelihoods_test.py rename to spinoffs/autobnn/autobnn/likelihoods_test.py index a36a415a68..c9b91edf62 100644 --- a/tensorflow_probability/spinoffs/autobnn/likelihoods_test.py +++ b/spinoffs/autobnn/autobnn/likelihoods_test.py @@ -16,7 +16,7 @@ from absl.testing import parameterized import jax.numpy as jnp -from tensorflow_probability.spinoffs.autobnn import likelihoods +from autobnn import likelihoods from absl.testing import absltest diff --git a/tensorflow_probability/spinoffs/autobnn/models.py b/spinoffs/autobnn/autobnn/models.py similarity index 96% rename from tensorflow_probability/spinoffs/autobnn/models.py rename to spinoffs/autobnn/autobnn/models.py index 8eb77d849d..0166794317 100644 --- a/tensorflow_probability/spinoffs/autobnn/models.py +++ b/spinoffs/autobnn/autobnn/models.py @@ -22,11 +22,11 @@ import functools from typing import Sequence, Union import jax.numpy as jnp -from tensorflow_probability.spinoffs.autobnn import bnn -from tensorflow_probability.spinoffs.autobnn import bnn_tree -from tensorflow_probability.spinoffs.autobnn import kernels -from tensorflow_probability.spinoffs.autobnn import likelihoods -from tensorflow_probability.spinoffs.autobnn import operators +from autobnn import bnn +from autobnn import bnn_tree +from autobnn import kernels +from autobnn import likelihoods +from autobnn import operators Array = jnp.ndarray diff --git a/tensorflow_probability/spinoffs/autobnn/models_test.py b/spinoffs/autobnn/autobnn/models_test.py similarity index 94% rename from tensorflow_probability/spinoffs/autobnn/models_test.py rename to spinoffs/autobnn/autobnn/models_test.py index 68b3a40044..c8dcd63ba0 100644 --- a/tensorflow_probability/spinoffs/autobnn/models_test.py +++ b/spinoffs/autobnn/autobnn/models_test.py @@ -17,8 +17,8 @@ from absl.testing import parameterized import jax import jax.numpy as jnp -from tensorflow_probability.spinoffs.autobnn import likelihoods -from tensorflow_probability.spinoffs.autobnn import models +from autobnn import likelihoods +from autobnn import models from absl.testing import absltest diff --git a/tensorflow_probability/spinoffs/autobnn/operators.py b/spinoffs/autobnn/autobnn/operators.py similarity index 98% rename from tensorflow_probability/spinoffs/autobnn/operators.py rename to spinoffs/autobnn/autobnn/operators.py index 43e37b3250..2891ee5373 100644 --- a/tensorflow_probability/spinoffs/autobnn/operators.py +++ b/spinoffs/autobnn/autobnn/operators.py @@ -19,8 +19,8 @@ from flax import linen as nn import jax import jax.numpy as jnp -from tensorflow_probability.spinoffs.autobnn import bnn -from tensorflow_probability.spinoffs.autobnn import likelihoods +from autobnn import bnn +from autobnn import likelihoods from tensorflow_probability.substrates.jax.bijectors import chain as chain_lib from tensorflow_probability.substrates.jax.bijectors import scale as scale_lib from tensorflow_probability.substrates.jax.bijectors import shift as shift_lib diff --git a/tensorflow_probability/spinoffs/autobnn/operators_test.py b/spinoffs/autobnn/autobnn/operators_test.py similarity index 97% rename from tensorflow_probability/spinoffs/autobnn/operators_test.py rename to spinoffs/autobnn/autobnn/operators_test.py index d22ec6e85e..7cec80efe7 100644 --- a/tensorflow_probability/spinoffs/autobnn/operators_test.py +++ b/spinoffs/autobnn/autobnn/operators_test.py @@ -19,9 +19,9 @@ import jax import jax.numpy as jnp import numpy as np -from tensorflow_probability.spinoffs.autobnn import kernels -from tensorflow_probability.spinoffs.autobnn import operators -from tensorflow_probability.spinoffs.autobnn import util +from autobnn import kernels +from autobnn import operators +from autobnn import util from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib from absl.testing import absltest diff --git a/tensorflow_probability/spinoffs/autobnn/training_util.py b/spinoffs/autobnn/autobnn/training_util.py similarity index 85% rename from tensorflow_probability/spinoffs/autobnn/training_util.py rename to spinoffs/autobnn/autobnn/training_util.py index 1167d157d5..55134bad58 100644 --- a/tensorflow_probability/spinoffs/autobnn/training_util.py +++ b/spinoffs/autobnn/autobnn/training_util.py @@ -23,9 +23,75 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -from tensorflow_probability.python.experimental.timeseries import metrics -from tensorflow_probability.spinoffs.autobnn import bnn -from tensorflow_probability.spinoffs.autobnn import util +from autobnn import bnn +from autobnn import util + + +def smape(y, yhat): + """Return the symmetric mean absolute percentage error. + + Args: + y: An array containing the true values. + yhat: An array containing the predicted values. + + Returns: + The scalar SMAPE. + """ + # https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error + assert len(yhat) == len(y) + h = len(y) + errors = np.abs(y - yhat) / (np.abs(y) + np.abs(yhat)) * 100 + return 2/h * np.sum(errors) + + +def mase(y, yhat, y_obs, m): + """Return the mean absolute scaled error. + + Args: + y: An array containing the true values. + yhat: An array containing the predicted values. + y_obs: An array containing the training values. + m: The season length. + + Returns: + The scalar MASE. + """ + # https://en.wikipedia.org/wiki/Mean_absolute_scaled_error + assert len(yhat) == len(y) + n = len(y_obs) + h = len(y) + assert 0 < m < len(y_obs) + numer = np.sum(np.abs(y - yhat)) + denom = np.sum(np.abs(y_obs[m:] - y_obs[:-m])) / (n - m) + return (1 / h) * (numer / denom) + + +def msis(y, yhat_lower, yhat_upper, y_obs, m, a=0.05): + """Return the mean scaled interval score. + + Args: + y: An array containing the true values. + yhat_lower: An array containing the a% quantile of the predicted + distribution. + yhat_upper: An array containing the (1-a)% quantile of the + predicted distribution. + y_obs: An array containing the training values. + m: The season length. + a: A scalar in [0, 1] specifying the quantile window to evaluate. + + Returns: + The scalar MSIS. + """ + # https://www.uber.com/blog/m4-forecasting-competition/ + assert len(y) == len(yhat_lower) == len(yhat_upper) + n = len(y_obs) + h = len(y) + numer = np.sum( + (yhat_upper - yhat_lower) + + (2 / a) * (yhat_lower - y) * (y < yhat_lower) + + (2 / a) * (y - yhat_upper) * (yhat_upper < y)) + denom = np.sum(np.abs(y_obs[m:] - y_obs[:-m])) / (n - m) + return (1 / h) * (numer / denom) def _make_bayeux_model( @@ -247,12 +313,12 @@ def make_results_dataframe( # 'm3': [smape, mase, msis], # 'traffic': [wmppl, wmape], # 'm5': [wrmsse, wspl] - smapes = np.array([metrics.smape(y_test[:i], predictions[:i]) + smapes = np.array([smape(y_test[:i], predictions[:i]) for i in range(1, n_test+1)]) - mases = np.array([metrics.mase(y_test[:i], predictions[:i], y_train, 12) + mases = np.array([mase(y_test[:i], predictions[:i], y_train, 12) for i in range(1, n_test+1)]) msises = np.array( - [metrics.msis(y_test[:i], p2_5[:i], p97_5[:i], y_train, 12) + [msis(y_test[:i], p2_5[:i], p97_5[:i], y_train, 12) for i in range(1, n_test + 1)]) return pd.DataFrame( data=np.array([predictions, p2_5, p90, p97_5, diff --git a/tensorflow_probability/spinoffs/autobnn/training_util_test.py b/spinoffs/autobnn/autobnn/training_util_test.py similarity index 96% rename from tensorflow_probability/spinoffs/autobnn/training_util_test.py rename to spinoffs/autobnn/autobnn/training_util_test.py index 52d7762080..f139a8468d 100644 --- a/tensorflow_probability/spinoffs/autobnn/training_util_test.py +++ b/spinoffs/autobnn/autobnn/training_util_test.py @@ -19,10 +19,10 @@ import jax.numpy as jnp import numpy as np from tensorflow_probability.python.internal import test_util -from tensorflow_probability.spinoffs.autobnn import kernels -from tensorflow_probability.spinoffs.autobnn import operators -from tensorflow_probability.spinoffs.autobnn import training_util -from tensorflow_probability.spinoffs.autobnn import util +from autobnn import kernels +from autobnn import operators +from autobnn import training_util +from autobnn import util class TrainingUtilTest(test_util.TestCase): diff --git a/tensorflow_probability/spinoffs/autobnn/util.py b/spinoffs/autobnn/autobnn/util.py similarity index 97% rename from tensorflow_probability/spinoffs/autobnn/util.py rename to spinoffs/autobnn/autobnn/util.py index 10674126d9..b244221071 100644 --- a/tensorflow_probability/spinoffs/autobnn/util.py +++ b/spinoffs/autobnn/autobnn/util.py @@ -19,7 +19,7 @@ import jax import jax.numpy as jnp import scipy -from tensorflow_probability.spinoffs.autobnn import bnn +from autobnn import bnn from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib diff --git a/tensorflow_probability/spinoffs/autobnn/util_test.py b/spinoffs/autobnn/autobnn/util_test.py similarity index 95% rename from tensorflow_probability/spinoffs/autobnn/util_test.py rename to spinoffs/autobnn/autobnn/util_test.py index 250d5099e9..491adb290f 100644 --- a/tensorflow_probability/spinoffs/autobnn/util_test.py +++ b/spinoffs/autobnn/autobnn/util_test.py @@ -18,8 +18,8 @@ import jax.numpy as jnp import numpy as np from tensorflow_probability.python.internal import test_util -from tensorflow_probability.spinoffs.autobnn import kernels -from tensorflow_probability.spinoffs.autobnn import util +from autobnn import kernels +from autobnn import util class UtilTest(test_util.TestCase): diff --git a/spinoffs/autobnn/autobnn/version.py b/spinoffs/autobnn/autobnn/version.py new file mode 100644 index 0000000000..2f0f2402a5 --- /dev/null +++ b/spinoffs/autobnn/autobnn/version.py @@ -0,0 +1,36 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define AutoBNN version information.""" + +# We follow Semantic Versioning (https://semver.org/) +_MAJOR_VERSION = '0' +_MINOR_VERSION = '0' +_PATCH_VERSION = '2' + +# When building releases, we can update this value on the release branch to +# reflect the current release candidate ('rc0', 'rc1') or, finally, the official +# stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a +# release branch, the current version is by default assumed to be a +# 'development' version, labeled 'dev'. +_VERSION_SUFFIX = 'dev' + +# Example, '0.4.0-dev' +__version__ = '.'.join([ + _MAJOR_VERSION, + _MINOR_VERSION, + _PATCH_VERSION, +]) +if _VERSION_SUFFIX: + __version__ = '{}-{}'.format(__version__, _VERSION_SUFFIX) diff --git a/spinoffs/autobnn/setup.py b/spinoffs/autobnn/setup.py new file mode 100644 index 0000000000..184fb98611 --- /dev/null +++ b/spinoffs/autobnn/setup.py @@ -0,0 +1,72 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Install AutoBNN.""" +import os +import sys +from setuptools import find_packages +from setuptools import setup + +# To enable importing version.py directly, we add its path to sys.path. +version_path = os.path.join( + os.path.dirname(__file__), 'autobnn') +sys.path.append(version_path) +from version import __version__ # pylint: disable=g-import-not-at-top + +with open('README.md', 'r') as fh: + oryx_long_description = fh.read() + +setup( + name='autobnn', + python_requires='>=3.6', + version=__version__, + description=( + 'Package for training Gaussian process-like Bayesian Neural Networks' + ' with composite structure.' + ), + long_description=oryx_long_description, + long_description_content_type='text/markdown', + author='Google LLC', + author_email='no-reply@google.com', + url='https://github.com/tensorflow/probability/tree/main/spinoffs/autobnn', + license='Apache 2.0', + packages=find_packages('.'), + install_requires=[ + 'bayeux-ml', + 'chex', + 'flax', + 'jaxtyping', + 'matplotlib', + 'pandas', + 'scipy', + ], + # Add in any packaged data. + exclude_package_data={'': ['BUILD']}, + zip_safe=False, + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Mathematics', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + keywords='tensorflow jax probability statistics bayesian machine learning', +) diff --git a/tensorflow_probability/spinoffs/autobnn/setup_autobnn.sh b/spinoffs/autobnn/setup_autobnn.sh similarity index 100% rename from tensorflow_probability/spinoffs/autobnn/setup_autobnn.sh rename to spinoffs/autobnn/setup_autobnn.sh