Skip to content

Commit 96d7ec8

Browse files
committed
centralize some XML helper functions
1 parent f6791c0 commit 96d7ec8

4 files changed

Lines changed: 76 additions & 63 deletions

File tree

dev_tools/docs/nxdl.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from ..globals.errors import NXDLParseError
1313
from ..globals.nxdl import NXDL_NAMESPACE
1414
from ..globals.urls import REPO_URL
15+
from ..utils import nxdl_utils
16+
from ..utils import xml_utils
1517
from ..utils.github import get_file_contributors_via_api
16-
from ..utils.nxdl_utils import get_inherited_nodes
17-
from ..utils.nxdl_utils import get_rst_formatted_name
1818
from ..utils.types import PathLike
1919
from .anchor_list import AnchorRegistry
2020

@@ -363,8 +363,8 @@ def _get_required_or_optional_text(self, node):
363363
:param obj node: instance of lxml.etree._Element
364364
:returns: formatted text
365365
"""
366-
tag = node.tag.split("}")[-1]
367-
if tag in ("field", "group"):
366+
nxdl_element_type = nxdl_utils.get_nxdl_element_type(node)
367+
if nxdl_element_type in ("field", "group"):
368368
optional_default = not self._use_application_defaults
369369
optional = node.get("optional", optional_default) in (True, "true", "1", 1)
370370
recommended = node.get("recommended", None) in (True, "true", "1", 1)
@@ -379,15 +379,15 @@ def _get_required_or_optional_text(self, node):
379379
# this is unexpected and remarkable
380380
# TODO: add a remark to the log
381381
optional_text = f"(``minOccurs={str(minOccurs)}``) "
382-
elif tag in ("attribute",):
382+
elif nxdl_element_type in ("attribute",):
383383
optional_default = not self._use_application_defaults
384384
optional = node.get("optional", optional_default) in (True, "true", "1", 1)
385385
recommended = node.get("recommended", None) in (True, "true", "1", 1)
386386
optional_text = {True: "(optional) ", False: "(required) "}[optional]
387387
if recommended:
388388
optional_text = "(recommended) "
389389
else:
390-
optional_text = "(unknown tag: " + str(tag) + ") "
390+
optional_text = "(unknown tag: " + str(nxdl_element_type) + ") "
391391
return optional_text
392392

393393
def _analyze_dimensions(self, ns, parent) -> str:
@@ -596,7 +596,7 @@ def _print_doc_enum(self, indent, ns, node, required=False):
596596

597597
def _print_attribute(self, ns, kind, node, optional, indent, parent_path):
598598
name = node.get("name")
599-
formatted_name = get_rst_formatted_name(node)
599+
formatted_name = nxdl_utils.get_rst_formatted_name(node)
600600
index_name = name
601601
self._print(
602602
f"{indent}" f"{self._hyperlink_target(parent_path, name, 'attribute')}"
@@ -626,7 +626,7 @@ def _print_full_tree(self, ns, parent, name, indent, parent_path):
626626
"""
627627
for node in parent.xpath("nx:field", namespaces=ns):
628628
name = node.get("name")
629-
formatted_name = get_rst_formatted_name(node)
629+
formatted_name = nxdl_utils.get_rst_formatted_name(node)
630630
index_name = name
631631
dims = self._analyze_dimensions(ns, node)
632632

@@ -659,7 +659,7 @@ def _print_full_tree(self, ns, parent, name, indent, parent_path):
659659

660660
for node in parent.xpath("nx:group", namespaces=ns):
661661
name = node.get("name", "")
662-
formatted_name = get_rst_formatted_name(node)
662+
formatted_name = nxdl_utils.get_rst_formatted_name(node)
663663
typ = node.get("type", "untyped (this is an error; please report)")
664664

665665
optional_text = self._get_required_or_optional_text(node)
@@ -700,7 +700,7 @@ def _print_full_tree(self, ns, parent, name, indent, parent_path):
700700

