Skip to content

Commit ba23ec8

Browse files
committed
Add test of documentation coverage
1 parent 457242d commit ba23ec8

File tree

2 files changed

+283
-0
lines changed

2 files changed

+283
-0
lines changed

tests/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,19 @@ jax_py_test(
957957
] + py_deps("absl/testing"),
958958
)
959959

960+
jax_py_test(
961+
name = "documentation_coverage_test",
962+
srcs = [
963+
"documentation_coverage_test.py",
964+
],
965+
deps = [
966+
"//jax",
967+
"//jax/_src:config",
968+
"//jax/_src:internal_test_util",
969+
"//jax/_src:test_util",
970+
] + py_deps("absl/testing"),
971+
)
972+
960973
jax_multiplatform_test(
961974
name = "linalg_test",
962975
srcs = ["linalg_test.py"],
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test that public APIs are correctly documented."""
16+
17+
import collections
18+
from collections.abc import Iterator
19+
import importlib
20+
import functools
21+
import os
22+
import pkgutil
23+
import warnings
24+
25+
from absl.testing import absltest
26+
from absl.testing import parameterized
27+
28+
import jax
29+
import jax._src.test_util as jtu
30+
from jax._src import config
31+
32+
config.parse_flags_with_absl()
33+
34+
35+
JAX_DOCS_DIR = os.path.abspath(os.path.join(__file__, "..", "..", "docs"))
36+
CURRENTMODULE_TAG = '.. currentmodule::'
37+
AUTOMODULE_TAG = '.. automodule::'
38+
AUTOSUMMARY_TAG = '.. autosummary::'
39+
AUTOCLASS_TAG = '.. autoclass::'
40+
41+
@functools.lru_cache()
42+
def undocumented_apis():
43+
"""Return a dictionary of per-module symbols that are known to be undocumented."""
44+
return {
45+
'jax': ['NamedSharding', 'P', 'Ref', 'Shard', 'ad_checkpoint', 'api_util', 'checkpoint_policies', 'core', 'custom_derivatives', 'custom_transpose', 'debug_key_reuse', 'device_put_replicated', 'device_put_sharded', 'effects_barrier', 'example_libraries', 'explain_cache_misses', 'experimental', 'extend', 'float0', 'freeze', 'fwd_and_bwd', 'host_count', 'host_id', 'host_ids', 'interpreters', 'jax', 'jax2tf_associative_scan_reductions', 'legacy_prng_key', 'lib', 'make_user_context', 'new_ref', 'no_execution', 'numpy_dtype_promotion', 'remat', 'remove_size_one_mesh_axis_from_type', 'softmax_custom_jvp', 'threefry_partitionable', 'tools', 'transfer_guard_device_to_device', 'transfer_guard_device_to_host', 'transfer_guard_host_to_device', 'typeof', 'version'],
46+
'jax.custom_batching': ['custom_vmap', 'sequential_vmap'],
47+
'jax.custom_derivatives': ['CustomVJPPrimal', 'SymbolicZero', 'closure_convert', 'custom_gradient', 'custom_jvp', 'custom_jvp_call_p', 'custom_vjp', 'custom_vjp_call_p', 'custom_vjp_primal_tree_values', 'linear_call', 'remat_opt_p', 'zero_from_primal'],
48+
'jax.custom_transpose': ['custom_transpose'],
49+
'jax.debug': ['DebugEffect', 'log'],
50+
'jax.distributed': ['is_initialized'],
51+
'jax.dlpack': ['jax'],
52+
'jax.dtypes': ['extended', 'finfo', 'iinfo'],
53+
'jax.errors': ['JAXIndexError', 'JAXTypeError'],
54+
'jax.ffi': ['build_ffi_lowering_function', 'include_dir', 'register_ffi_target_as_batch_partitionable', 'register_ffi_type_id'],
55+
'jax.lax': ['all_gather_invariant', 'unreduced_psum', 'dce_sink', 'conv_transpose_shape_tuple', 'reduce_window_shape_tuple', 'preduced', 'conv_general_permutations', 'conv_general_shape_tuple', 'pbroadcast', 'padtype_to_pads', 'conv_shape_tuple', 'unreduced_psum_scatter', 'create_token', 'dtype', 'shape_as_value', 'all_gather_reduced', 'pvary', *(name for name in dir(jax.lax) if name.endswith('_p'))],
56+
'jax.lax.linalg': [api for api in dir(jax.lax.linalg) if api.endswith('_p')],
57+
'jax.memory': ['Space'],
58+
'jax.monitoring': ['clear_event_listeners', 'record_event', 'record_event_duration_secs', 'record_event_time_span', 'record_scalar', 'register_event_duration_secs_listener', 'register_event_listener', 'register_event_time_span_listener', 'register_scalar_listener', 'unregister_event_duration_listener', 'unregister_event_listener', 'unregister_event_time_span_listener', 'unregister_scalar_listener'],
59+
'jax.nn': ['tanh'],
60+
'jax.nn.initializers': ['Initializer', 'kaiming_normal', 'kaiming_uniform', 'xavier_normal', 'xavier_uniform'],
61+
'jax.numpy': ['bfloat16', 'bool', 'e', 'euler_gamma', 'float4_e2m1fn', 'float8_e3m4', 'float8_e4m3', 'float8_e4m3b11fnuz', 'float8_e4m3fn', 'float8_e4m3fnuz', 'float8_e5m2', 'float8_e5m2fnuz', 'float8_e8m0fnu', 'inf', 'int2', 'int4', 'nan', 'newaxis', 'pi', 'uint2', 'uint4'],
62+
'jax.profiler': ['ProfileData', 'ProfileEvent', 'ProfileOptions', 'ProfilePlane', 'stop_server'],
63+
'jax.random': ['key_impl', 'random_gamma_p'],
64+
'jax.scipy.special': ['bessel_jn', 'sph_harm_y'],
65+
'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh'],
66+
'jax.stages': ['ArgInfo', 'CompilerOptions'],
67+
'jax.tree_util': ['DictKey', 'FlattenedIndexKey', 'GetAttrKey', 'PyTreeDef', 'SequenceKey', 'default_registry'],
68+
}
69+
70+
# A list of modules to skip entirely, either because they cannot be imported
71+
# or because they are not expected to be documented.
72+
MODULES_TO_SKIP = [
73+
"jax.ad_checkpoint",
74+
"jax.api_util",
75+
"jax.collect_profile", # fails when xprof is not available.
76+
"jax.core",
77+
"jax.example_libraries",
78+
"jax.extend",
79+
"jax.experimental",
80+
"jax.interpreters",
81+
"jax.lib",
82+
"jax.tools",
83+
"jax.version",
84+
]
85+
86+
87+
def extract_apis_from_rst_file(path: str) -> dict[str, list[str]]:
88+
"""Extract documented APIs from an RST file."""
89+
# We could do this more robustly by adding a docutils dependency, but that is
90+
# pretty heavy. Instead we use simple string-based file parsing, recognizing the
91+
# particular patterns used within the JAX documentation.
92+
currentmodule: str = '<none>'
93+
in_autosummary_block = False
94+
apis = collections.defaultdict(list)
95+
with open(path, 'r') as f:
96+
for line in f:
97+
stripped_line = line.strip()
98+
if not stripped_line:
99+
continue
100+
if line.startswith(CURRENTMODULE_TAG):
101+
currentmodule = line.removeprefix(CURRENTMODULE_TAG).strip()
102+
continue
103+
if line.startswith(AUTOMODULE_TAG):
104+
currentmodule = line.removeprefix(AUTOMODULE_TAG).strip()
105+
continue
106+
if line.startswith(AUTOCLASS_TAG):
107+
in_autosummary_block = False
108+
apis[currentmodule].append(line.removeprefix(AUTOCLASS_TAG).strip())
109+
continue
110+
if line.startswith(AUTOSUMMARY_TAG):
111+
in_autosummary_block = True
112+
continue
113+
if not in_autosummary_block:
114+
continue
115+
if not line.startswith(' '):
116+
in_autosummary_block = False
117+
continue
118+
if stripped_line.startswith(':'):
119+
continue
120+
apis[currentmodule].append(stripped_line)
121+
return dict(apis)
122+
123+
124+
@functools.lru_cache()
125+
def get_all_documented_apis(path: str = JAX_DOCS_DIR) -> dict[str, list[str]]:
126+
"""Get the list of APIs documented in all files in a directory (recursive)."""
127+
apis = collections.defaultdict(list)
128+
for root, _, files in os.walk(path):
129+
if (root.startswith(os.path.join(path, 'build'))
130+
or root.startswith(os.path.join(path, '_autosummary'))):
131+
continue
132+
for filename in files:
133+
if filename.endswith('.rst'):
134+
new_apis = extract_apis_from_rst_file(os.path.join(root, filename))
135+
for key, val in new_apis.items():
136+
apis[key].extend(val)
137+
return {key: sorted(vals) for key, vals in apis.items()}
138+
139+
140+
@functools.lru_cache()
141+
def list_public_jax_modules() -> list[str]:
142+
"""Return a list of the public modules defined in jax."""
143+
# We could use pkgutil.walk_packages, but we want to avoid traversing modules
144+
# like `jax._src`, `jax.example_libraries`, etc. so we implement it manually.
145+
def walk_public_modules(paths: list[str], parent_package: str) -> Iterator[str]:
146+
for info in pkgutil.iter_modules(paths):
147+
pkg_name = f"{parent_package}.{info.name}"
148+
if pkg_name in MODULES_TO_SKIP or info.name == 'tests' or info.name.startswith('_'):
149+
continue
150+
yield pkg_name
151+
if not info.ispkg:
152+
continue
153+
try:
154+
submodule = importlib.import_module(pkg_name)
155+
except ImportError as e:
156+
warnings.warn(f"failed to import {pkg_name}: {e!r}")
157+
else:
158+
if path := getattr(submodule, '__path__', None):
159+
yield from walk_public_modules(path, pkg_name)
160+
return [jax.__name__, *walk_public_modules(jax.__path__, jax.__name__)]
161+
162+
163+
@functools.lru_cache()
164+
def list_public_apis(module_name: str) -> list[str]:
165+
"""Return a list of public APIs within a specified module.
166+
167+
This will import the module as a side-effect.
168+
"""
169+
module = importlib.import_module(module_name)
170+
return [api for api in dir(module)
171+
if not api.startswith('_') # skip private members
172+
and not api.startswith('@') # skip injected pytest-related symbols
173+
]
174+
175+
176+
@functools.lru_cache()
177+
def get_all_public_jax_apis() -> dict[str, list[str]]:
178+
"""Return a dictionary mapping jax submodules to their list of public APIs."""
179+
apis = {}
180+
for module in list_public_jax_modules():
181+
try:
182+
apis[module] = list_public_apis(module)
183+
except ImportError as e:
184+
warnings.warn(f"failed to import {module}: {e}")
185+
return apis
186+
187+
188+
class DocumentationCoverageTest(jtu.JaxTestCase):
189+
def test_list_public_jax_modules(self):
190+
"""Simple smoke test for list_public_jax_modules()"""
191+
apis = list_public_jax_modules()
192+
193+
# A few submodules which should be included
194+
self.assertIn("jax", apis)
195+
self.assertIn("jax.numpy", apis)
196+
self.assertIn("jax.numpy.linalg", apis)
197+
198+
# A few submodules which should not be included
199+
self.assertNotIn("jax._src", apis)
200+
self.assertNotIn("jax._src.numpy", apis)
201+
self.assertNotIn("jax.example_libraries", apis)
202+
self.assertNotIn("jax.experimental.jax2tf", apis)
203+
204+
def test_list_public_apis(self):
205+
"""Simple smoketest for list_public_apis()"""
206+
jnp_apis = list_public_apis('jax.numpy')
207+
self.assertIn("array", jnp_apis)
208+
self.assertIn("zeros", jnp_apis)
209+
self.assertNotIn("jax.numpy.array", jnp_apis)
210+
self.assertNotIn("np", jnp_apis)
211+
self.assertNotIn("jax", jnp_apis)
212+
213+
def test_get_all_public_jax_apis(self):
214+
"""Simple smoketest for get_all_public_jax_apis()"""
215+
apis = get_all_public_jax_apis()
216+
self.assertIn("Array", apis["jax"])
217+
self.assertIn("array", apis["jax.numpy"])
218+
self.assertIn("eigh", apis["jax.numpy.linalg"])
219+
220+
def test_extract_apis_from_rst_file(self):
221+
"""Simple smoketest for extract_apis_from_rst_file()"""
222+
numpy_docs = os.path.join(JAX_DOCS_DIR, "jax.numpy.rst")
223+
apis = extract_apis_from_rst_file(numpy_docs)
224+
225+
self.assertIn("jax.numpy", apis.keys())
226+
self.assertIn("jax.numpy.linalg", apis.keys())
227+
228+
self.assertIn("array", apis["jax.numpy"])
229+
self.assertIn("asarray", apis["jax.numpy"])
230+
self.assertIn("eigh", apis["jax.numpy.linalg"])
231+
self.assertNotIn("jax", apis["jax.numpy"])
232+
self.assertNotIn("jax.numpy", apis["jax.numpy"])
233+
234+
def test_get_all_documented_apis(self):
235+
"""Simple smoketest of get_all_documented_apis()"""
236+
apis = get_all_documented_apis()
237+
self.assertIn("Array", apis["jax"])
238+
self.assertIn("arange", apis["jax.numpy"])
239+
self.assertIn("eigh", apis["jax.lax.linalg"])
240+
241+
@parameterized.parameters(list_public_jax_modules())
242+
def test_module_apis_documented(self, module):
243+
"""Test that the APIs in each module are appropriately documented."""
244+
public_apis = get_all_public_jax_apis()
245+
documented_apis = get_all_documented_apis(JAX_DOCS_DIR)
246+
247+
pub_apis = {f"{module}.{api}" for api in public_apis.get(module, ())}
248+
doc_apis = {f"{module}.{api}" for api in documented_apis.get(module, ())}
249+
undoc_apis = {f"{module}.{api}" for api in undocumented_apis().get(module, ())}
250+
251+
# Remove submodules from list.
252+
pub_apis -= public_apis.keys()
253+
pub_apis -= set(MODULES_TO_SKIP)
254+
255+
# This ensures that undocumented API lists are up-to-date: if this fails,
256+
# the fix is typically to remove the offending entries from `undocumented_apis()`.
257+
if (notempty := undoc_apis & doc_apis):
258+
raise ValueError(
259+
f"Found stale values in the undocumented_apis() list: {notempty}")
260+
261+
# This asserts that all public APIs are documented. If this fails, it
262+
# likely means there is a new public API within the jax package,
263+
# and the fix is to add the new API to the appropriate file in docs/*.rst.
264+
if (notempty := pub_apis - doc_apis - undoc_apis):
265+
raise ValueError(
266+
f"Found public APIs that are not listed within docs: {notempty}")
267+
268+
269+
if __name__ == "__main__":
270+
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)