|
| 1 | +# Copyright 2025 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Test that public APIs are correctly documented.""" |
| 16 | + |
| 17 | +import collections |
| 18 | +from collections.abc import Iterator |
| 19 | +import importlib |
| 20 | +import functools |
| 21 | +import os |
| 22 | +import pkgutil |
| 23 | +import warnings |
| 24 | + |
| 25 | +from absl.testing import absltest |
| 26 | +from absl.testing import parameterized |
| 27 | + |
| 28 | +import jax |
| 29 | +import jax._src.test_util as jtu |
| 30 | +from jax._src import config |
| 31 | + |
| 32 | +config.parse_flags_with_absl() |
| 33 | + |
| 34 | + |
| 35 | +JAX_DOCS_DIR = os.path.abspath(os.path.join(__file__, "..", "..", "docs")) |
| 36 | +CURRENTMODULE_TAG = '.. currentmodule::' |
| 37 | +AUTOMODULE_TAG = '.. automodule::' |
| 38 | +AUTOSUMMARY_TAG = '.. autosummary::' |
| 39 | +AUTOCLASS_TAG = '.. autoclass::' |
| 40 | + |
| 41 | +@functools.lru_cache() |
| 42 | +def undocumented_apis(): |
| 43 | + """Return a dictionary of per-module symbols that are known to be undocumented.""" |
| 44 | + return { |
| 45 | + 'jax': ['NamedSharding', 'P', 'Ref', 'Shard', 'ad_checkpoint', 'api_util', 'checkpoint_policies', 'core', 'custom_derivatives', 'custom_transpose', 'debug_key_reuse', 'device_put_replicated', 'device_put_sharded', 'effects_barrier', 'example_libraries', 'explain_cache_misses', 'experimental', 'extend', 'float0', 'freeze', 'fwd_and_bwd', 'host_count', 'host_id', 'host_ids', 'interpreters', 'jax', 'jax2tf_associative_scan_reductions', 'legacy_prng_key', 'lib', 'make_user_context', 'new_ref', 'no_execution', 'numpy_dtype_promotion', 'remat', 'remove_size_one_mesh_axis_from_type', 'softmax_custom_jvp', 'threefry_partitionable', 'tools', 'transfer_guard_device_to_device', 'transfer_guard_device_to_host', 'transfer_guard_host_to_device', 'typeof', 'version'], |
| 46 | + 'jax.custom_batching': ['custom_vmap', 'sequential_vmap'], |
| 47 | + 'jax.custom_derivatives': ['CustomVJPPrimal', 'SymbolicZero', 'closure_convert', 'custom_gradient', 'custom_jvp', 'custom_jvp_call_p', 'custom_vjp', 'custom_vjp_call_p', 'custom_vjp_primal_tree_values', 'linear_call', 'remat_opt_p', 'zero_from_primal'], |
| 48 | + 'jax.custom_transpose': ['custom_transpose'], |
| 49 | + 'jax.debug': ['DebugEffect', 'log'], |
| 50 | + 'jax.distributed': ['is_initialized'], |
| 51 | + 'jax.dlpack': ['jax'], |
| 52 | + 'jax.dtypes': ['extended', 'finfo', 'iinfo'], |
| 53 | + 'jax.errors': ['JAXIndexError', 'JAXTypeError'], |
| 54 | + 'jax.ffi': ['build_ffi_lowering_function', 'include_dir', 'register_ffi_target_as_batch_partitionable', 'register_ffi_type_id'], |
| 55 | + 'jax.lax': ['all_gather_invariant', 'unreduced_psum', 'dce_sink', 'conv_transpose_shape_tuple', 'reduce_window_shape_tuple', 'preduced', 'conv_general_permutations', 'conv_general_shape_tuple', 'pbroadcast', 'padtype_to_pads', 'conv_shape_tuple', 'unreduced_psum_scatter', 'create_token', 'dtype', 'shape_as_value', 'all_gather_reduced', 'pvary', *(name for name in dir(jax.lax) if name.endswith('_p'))], |
| 56 | + 'jax.lax.linalg': [api for api in dir(jax.lax.linalg) if api.endswith('_p')], |
| 57 | + 'jax.memory': ['Space'], |
| 58 | + 'jax.monitoring': ['clear_event_listeners', 'record_event', 'record_event_duration_secs', 'record_event_time_span', 'record_scalar', 'register_event_duration_secs_listener', 'register_event_listener', 'register_event_time_span_listener', 'register_scalar_listener', 'unregister_event_duration_listener', 'unregister_event_listener', 'unregister_event_time_span_listener', 'unregister_scalar_listener'], |
| 59 | + 'jax.nn': ['tanh'], |
| 60 | + 'jax.nn.initializers': ['Initializer', 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform'], |
| 61 | + 'jax.numpy': ['bfloat16', 'bool', 'e', 'euler_gamma', 'float4_e2m1fn', 'float8_e3m4', 'float8_e4m3', 'float8_e4m3b11fnuz', 'float8_e4m3fn', 'float8_e4m3fnuz', 'float8_e5m2', 'float8_e5m2fnuz', 'float8_e8m0fnu', 'inf', 'int2', 'int4', 'nan', 'newaxis', 'pi', 'uint2', 'uint4'], |
| 62 | + 'jax.profiler': ['ProfileData', 'ProfileEvent', 'ProfileOptions', 'ProfilePlane', 'stop_server'], |
| 63 | + 'jax.random': ['key_impl', 'random_gamma_p'], |
| 64 | + 'jax.scipy.special': ['bessel_jn', 'sph_harm_y'], |
| 65 | + 'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh'], |
| 66 | + 'jax.stages': ['ArgInfo', 'CompilerOptions'], |
| 67 | + 'jax.tree_util': ['DictKey', 'FlattenedIndexKey', 'GetAttrKey', 'PyTreeDef', 'SequenceKey', 'default_registry'], |
| 68 | + } |
| 69 | + |
| 70 | +# A list of modules to skip entirely, either because they cannot be imported |
| 71 | +# or because they are not expected to be documented. |
| 72 | +MODULES_TO_SKIP = [ |
| 73 | + "jax.ad_checkpoint", |
| 74 | + "jax.api_util", |
| 75 | + "jax.collect_profile", # fails when xprof is not available. |
| 76 | + "jax.core", |
| 77 | + "jax.example_libraries", |
| 78 | + "jax.extend", |
| 79 | + "jax.experimental", |
| 80 | + "jax.interpreters", |
| 81 | + "jax.lib", |
| 82 | + "jax.tools", |
| 83 | + "jax.version", |
| 84 | +] |
| 85 | + |
| 86 | + |
| 87 | +def extract_apis_from_rst_file(path: str) -> dict[str, list[str]]: |
| 88 | + """Extract documented APIs from an RST file.""" |
| 89 | + # We could do this more robustly by adding a docutils dependency, but that is |
| 90 | + # pretty heavy. Instead we use simple string-based file parsing, recognizing the |
| 91 | + # particular patterns used within the JAX documentation. |
| 92 | + currentmodule: str = '<none>' |
| 93 | + in_autosummary_block = False |
| 94 | + apis = collections.defaultdict(list) |
| 95 | + with open(path, 'r') as f: |
| 96 | + for line in f: |
| 97 | + stripped_line = line.strip() |
| 98 | + if not stripped_line: |
| 99 | + continue |
| 100 | + if line.startswith(CURRENTMODULE_TAG): |
| 101 | + currentmodule = line.removeprefix(CURRENTMODULE_TAG).strip() |
| 102 | + continue |
| 103 | + if line.startswith(AUTOMODULE_TAG): |
| 104 | + currentmodule = line.removeprefix(AUTOMODULE_TAG).strip() |
| 105 | + continue |
| 106 | + if line.startswith(AUTOCLASS_TAG): |
| 107 | + in_autosummary_block = False |
| 108 | + apis[currentmodule].append(line.removeprefix(AUTOCLASS_TAG).strip()) |
| 109 | + continue |
| 110 | + if line.startswith(AUTOSUMMARY_TAG): |
| 111 | + in_autosummary_block = True |
| 112 | + continue |
| 113 | + if not in_autosummary_block: |
| 114 | + continue |
| 115 | + if not line.startswith(' '): |
| 116 | + in_autosummary_block = False |
| 117 | + continue |
| 118 | + if stripped_line.startswith(':'): |
| 119 | + continue |
| 120 | + apis[currentmodule].append(stripped_line) |
| 121 | + return dict(apis) |
| 122 | + |
| 123 | + |
| 124 | +@functools.lru_cache() |
| 125 | +def get_all_documented_apis(path: str = JAX_DOCS_DIR) -> dict[str, list[str]]: |
| 126 | + """Get the list of APIs documented in all files in a directory (recursive).""" |
| 127 | + apis = collections.defaultdict(list) |
| 128 | + for root, _, files in os.walk(path): |
| 129 | + if (root.startswith(os.path.join(path, 'build')) |
| 130 | + or root.startswith(os.path.join(path, '_autosummary'))): |
| 131 | + continue |
| 132 | + for filename in files: |
| 133 | + if filename.endswith('.rst'): |
| 134 | + new_apis = extract_apis_from_rst_file(os.path.join(root, filename)) |
| 135 | + for key, val in new_apis.items(): |
| 136 | + apis[key].extend(val) |
| 137 | + return {key: sorted(vals) for key, vals in apis.items()} |
| 138 | + |
| 139 | + |
| 140 | +@functools.lru_cache() |
| 141 | +def list_public_jax_modules() -> list[str]: |
| 142 | + """Return a list of the public modules defined in jax.""" |
| 143 | + # We could use pkgutil.walk_packages, but we want to avoid traversing modules |
| 144 | + # like `jax._src`, `jax.example_libraries`, etc. so we implement it manually. |
| 145 | + def walk_public_modules(paths: list[str], parent_package: str) -> Iterator[str]: |
| 146 | + for info in pkgutil.iter_modules(paths): |
| 147 | + pkg_name = f"{parent_package}.{info.name}" |
| 148 | + if pkg_name in MODULES_TO_SKIP or info.name == 'tests' or info.name.startswith('_'): |
| 149 | + continue |
| 150 | + yield pkg_name |
| 151 | + if not info.ispkg: |
| 152 | + continue |
| 153 | + try: |
| 154 | + submodule = importlib.import_module(pkg_name) |
| 155 | + except ImportError as e: |
| 156 | + warnings.warn(f"failed to import {pkg_name}: {e!r}") |
| 157 | + else: |
| 158 | + if path := getattr(submodule, '__path__', None): |
| 159 | + yield from walk_public_modules(path, pkg_name) |
| 160 | + return [jax.__name__, *walk_public_modules(jax.__path__, jax.__name__)] |
| 161 | + |
| 162 | + |
| 163 | +@functools.lru_cache() |
| 164 | +def list_public_apis(module_name: str) -> list[str]: |
| 165 | + """Return a list of public APIs within a specified module. |
| 166 | +
|
| 167 | + This will import the module as a side-effect. |
| 168 | + """ |
| 169 | + module = importlib.import_module(module_name) |
| 170 | + return [api for api in dir(module) |
| 171 | + if not api.startswith('_') # skip private members |
| 172 | + and not api.startswith('@') # skip injected pytest-related symbols |
| 173 | + ] |
| 174 | + |
| 175 | + |
| 176 | +@functools.lru_cache() |
| 177 | +def get_all_public_jax_apis() -> dict[str, list[str]]: |
| 178 | + """Return a dictionary mapping jax submodules to their list of public APIs.""" |
| 179 | + apis = {} |
| 180 | + for module in list_public_jax_modules(): |
| 181 | + try: |
| 182 | + apis[module] = list_public_apis(module) |
| 183 | + except ImportError as e: |
| 184 | + warnings.warn(f"failed to import {module}: {e}") |
| 185 | + return apis |
| 186 | + |
| 187 | + |
| 188 | +class DocumentationCoverageTest(jtu.JaxTestCase): |
| 189 | + def test_list_public_jax_modules(self): |
| 190 | + """Simple smoke test for list_public_jax_modules()""" |
| 191 | + apis = list_public_jax_modules() |
| 192 | + |
| 193 | + # A few submodules which should be included |
| 194 | + self.assertIn("jax", apis) |
| 195 | + self.assertIn("jax.numpy", apis) |
| 196 | + self.assertIn("jax.numpy.linalg", apis) |
| 197 | + |
| 198 | + # A few submodules which should not be included |
| 199 | + self.assertNotIn("jax._src", apis) |
| 200 | + self.assertNotIn("jax._src.numpy", apis) |
| 201 | + self.assertNotIn("jax.example_libraries", apis) |
| 202 | + self.assertNotIn("jax.experimental.jax2tf", apis) |
| 203 | + |
| 204 | + def test_list_public_apis(self): |
| 205 | + """Simple smoketest for list_public_apis()""" |
| 206 | + jnp_apis = list_public_apis('jax.numpy') |
| 207 | + self.assertIn("array", jnp_apis) |
| 208 | + self.assertIn("zeros", jnp_apis) |
| 209 | + self.assertNotIn("jax.numpy.array", jnp_apis) |
| 210 | + self.assertNotIn("np", jnp_apis) |
| 211 | + self.assertNotIn("jax", jnp_apis) |
| 212 | + |
| 213 | + def test_get_all_public_jax_apis(self): |
| 214 | + """Simple smoketest for get_all_public_jax_apis()""" |
| 215 | + apis = get_all_public_jax_apis() |
| 216 | + self.assertIn("Array", apis["jax"]) |
| 217 | + self.assertIn("array", apis["jax.numpy"]) |
| 218 | + self.assertIn("eigh", apis["jax.numpy.linalg"]) |
| 219 | + |
| 220 | + def test_extract_apis_from_rst_file(self): |
| 221 | + """Simple smoketest for extract_apis_from_rst_file()""" |
| 222 | + numpy_docs = os.path.join(JAX_DOCS_DIR, "jax.numpy.rst") |
| 223 | + apis = extract_apis_from_rst_file(numpy_docs) |
| 224 | + |
| 225 | + self.assertIn("jax.numpy", apis.keys()) |
| 226 | + self.assertIn("jax.numpy.linalg", apis.keys()) |
| 227 | + |
| 228 | + self.assertIn("array", apis["jax.numpy"]) |
| 229 | + self.assertIn("asarray", apis["jax.numpy"]) |
| 230 | + self.assertIn("eigh", apis["jax.numpy.linalg"]) |
| 231 | + self.assertNotIn("jax", apis["jax.numpy"]) |
| 232 | + self.assertNotIn("jax.numpy", apis["jax.numpy"]) |
| 233 | + |
| 234 | + def test_get_all_documented_apis(self): |
| 235 | + """Simple smoketest of get_all_documented_apis()""" |
| 236 | + apis = get_all_documented_apis() |
| 237 | + self.assertIn("Array", apis["jax"]) |
| 238 | + self.assertIn("arange", apis["jax.numpy"]) |
| 239 | + self.assertIn("eigh", apis["jax.lax.linalg"]) |
| 240 | + |
| 241 | + @parameterized.parameters(list_public_jax_modules()) |
| 242 | + def test_module_apis_documented(self, module): |
| 243 | + """Test that the APIs in each module are appropriately documented.""" |
| 244 | + public_apis = get_all_public_jax_apis() |
| 245 | + documented_apis = get_all_documented_apis(JAX_DOCS_DIR) |
| 246 | + |
| 247 | + pub_apis = {f"{module}.{api}" for api in public_apis.get(module, ())} |
| 248 | + doc_apis = {f"{module}.{api}" for api in documented_apis.get(module, ())} |
| 249 | + undoc_apis = {f"{module}.{api}" for api in undocumented_apis().get(module, ())} |
| 250 | + |
| 251 | + # Remove submodules from list. |
| 252 | + pub_apis -= public_apis.keys() |
| 253 | + pub_apis -= set(MODULES_TO_SKIP) |
| 254 | + |
| 255 | + # This ensures that undocumented API lists are up-to-date: if this fails, |
| 256 | + # the fix is typically to remove the offending entries from `undocumented_apis()`. |
| 257 | + if (notempty := undoc_apis & doc_apis): |
| 258 | + raise ValueError( |
| 259 | + f"Found stale values in the undocumented_apis() list: {notempty}") |
| 260 | + |
| 261 | + # This asserts that all public APIs are documented. If this fails, it |
| 262 | + # likely means there is a new public API within the jax package, |
| 263 | + # and the fix is to add the new API to the appropriate file in docs/*.rst. |
| 264 | + if (notempty := pub_apis - doc_apis - undoc_apis): |
| 265 | + raise ValueError( |
| 266 | + f"Found public APIs that are not listed within docs: {notempty}") |
| 267 | + |
| 268 | + |
| 269 | +if __name__ == "__main__": |
| 270 | + absltest.main(testLoader=jtu.JaxTestLoader()) |
0 commit comments