Skip to content

Commit 0e72fac

Browse files
authored
Improve torch-xpu-ops codegen (#1541)
# Motivation Improve code to avoid `PermissionError` on Windows: Previously, the script n `install_xpu_headers.py` opened the file twice - once for reading and once for writing: ```python for file in files: with open(file) as fr: # modify fr with open(file, "w") as fw: # write fw ``` However, on Windows, this can lead to `PermissionError` due to file access conflicts. It looks like a bug related to the Windows system or Python. I am not sure. To avoid this issue, we now open the file only once using `"r+"` mode, allowing both reading and writing within the same context: ```python for file in files: with open(file, "r+") as f: # modify f # rewrite f ``` Additionally, add dependencies for torch-xpu-ops codegen to ensure proper execution. # Solution Trying to prevent writing back an unchanged file. And open the file only once to allow both reading and writing. # Additional Context Windows CI build pass refer to https://github.com/pytorch/pytorch/actions/runs/14236812185/job/39897594304?pr=139971
1 parent b1c5462 commit 0e72fac

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

cmake/Codegen.cmake

+2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ function(GEN_XPU file_yaml)
106106
${CODEGEN_XPU_YAML_DIR}/native/${file_yaml}
107107
${XPUFallback_TEMPLATE}
108108
${TORCH_XPU_OPS_ROOT}/tools/codegen/install_xpu_headers.py
109+
${ops_generated_headers}
110+
${ops_generated_sources}
109111
)
110112

111113
# Post codegen delete the copied templates folder only on Windows.

tools/codegen/install_xpu_headers.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,22 @@ def append_xpu_ops_headers(src_dir, dst_dir, common_headers, xpu_ops_headers):
9797
re.findall(r"struct TORCH_XPU_API.*xpu.*?{.*?};\n", src_text, re.DOTALL)
9898
)
9999

100-
with open(dst) as fr:
101-
dst_lines = fr.readlines()
100+
if not xpu_declarations:
101+
continue
102+
103+
with open(dst, "r+") as f:
104+
dst_lines = f.readlines()
102105
dst_text = "".join(dst_lines)
103-
for line in dst_lines:
106+
for index, line in enumerate(dst_lines):
104107
if re.match(r"^(TORCH_API.*;|struct TORCH_API.*)", line):
105108
for xpu_declaration in xpu_declarations:
106109
if not re.search(re.escape(xpu_declaration), dst_text):
107-
dst_lines.insert(dst_lines.index(line), xpu_declaration)
110+
dst_lines.insert(index, xpu_declaration)
108111
break
109112

110-
with open(dst, "w") as fw:
111-
fw.writelines(dst_lines)
113+
f.seek(0)
114+
f.writelines(dst_lines)
115+
f.truncate()
112116

113117

114118
def main():

0 commit comments

Comments
 (0)