Skip to content

Commit 0611c41

Browse files
bertmaherfacebook-github-bot
authored andcommitted
Skip loading triton.nvidia.cublas if not found
Summary: We have an old triton internally that doesn't have the cublasLt bindings Reviewed By: adamomainz Differential Revision: D63643619 fbshipit-source-id: 39aece74b52f7747fe2100d7bb905bad49ba1fa0
1 parent b6b67a4 commit 0611c41

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

torchbenchmark/operators/fp8_gemm/persistent.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
import triton.language as tl
66
import triton.tools.experimental_descriptor
77

8+
cublas = None
89
if torch.cuda.is_available():
9-
from triton._C.libtriton import nvidia
10+
try:
11+
from triton._C.libtriton import nvidia
1012

11-
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
12-
cublas = nvidia.cublas.CublasLt(cublas_workspace)
13-
else:
14-
cublas = None
13+
cublas_workspace = torch.empty(
14+
32 * 1024 * 1024, device="cuda", dtype=torch.uint8
15+
)
16+
cublas = nvidia.cublas.CublasLt(cublas_workspace)
17+
except (ImportError, IOError, AttributeError):
18+
pass
1519

1620

1721
def is_cuda():

0 commit comments

Comments
 (0)