diff --git a/array_api_tests/test_inspection_functions.py b/array_api_tests/test_inspection_functions.py index 4bbcfa2b..d210535e 100644 --- a/array_api_tests/test_inspection_functions.py +++ b/array_api_tests/test_inspection_functions.py @@ -16,13 +16,14 @@ def test_capabilities(self): expected_attr = {"boolean indexing": bool, "data-dependent shapes": bool} if xp.__array_api_version__ >= "2024.12": - expected_attr.update(**{"max dimensions": int}) + expected_attr.update(**{"max dimensions": type(None) | int}) for attr, typ in expected_attr.items(): assert attr in capabilities, f'capabilites is missing "{attr}".' assert isinstance(capabilities[attr], typ) - assert capabilities.get("max dimensions", 100500) > 0 + max_dims = capabilities.get("max dimensions", 100500) + assert (max_dims is None) or (max_dims > 0) def test_devices(self): out = xp.__array_namespace_info__()