diff --git a/BUILD b/BUILD index e9a0d19546..097665c738 100644 --- a/BUILD +++ b/BUILD @@ -20,6 +20,5 @@ sh_binary( "required_packages.py", "setup.py", "//tensorflow_probability", - "//tensorflow_probability/substrates", ], ) diff --git a/conftest.py b/conftest.py index 56206d8ea8..abc9bec2cc 100644 --- a/conftest.py +++ b/conftest.py @@ -20,6 +20,7 @@ collect_ignore = [ "discussion/", "setup.py", + "tensorflow_probability/python/experimental/substrates/" ] diff --git a/tensorflow_probability/python/BUILD b/tensorflow_probability/python/BUILD index 1431949e91..3a11d8d473 100644 --- a/tensorflow_probability/python/BUILD +++ b/tensorflow_probability/python/BUILD @@ -41,7 +41,6 @@ py_library( "//tensorflow_probability/python/glm", "//tensorflow_probability/python/internal", "//tensorflow_probability/python/internal:all_util", - "//tensorflow_probability/python/internal:lazy_loader", "//tensorflow_probability/python/layers", "//tensorflow_probability/python/math", "//tensorflow_probability/python/mcmc", diff --git a/tensorflow_probability/python/experimental/BUILD b/tensorflow_probability/python/experimental/BUILD index f955ce118a..6ae164fef5 100644 --- a/tensorflow_probability/python/experimental/BUILD +++ b/tensorflow_probability/python/experimental/BUILD @@ -41,6 +41,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/experimental/auto_batching", "//tensorflow_probability/python/experimental/marginalize", "//tensorflow_probability/python/experimental/nn", + "//tensorflow_probability/python/experimental/substrates", "//tensorflow_probability/python/experimental/timeseries", "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/experimental/util:composite_tensor", @@ -63,6 +64,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/experimental/sequential", "//tensorflow_probability/python/experimental/stats", "//tensorflow_probability/python/experimental/sts_gibbs", + "//tensorflow_probability/python/experimental/substrates", "//tensorflow_probability/python/experimental/tangent_spaces", "//tensorflow_probability/python/experimental/timeseries", "//tensorflow_probability/python/experimental/util", diff --git a/tensorflow_probability/python/experimental/__init__.py b/tensorflow_probability/python/experimental/__init__.py index d316a44d90..58a72d64ae 100644 --- a/tensorflow_probability/python/experimental/__init__.py +++ b/tensorflow_probability/python/experimental/__init__.py @@ -43,6 +43,7 @@ from tensorflow_probability.python.experimental import sequential from tensorflow_probability.python.experimental import stats from tensorflow_probability.python.experimental import sts_gibbs +from tensorflow_probability.python.experimental import substrates from tensorflow_probability.python.experimental import tangent_spaces from tensorflow_probability.python.experimental import timeseries from tensorflow_probability.python.experimental import util @@ -75,6 +76,7 @@ 'sequential', 'sts_gibbs', 'stats', + 'substrates', 'tangent_spaces', 'timeseries', 'unnest', diff --git a/tensorflow_probability/python/experimental/substrates/BUILD b/tensorflow_probability/python/experimental/substrates/BUILD new file mode 100644 index 0000000000..c7f3bd836d --- /dev/null +++ b/tensorflow_probability/python/experimental/substrates/BUILD @@ -0,0 +1,37 @@ +# Copyright 2019 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Description: +# API-unstable code that is part of the TFP package. + +# Placeholder: py_library + +package( + # default_applicable_licenses + default_visibility = [ + "//tensorflow_probability:__subpackages__", + ], +) + +licenses(["notice"]) + +py_library( + name = "substrates", + srcs = ["__init__.py"], + deps = [ + "//tensorflow_probability/python/internal:all_util", + "//tensorflow_probability/python/internal:lazy_loader", + "//tensorflow_probability/substrates", + ], +) diff --git a/tensorflow_probability/python/experimental/substrates/__init__.py b/tensorflow_probability/python/experimental/substrates/__init__.py new file mode 100644 index 0000000000..6cc75af087 --- /dev/null +++ b/tensorflow_probability/python/experimental/substrates/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2019 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""TensorFlow Probability alternative substrates.""" + +from tensorflow_probability.python.internal import all_util +from tensorflow_probability.python.internal import lazy_loader + +jax = lazy_loader.LazyLoader( + 'jax', globals(), + 'tensorflow_probability.substrates.jax') +numpy = lazy_loader.LazyLoader( + 'numpy', globals(), + 'tensorflow_probability.substrates.numpy') + + +_allowed_symbols = [ + 'jax', + 'numpy', +] + +all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow_probability/python/internal/lazy_loader.py b/tensorflow_probability/python/internal/lazy_loader.py index c42f0cca5d..9cac7eceac 100644 --- a/tensorflow_probability/python/internal/lazy_loader.py +++ b/tensorflow_probability/python/internal/lazy_loader.py @@ -19,9 +19,6 @@ import types -__all__ = ['LazyLoader'] - - class LazyLoader(types.ModuleType): """Lazily import a module to avoid pulling in large deps, defer checks.""" diff --git a/tensorflow_probability/substrates/BUILD b/tensorflow_probability/substrates/BUILD index 5a186f4cfa..3f99a48001 100644 --- a/tensorflow_probability/substrates/BUILD +++ b/tensorflow_probability/substrates/BUILD @@ -32,8 +32,6 @@ py_library( srcs = [ "__init__.py", ], - # :substrates needs to be visible to the external :pip_pkg target. - visibility = ["//visibility:public"], # EnableOnExport deps = [ ":jax", ":numpy", diff --git a/tensorflow_probability/substrates/meta/rewrite.py b/tensorflow_probability/substrates/meta/rewrite.py index f856bdaf1a..fd6dea9444 100644 --- a/tensorflow_probability/substrates/meta/rewrite.py +++ b/tensorflow_probability/substrates/meta/rewrite.py @@ -258,9 +258,6 @@ def disable_all(name): if FLAGS.numpy_to_jax: contents = contents.replace('tfp.substrates.numpy', 'tfp.substrates.jax') contents = contents.replace('substrates.numpy', 'substrates.jax') - contents = contents.replace( - 'tensorflow_probability.substrates import numpy', - 'tensorflow_probability.substrates import jax') contents = contents.replace('backend.numpy', 'backend.jax') contents = contents.replace('backend import numpy as tf', 'backend import jax as tf') diff --git a/tensorflow_probability/tools/build_docs.py b/tensorflow_probability/tools/build_docs.py index 2adb3b21a6..d832748ca3 100644 --- a/tensorflow_probability/tools/build_docs.py +++ b/tensorflow_probability/tools/build_docs.py @@ -46,7 +46,12 @@ FLAGS = flags.FLAGS -DO_NOT_GENERATE_DOCS_FOR = [] +DO_NOT_GENERATE_DOCS_FOR = [ + tfp.experimental.substrates.jax.tf2jax, + tfp.experimental.substrates.jax.experimental, + tfp.experimental.substrates.numpy.tf2numpy, + tfp.experimental.substrates.numpy.experimental, +] def internal_filter(path, parent, children):