66import sys
77
88from setuptools .command .build_ext import build_ext as _build_ext
9- from wheel .bdist_wheel import bdist_wheel as _bdist_wheel
109
1110
1211def detect_cuda_paths ():
@@ -22,6 +21,9 @@ def detect_cuda_paths():
2221 potential_build_prefixes = (
2322 [os .path .join (p , "nvidia/cuda_runtime" ) for p in sys .path ]
2423 + [os .path .join (p , "nvidia/cuda_nvcc" ) for p in sys .path ]
24+ # internal/bindings depends on cuda_bindings cydriver,
25+ # which introduces dependency on cudaProfiler.h
26+ + [os .path .join (p , "nvidia/cuda_profiler_api" ) for p in sys .path ]
2527 + [os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" , "" )), "/usr/local/cuda" ]
2628 )
2729 cuda_paths = []
@@ -38,6 +40,9 @@ def check_path(header):
3840
3941 check_path ("cuda.h" )
4042 check_path ("crt/host_defines.h" )
43+ # internal/bindings depends on cuda_bindings cydriver,
44+ # which introduces dependency on cudaProfiler.h
45+ check_path ("cudaProfiler.h" )
4146 return cuda_paths
4247
4348
@@ -50,16 +55,6 @@ def decide_lib_name(ext_name):
5055 return None
5156
5257
53- building_wheel = False
54-
55-
56- class bdist_wheel (_bdist_wheel ):
57- def run (self ):
58- global building_wheel
59- building_wheel = True
60- super ().run ()
61-
62-
6358class build_ext (_build_ext ):
6459 def __init__ (self , * args , ** kwargs ):
6560 self ._nvmath_cuda_paths = detect_cuda_paths ()
@@ -74,42 +69,43 @@ def _prep_includes_libs_rpaths(self, lib_name):
7469 Set cuda_incl_dir and extra_linker_flags.
7570 """
7671 cuda_incl_dir = [os .path .join (p , "include" ) for p in self ._nvmath_cuda_paths ]
72+ extra_linker_flags = []
73+
74+ site_packages = ["$ORIGIN/../../.." ]
75+ if self .editable_mode :
76+ import site
7777
78- if not building_wheel :
79- # Note: with PEP-517 the editable mode would not build a wheel for installation
80- # (and we purposely do not support PEP-660).
81- extra_linker_flags = []
78+ site_packages = site .getsitepackages ()
8279 else :
83- # Note: soname = library major version
84- # We need to be able to search for cuBLAS/cuSOLVER/... at run time, in case they
85- # are installed via pip wheels.
86- # The rpaths must be adjusted given the following full-wheel installation:
87- # - $ORIGIN: site-packages/nvmath/bindings/_internal/
88- # - cublas: site-packages/nvidia/cublas/lib/
89- # - cusolver: site-packages/nvidia/cusolver/lib/
90- # - ... ...
91- # strip binaries to remove debug symbols which significantly increase wheel size
80+ # strip binaries to remove debug symbols which significantly
81+ # increase wheel size
9282 extra_linker_flags = ["-Wl,--strip-all" ]
93- if lib_name is not None :
94- ldflag = "-Wl,--disable-new-dtags"
95- match lib_name :
96- case "nvpl" :
97- # 1. the nvpl bindings land in
98- # site-packages/nvmath/bindings/nvpl/_internal/ as opposed to other
99- # packages that have their bindings in
100- # site-packages/nvmath/bindings/_internal/, so we need one extra
101- # `..` to get into `site-packages` and then the lib_name=nvpl is not
102- # in nvidia dir but directly in the site-packages.
103- # 2. mkl lib is placed directly in the python `lib` directory, not
104- # in python{ver}/site-packages
105- ldflag += f",-rpath,$ORIGIN/../../../../{ lib_name } /lib:$ORIGIN/../../../../../../"
106- case "cufftMp" :
107- ldflag += ",-rpath,$ORIGIN/../../../nvidia/cufftmp/cu12/lib"
108- case "mathdx" | "cudss" :
109- ldflag += ",-rpath,$ORIGIN/../../../nvidia/cu12/lib"
110- case _:
111- ldflag += f",-rpath,$ORIGIN/../../../nvidia/{ lib_name } /lib"
112- extra_linker_flags .append (ldflag )
83+ nvpl_site_packages = [f"{ p } /.." for p in site_packages ]
84+
85+ # Note: soname = library major version
86+ # We need to be able to search for cuBLAS/cuSOLVER/... at run time, in case they
87+ # are installed via pip wheels.
88+ # The rpaths must be adjusted given the following full-wheel installation:
89+ # - $ORIGIN: site-packages/nvmath/bindings/_internal/
90+ # - cublas: site-packages/nvidia/cublas/lib/
91+ # - cusolver: site-packages/nvidia/cusolver/lib/
92+ # - ... ...
93+ if lib_name is None :
94+ return cuda_incl_dir , extra_linker_flags
95+
96+ ldflag = "-Wl,--disable-new-dtags"
97+ if lib_name == "nvpl" :
98+ # 1. the nvpl bindings land in
99+ # site-packages/nvmath/bindings/nvpl/_internal/ as opposed to other
100+ # packages that have their bindings in
101+ # site-packages/nvmath/bindings/_internal/, so we need one extra
102+ # `..` to get into `site-packages` and then the lib_name=nvpl is not
103+ # in nvidia dir but directly in the site-packages.
104+ # 2. mkl lib is placed directly in the python `lib` directory, not
105+ # in python{ver}/site-packages
106+ rpath = ":" .join ([f"{ pth } /{ lib_name } /lib:{ pth } /../../" for pth in nvpl_site_packages ])
107+ ldflag += f",-rpath,{ rpath } "
108+ extra_linker_flags .append (ldflag )
113109
114110 return cuda_incl_dir , extra_linker_flags
115111
0 commit comments