|
13 | 13 | import os |
14 | 14 | import sys |
15 | 15 | import platform |
| 16 | +import ctypes |
16 | 17 | from ._utils import _import_dotted_name |
17 | | -from ._utils_internal import get_file_path, prepare_multiprocessing_environment |
| 18 | +from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \ |
| 19 | + USE_RTLD_GLOBAL_WITH_LIBTORCH |
18 | 20 | from .version import __version__ |
19 | 21 | from ._six import string_classes as _string_classes |
20 | 22 |
|
|
33 | 35 | # Load the extension module |
34 | 36 | ################################################################################ |
35 | 37 |
|
36 | | -# Loading the extension with RTLD_GLOBAL option allows to not link extension |
37 | | -# modules against the _C shared object. Their missing THP symbols will be |
38 | | -# automatically filled by the dynamic loader. |
39 | | -import os as _dl_flags |
40 | | - |
41 | | -# if we have numpy, it *must* be imported before the call to setdlopenflags() |
42 | | -# or there is risk that later c modules will segfault when importing numpy |
43 | | -try: |
44 | | - import numpy as _np |
45 | | -except ImportError: |
46 | | - pass |
47 | | - |
48 | 38 | if platform.system() == 'Windows': |
49 | | - # first get nvToolsExt PATH |
50 | | - def get_nvToolsExt_path(): |
51 | | - NVTOOLEXT_HOME = _dl_flags.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt') |
| 39 | + NVTOOLSEXT_PATH = os.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt') |
52 | 40 |
|
53 | | - if _dl_flags.path.exists(NVTOOLEXT_HOME): |
54 | | - return _dl_flags.path.join(NVTOOLEXT_HOME, 'bin', 'x64') |
55 | | - else: |
56 | | - return '' |
| 41 | + if os.path.exists(NVTOOLSEXT_PATH): |
| 42 | + nvtoolsext_lib_path = os.path.join(NVTOOLSEXT_PATH, 'bin', 'x64') |
| 43 | + else: |
| 44 | + nvtoolsext_lib_path = '' |
57 | 45 |
|
58 | | - py_dll_path = _dl_flags.path.join(sys.exec_prefix, 'Library', 'bin') |
59 | | - th_dll_path = _dl_flags.path.join(_dl_flags.path.dirname(__file__), 'lib') |
| 46 | + py_dll_path = os.path.join(sys.exec_prefix, 'Library', 'bin') |
| 47 | + th_dll_path = os.path.join(os.path.dirname(__file__), 'lib') |
60 | 48 |
|
61 | | - dll_paths = [th_dll_path, py_dll_path, get_nvToolsExt_path(), _dl_flags.environ['PATH']] |
| 49 | + dll_paths = [th_dll_path, py_dll_path, nvtoolsext_lib_path, os.environ['PATH']] |
62 | 50 |
|
63 | 51 | # then add the path to env |
64 | | - _dl_flags.environ['PATH'] = ';'.join(dll_paths) |
| 52 | + os.environ['PATH'] = ';'.join(dll_paths) |
65 | 53 |
|
66 | | -else: |
67 | | - # first check if the os package has the required flags |
| 54 | + |
| 55 | +# See Note [Global dependencies] |
| 56 | +def _load_global_deps(): |
| 57 | + if platform.system() == 'Windows': |
| 58 | + return |
| 59 | + |
| 60 | + lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so') |
| 61 | + here = os.path.abspath(__file__) |
| 62 | + lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) |
| 63 | + |
| 64 | + ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) |
| 65 | + |
| 66 | + |
| 67 | +if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \ |
| 68 | + platform.system() != 'Windows': |
| 69 | + # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a |
| 70 | + # few circumstances: |
| 71 | + # |
| 72 | + # 1. You're in a build environment (e.g., fbcode) where |
| 73 | + # libtorch_global_deps is not available, but you still need |
| 74 | + # to get mkl to link in with RTLD_GLOBAL or it will just |
| 75 | + # not work. |
| 76 | + # |
| 77 | + # 2. You're trying to run PyTorch under UBSAN and you need |
| 78 | + # to ensure that only one copy of libtorch is loaded, so |
| 79 | + # vptr checks work properly |
| 80 | + # |
| 81 | + # If you're using this setting, you must verify that all the libraries |
| 82 | + # you load consistently use the same libstdc++, or you may have |
| 83 | + # mysterious segfaults. |
| 84 | + # |
| 85 | + import os as _dl_flags |
68 | 86 | if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'): |
69 | 87 | try: |
70 | 88 | # next try if DLFCN exists |
71 | 89 | import DLFCN as _dl_flags |
72 | 90 | except ImportError: |
73 | 91 | # as a last attempt, use compile-time constants |
74 | 92 | import torch._dl as _dl_flags |
75 | | - |
76 | 93 | old_flags = sys.getdlopenflags() |
77 | 94 | sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY) |
| 95 | + from torch._C import * |
| 96 | + sys.setdlopenflags(old_flags) |
| 97 | + del old_flags |
| 98 | + del _dl_flags |
78 | 99 |
|
79 | | -del _dl_flags |
80 | | - |
81 | | -from torch._C import * |
| 100 | +else: |
| 101 | + # Easy way. You want this most of the time, because it will prevent |
| 102 | + # C++ symbols from libtorch clobbering C++ symbols from other |
| 103 | + # libraries, leading to mysterious segfaults. |
| 104 | + # |
| 105 | + # See Note [Global dependencies] |
| 106 | + _load_global_deps() |
| 107 | + from torch._C import * |
82 | 108 |
|
83 | 109 | __all__ += [name for name in dir(_C) |
84 | 110 | if name[0] != '_' and |
85 | 111 | not name.endswith('Base')] |
86 | 112 |
|
87 | | -if platform.system() != 'Windows': |
88 | | - sys.setdlopenflags(old_flags) |
89 | | - del old_flags |
90 | | - |
91 | 113 | ################################################################################ |
92 | 114 | # Define basic utilities |
93 | 115 | ################################################################################ |
|
0 commit comments