701701
for node in parent.xpath("nx:link", namespaces=ns):
702702
name = node.get("name")
703-
formatted_name = get_rst_formatted_name(node)
703+
formatted_name = nxdl_utils.get_rst_formatted_name(node)
704704
self._print(f"{indent}{self._hyperlink_target(parent_path, name, 'link')}")
705705
self._print(
706706
f"{indent}{formatted_name}: "
@@ -719,15 +719,13 @@ def get_first_parent_ref(self, path, tag):
719719
path = path[path.find("/", 1) :]
720720

721721
try:
722-
parents = get_inherited_nodes(path, nx_name)[2]
722+
parents = nxdl_utils.get_inherited_nodes(path, nx_name)[2]
723723
except FileNotFoundError:
724724
return ""
725725
if len(parents) > 1:
726726
for parent in parents:
727727
# iterate back and check tag matches
728-
if not parent.tag.endswith(tag) and not parent.tag.endswith(
729-
"definition"
730-
):
728+
if xml_utils.get_local_name(parent) not in (tag, "definition"):
731729
print(
732730
f"Warning: {path} has a mismatching inherited node - {parent.tag} cf {tag}"
733731
)

dev_tools/docs/xsd.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ..globals import directories
88
from ..globals.errors import NXDLParseError
99
from ..globals.nxdl import XSD_NAMESPACE
10+
from ..utils import xml_utils
1011
from ..utils.types import PathLike
1112

1213

@@ -116,9 +117,7 @@ def general_handler(self, parent=None, indentLevel=0):
116117
if parent_name is None:
117118
return
118119

119-
simple_tag = parent.tag[
120-
parent.tag.find("}") + 1 :
121-
] # cut off the namespace identifier
120+
simple_tag = xml_utils.get_local_name(parent)
122121

123122
# <varlistentry> ...
124123
name = parent_name # + ' data type'

dev_tools/utils/nxdl_utils.py

Lines changed: 24 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import lxml.etree as ET
1414
from lxml.etree import ParseError as xmlER
1515

16+
from . import xml_utils
17+
1618

1719
def decode_or_not(elem, encoding: str = "utf-8", decode: bool = True):
1820
"""
@@ -52,10 +54,9 @@ def decode_or_not(elem, encoding: str = "utf-8", decode: bool = True):
5254
return elem
5355

5456

55-
def remove_namespace_from_tag(tag):
56-
"""Helper function to remove the namespace from an XML tag."""
57-
58-
return tag.split("}")[-1]
57+
def get_nxdl_element_type(element):
58+
type = xml_utils.get_local_name(element)
59+
return "field" if type == "link" else type
5960

6061

6162
class NxdlAttributeNotFoundError(Exception):
@@ -87,20 +88,13 @@ def get_app_defs_names():
8788

8889
files = sorted(glob(str(app_def_path_glob)))
8990
for nexus_file in sorted(glob(str(contrib_def_path_glob))):
90-
root = get_xml_root(nexus_file)
91+
root = xml_utils.read_xml_file(nexus_file)
9192
if root.attrib["category"] == "application":
9293
files.append(nexus_file)
9394

9495
return [Path(file).name[:-9] for file in files] + ["NXroot"]
9596

9697

97-
@lru_cache(maxsize=None)
98-
def get_xml_root(file_path):
99-
"""Reducing I/O time by caching technique"""
100-
101-
return ET.parse(file_path).getroot()
102-
103-
10498
def get_hdf_root(hdf_node):
10599
"""Get the root HDF5 node"""
106100
node = hdf_node
@@ -243,7 +237,7 @@ def get_nx_classes():
243237
nx_class = []
244238
for nexus_file in base_classes + applications + contributed:
245239
try:
246-
root = get_xml_root(nexus_file)
240+
root = xml_utils.read_xml_file(nexus_file)
247241
except xmlER as e:
248242
raise ValueError(f"Getting an issue while parsing file {nexus_file}") from e
249243
if root.attrib["category"] == "base":
@@ -254,7 +248,7 @@ def get_nx_classes():
254248
def get_nx_units():
255249
"""Read unit kinds from the NeXus definition/nxdlTypes.xsd file"""
256250
filepath = nexus_def_path / "nxdlTypes.xsd"
257-
root = get_xml_root(filepath)
251+
root = xml_utils.read_xml_file(filepath)
258252
units_and_type_list = []
259253
for child in root:
260254
units_and_type_list.extend(child.attrib.values())
@@ -275,7 +269,7 @@ def get_nx_attribute_type():
275269
"""Read attribute types from the NeXus definition/nxdlTypes.xsd file"""
276270
filepath = nexus_def_path / "nxdlTypes.xsd"
277271

278-
root = get_xml_root(filepath)
272+
root = xml_utils.read_xml_file(filepath)
279273
units_and_type_list = []
280274
for child in root:
281275
units_and_type_list.extend(child.attrib.values())
@@ -321,7 +315,7 @@ def is_name_type(child, name_type_value: str) -> bool:
321315
return True
322316

323317
if name_type_value == "any" and (
324-
get_local_name_from_xml(child) == "group"
318+
get_nxdl_element_type(child) == "group"
325319
and "nameType" not in child.attrib
326320
and "name" not in child.attrib
327321
):
@@ -355,7 +349,7 @@ def belongs_to(nxdl_elem, child, name, class_type=None, hdf_name=None):
355349
if not isinstance(child2.tag, str):
356350
continue
357351
if (
358-
get_local_name_from_xml(child) != get_local_name_from_xml(child2)
352+
get_nxdl_element_type(child) != get_nxdl_element_type(child2)
359353
or get_node_name(child2) == act_htmlname
360354
):
361355
continue
@@ -376,15 +370,9 @@ def belongs_to(nxdl_elem, child, name, class_type=None, hdf_name=None):
376370
return False
377371

378372

379-
def get_local_name_from_xml(element):
380-
"""Helper function to extract the element tag without the namespace."""
381-
type = remove_namespace_from_tag(element.tag)
382-
return "field" if type == "link" else type
383-
384-
385373
def get_own_nxdl_child_reserved_elements(child, name, nxdl_elem):
386374
"""checking reserved elements, like doc, enumeration"""
387-
local_name = get_local_name_from_xml(child)
375+
local_name = get_nxdl_element_type(child)
388376
if local_name == "doc" and name == "doc":
389377
return set_nxdlpath(child, nxdl_elem, tag_name=name)
390378

@@ -395,16 +383,16 @@ def get_own_nxdl_child_reserved_elements(child, name, nxdl_elem):
395383

396384
def get_own_nxdl_child_base_types(child, class_type, nxdl_elem, name, hdf_name):
397385
"""checking base types of group, field, attribute"""
398-
if get_local_name_from_xml(child) == "group":
386+
if get_nxdl_element_type(child) == "group":
399387
if (
400388
class_type is None or (class_type and get_nx_class(child) == class_type)
401389
) and belongs_to(nxdl_elem, child, name, class_type, hdf_name):
402390
return set_nxdlpath(child, nxdl_elem)
403-
if get_local_name_from_xml(child) == "field" and belongs_to(
391+
if get_nxdl_element_type(child) == "field" and belongs_to(
404392
nxdl_elem, child, name, None, hdf_name
405393
):
406394
return set_nxdlpath(child, nxdl_elem)
407-
if get_local_name_from_xml(child) == "attribute" and belongs_to(
395+
if get_nxdl_element_type(child) == "attribute" and belongs_to(
408396
nxdl_elem, child, name, None, hdf_name
409397
):
410398
return set_nxdlpath(child, nxdl_elem)
@@ -424,7 +412,7 @@ def get_own_nxdl_child(
424412
result = get_own_nxdl_child_reserved_elements(child, name, nxdl_elem)
425413
if result is not False:
426414
return result
427-
if nexus_type and get_local_name_from_xml(child) != nexus_type:
415+
if nexus_type and get_nxdl_element_type(child) != nexus_type:
428416
continue
429417
result = get_own_nxdl_child_base_types(
430418
child, class_type, nxdl_elem, name, hdf_name
@@ -470,7 +458,7 @@ def get_nxdl_child(
470458
bc_filename = find_definition_file(bc_name)
471459
if not bc_filename:
472460
raise ValueError("nxdl file not found in definitions folder!")
473-
bc_obj = ET.parse(bc_filename).getroot()
461+
bc_obj = xml_utils.read_xml_file(bc_filename)
474462
bc_obj.set("nxdlbase", bc_filename)
475463
if "category" in bc_obj.attrib:
476464
bc_obj.set("nxdlbase_class", bc_obj.attrib["category"])
@@ -692,11 +680,6 @@ def print_doc(node, ntype, level, nxhtml, nxpath):
692680
print(wrapper.fill(par))
693681

694682

695-
def get_namespace(element):
696-
"""Extracts the namespace for elements in the NXDL"""
697-
return element.tag[element.tag.index("{") : element.tag.rindex("}") + 1]
698-
699-
700683
def get_enums(node: ET._Element) -> Optional[List[str]]:
701684
"""
702685
Makes list of enumerations, if node contains any.
@@ -709,7 +692,7 @@ def get_enums(node: ET._Element) -> Optional[List[str]]:
709692
Returns a list of the enumeration values if an enumeration was found.
710693
If no enumeration was found it returns None.
711694
"""
712-
namespace = get_namespace(node)
695+
namespace = xml_utils.get_namespace(node)
713696
enums = []
714697
for enumeration in node.findall(f"{namespace}enumeration"):
715698
for item in enumeration.findall(f"{namespace}item"):
@@ -736,12 +719,7 @@ def add_base_classes(elist, nx_name=None, elem: ET.Element = None):
736719
if nxdl_file_path is None:
737720
nxdl_file_path = f"{nx_name}.nxdl.xml"
738721

739-
try:
740-
elem = ET.parse(os.path.abspath(nxdl_file_path)).getroot()
741-
# elem = ET.parse(nxdl_file_path).getroot()
742-
except OSError:
743-
with open(nxdl_file_path, "r") as f:
744-
elem = ET.parse(f).getroot()
722+
elem = xml_utils.read_xml_file(nxdl_file_path)
745723

746724
if not isinstance(nxdl_file_path, str):
747725
nxdl_file_path = str(nxdl_file_path)
@@ -781,7 +759,7 @@ def get_direct_child(nxdl_elem, html_name):
781759
for child in nxdl_elem:
782760
if not isinstance(child.tag, str):
783761
continue
784-
if get_local_name_from_xml(child) in (
762+
if get_nxdl_element_type(child) in (
785763
"group",
786764
"field",
787765
"attribute",
@@ -798,7 +776,7 @@ def get_field_child(nxdl_elem, html_name):
798776
for child in nxdl_elem:
799777
if not isinstance(child.tag, str):
800778
continue
801-
if get_local_name_from_xml(child) != "field":
779+
if get_nxdl_element_type(child) != "field":
802780
continue
803781
if get_node_name(child) == html_name:
804782
data_child = set_nxdlpath(child, nxdl_elem)
@@ -853,7 +831,7 @@ def get_best_child(nxdl_elem, hdf_node, hdf_name, hdf_class_name, nexus_type):
853831
if not isinstance(child.tag, str):
854832
continue
855833
fit = -2
856-
if get_local_name_from_xml(child) == nexus_type and (
834+
if get_nxdl_element_type(child) == nexus_type and (
857835
nexus_type != "group" or get_nx_class(child) == hdf_class_name
858836
):
859837
name_any = is_name_type(child, "any")
@@ -885,7 +863,7 @@ def walk_elist(elist, html_name):
885863
None,
886864
html_name,
887865
get_nx_class(main_child),
888-
get_local_name_from_xml(main_child),
866+
get_nxdl_element_type(main_child),
889867
)
890868
if fitting_child is not None:
891869
child = fitting_child
@@ -973,7 +951,7 @@ def get_rst_formatted_name(node):
973951
name = node.get("name", "")
974952
nameType = node.get("nameType", "")
975953

976-
node_type = get_local_name_from_xml(node)
954+
node_type = get_nxdl_element_type(node)
977955

978956
if not name and node_type == "group":
979957
# Derive the name from the type without the NX prefix

dev_tools/utils/xml_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from functools import lru_cache
2+
from pathlib import Path
3+
4+
import lxml.etree as ET
5+
6+
7+
def read_xml_file(file_path: Path | str) -> ET.Element:
8+
"""Read XML file with caching."""
9+
normalized_path = Path(file_path).resolve()
10+
return _read_xml_file(normalized_path)
11+
12+
13+
@lru_cache(maxsize=None)
14+
def _read_xml_file(normalized_path: Path) -> ET.Element:
15+
try:
16+
return ET.parse(normalized_path).getroot()
17+
except OSError:
18+
# Not sure this is still necessary
19+
with open(normalized_path, "r") as f:
20+
return ET.parse(f).getroot()
21+
22+
23+
def get_local_name(element: ET.Element) -> str:
24+
"""
25+
Return the local XML tag name of an element (without its namespace).
26+
27+
'{http://example.org/ns}field' -> 'field'
28+
"""
29+
return ET.QName(element).localname
30+
31+
32+
def get_namespace(element: ET.Element) -> str:
33+
"""
34+
Return the namespace URI of an XML element.
35+
36+
'{http://example.org/ns}field' -> 'http://example.org/ns'
37+
"""
38+
return ET.QName(element).namespace

0 commit comments

Comments
 (0)