Skip to content

Commit 6d5b8fa

Browse files
authored
Merge pull request #120 from HawkAaron/master
Support tf version >= 1.4
2 parents 14858fe + 41b1bc2 commit 6d5b8fa

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

tensorflow_binding/setup.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import setuptools
99
import sys
1010
import unittest
11+
import warnings
1112
from setuptools.command.build_ext import build_ext as orig_build_ext
1213

1314
# We need to import tensorflow to find where its include directory is.
@@ -56,10 +57,31 @@
5657
warp_ctc_includes = [os.path.join(root_path, '../include')]
5758
include_dirs = tf_includes + warp_ctc_includes
5859

59-
extra_compile_args = ['-std=c++11', '-fPIC']
60+
if tf.__version__ >= '1.4':
61+
include_dirs += [tf_include + '/../../external/nsync/public']
62+
63+
if os.getenv("TF_CXX11_ABI") is not None:
64+
TF_CXX11_ABI = os.getenv("TF_CXX11_ABI")
65+
else:
66+
warnings.warn("Assuming tensorflow was compiled without C++11 ABI. "
67+
"It is generally true if you are using binary pip package. "
68+
"If you compiled tensorflow from source with gcc >= 5 and didn't set "
69+
"-D_GLIBCXX_USE_CXX11_ABI=0 during compilation, you need to set "
70+
"environment variable TF_CXX11_ABI=1 when compiling this bindings. "
71+
"Also be sure to touch some files in src to trigger recompilation. "
72+
"Also, you need to set (or unsed) this environment variable if getting "
73+
"undefined symbol: _ZN10tensorflow... errors")
74+
TF_CXX11_ABI = "0"
75+
76+
extra_compile_args = ['-std=c++11', '-fPIC', '-D_GLIBCXX_USE_CXX11_ABI=' + TF_CXX11_ABI]
6077
# current tensorflow code triggers return type errors, silence those for now
6178
extra_compile_args += ['-Wno-return-type']
6279

80+
extra_link_args = []
81+
if tf.__version__ >= '1.4':
82+
if os.path.exists(os.path.join(tf_src_dir, 'libtensorflow_framework.so')):
83+
extra_link_args = ['-L' + tf.sysconfig.get_lib(), '-ltensorflow_framework']
84+
6385
if (enable_gpu):
6486
extra_compile_args += ['-DWARPCTC_ENABLE_GPU']
6587
include_dirs += [os.path.join(os.environ["CUDA_HOME"], 'include')]
@@ -91,8 +113,9 @@
91113
include_dirs = include_dirs,
92114
library_dirs = [warp_ctc_path],
93115
runtime_library_dirs = [os.path.realpath(warp_ctc_path)],
94-
libraries = ['warpctc'],
95-
extra_compile_args = extra_compile_args)
116+
libraries = ['warpctc', 'tensorflow_framework'],
117+
extra_compile_args = extra_compile_args,
118+
extra_link_args = extra_link_args)
96119

97120
class build_tf_ext(orig_build_ext):
98121
def build_extensions(self):

0 commit comments

Comments
 (0)