|
8 | 8 | import setuptools
|
9 | 9 | import sys
|
10 | 10 | import unittest
|
| 11 | +import warnings |
11 | 12 | from setuptools.command.build_ext import build_ext as orig_build_ext
|
12 | 13 |
|
13 | 14 | # We need to import tensorflow to find where its include directory is.
|
|
56 | 57 | warp_ctc_includes = [os.path.join(root_path, '../include')]
|
57 | 58 | include_dirs = tf_includes + warp_ctc_includes
|
58 | 59 |
|
59 |
| -extra_compile_args = ['-std=c++11', '-fPIC'] |
| 60 | +if tf.__version__ >= '1.4': |
| 61 | + include_dirs += [tf_include + '/../../external/nsync/public'] |
| 62 | + |
| 63 | +if os.getenv("TF_CXX11_ABI") is not None: |
| 64 | + TF_CXX11_ABI = os.getenv("TF_CXX11_ABI") |
| 65 | +else: |
| 66 | + warnings.warn("Assuming tensorflow was compiled without C++11 ABI. " |
| 67 | + "It is generally true if you are using binary pip package. " |
| 68 | + "If you compiled tensorflow from source with gcc >= 5 and didn't set " |
| 69 | + "-D_GLIBCXX_USE_CXX11_ABI=0 during compilation, you need to set " |
| 70 | + "environment variable TF_CXX11_ABI=1 when compiling this bindings. " |
| 71 | + "Also be sure to touch some files in src to trigger recompilation. " |
| 72 | + "Also, you need to set (or unsed) this environment variable if getting " |
| 73 | + "undefined symbol: _ZN10tensorflow... errors") |
| 74 | + TF_CXX11_ABI = "0" |
| 75 | + |
| 76 | +extra_compile_args = ['-std=c++11', '-fPIC', '-D_GLIBCXX_USE_CXX11_ABI=' + TF_CXX11_ABI] |
60 | 77 | # current tensorflow code triggers return type errors, silence those for now
|
61 | 78 | extra_compile_args += ['-Wno-return-type']
|
62 | 79 |
|
| 80 | +extra_link_args = [] |
| 81 | +if tf.__version__ >= '1.4': |
| 82 | + if os.path.exists(os.path.join(tf_src_dir, 'libtensorflow_framework.so')): |
| 83 | + extra_link_args = ['-L' + tf.sysconfig.get_lib(), '-ltensorflow_framework'] |
| 84 | + |
63 | 85 | if (enable_gpu):
|
64 | 86 | extra_compile_args += ['-DWARPCTC_ENABLE_GPU']
|
65 | 87 | include_dirs += [os.path.join(os.environ["CUDA_HOME"], 'include')]
|
|
91 | 113 | include_dirs = include_dirs,
|
92 | 114 | library_dirs = [warp_ctc_path],
|
93 | 115 | runtime_library_dirs = [os.path.realpath(warp_ctc_path)],
|
94 |
| - libraries = ['warpctc'], |
95 |
| - extra_compile_args = extra_compile_args) |
| 116 | + libraries = ['warpctc', 'tensorflow_framework'], |
| 117 | + extra_compile_args = extra_compile_args, |
| 118 | + extra_link_args = extra_link_args) |
96 | 119 |
|
97 | 120 | class build_tf_ext(orig_build_ext):
|
98 | 121 | def build_extensions(self):
|
|
0 commit comments