Skip to content

Commit d254d75

Browse files
YizhouZloadamstjruwase
authored
[XPU] support op builder from intel_extension_for_pytorch kernel path (#5425)
#Motivation From our next release, xpu DeepSpeed related kernels would be put into intel_extension_for_pytorch. This PR is to add new op builders and use kernel path from intel_extension_for_pytorch. More ops like MOE and WOQ will be added. --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 0c979d6 commit d254d75

File tree

6 files changed

+149
-25
lines changed

6 files changed

+149
-25
lines changed

accelerator/xpu_accelerator.py

+22-24
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
1010
import functools
1111

12+
import importlib
13+
import inspect
14+
1215

1316
class XPU_Accelerator(DeepSpeedAccelerator):
1417

@@ -17,6 +20,7 @@ def __init__(self):
1720
self._communication_backend_name = 'ccl'
1821
self._compile_backend = "inductor"
1922
self.aligned_tensors = []
23+
self.class_dict = None
2024

2125
def is_synchronized_device(self):
2226
return False
@@ -257,35 +261,29 @@ def on_accelerator(self, tensor):
257261
else:
258262
return False
259263

264+
def _lazy_init_class_dict(self):
265+
if self.class_dict:
266+
return
267+
268+
op_builder_module = importlib.import_module(self.op_builder_dir())
269+
270+
# get op builder class from op_builder/xpu/__init__.py
271+
self.class_dict = {}
272+
for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass):
273+
self.class_dict[class_name] = class_obj
274+
260275
# create an instance of op builder and return, name specified by class_name
261-
def create_op_builder(self, op_name):
262-
builder_class = self.get_op_builder(op_name)
263-
if builder_class != None:
264-
return builder_class()
265-
return None
276+
def create_op_builder(self, class_name):
277+
builder_class = self.get_op_builder(class_name)
278+
return builder_class()
266279

267280
# return an op builder class, name specified by class_name
268281
def get_op_builder(self, class_name):
269-
try:
270-
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
271-
# if successful this also means we're doing a local install and not JIT compile path
272-
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
273-
from op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder, PackbitsBuilder
274-
except ImportError:
275-
from deepspeed.ops.op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder, PackbitsBuilder
276-
277-
if class_name == "AsyncIOBuilder":
278-
return AsyncIOBuilder
279-
elif class_name == "CPUAdagradBuilder":
280-
return CPUAdagradBuilder
281-
elif class_name == "CPUAdamBuilder":
282-
return CPUAdamBuilder
283-
elif class_name == "FusedAdamBuilder":
284-
return FusedAdamBuilder
285-
elif class_name == "PackbitsBuilder":
286-
return PackbitsBuilder
282+
self._lazy_init_class_dict()
283+
if class_name in self.class_dict:
284+
return self.class_dict[class_name]
287285
else:
288-
return None
286+
return self.class_dict['NotImplementedBuilder']
289287

290288
def build_extension(self):
291289
try:

deepspeed/comm/ccl.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88

99
import torch
1010
from deepspeed.accelerator import get_accelerator
11+
from deepspeed.ops.op_builder import NotImplementedBuilder
1112
from .reduce_op import ReduceOp
1213
from .torch import TorchBackend
1314

1415

1516
def build_ccl_op():
1617
builder = get_accelerator().create_op_builder("CCLCommBuilder")
17-
if builder is None:
18+
if builder is None or NotImplementedBuilder:
1819
return None
1920
ccl_cpp_module = builder.load()
2021
print(f'DeepSpeed {builder.absolute_name()} built successfully')

op_builder/xpu/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@
77
from .cpu_adagrad import CPUAdagradBuilder
88
from .fused_adam import FusedAdamBuilder
99
from .async_io import AsyncIOBuilder
10+
from .inference import InferenceBuilder
11+
from .flash_attn import FlashAttentionBuilder
12+
from .no_impl import NotImplementedBuilder
1013
from .packbits import PackbitsBuilder

op_builder/xpu/flash_attn.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
from .builder import SYCLOpBuilder
6+
7+
8+
class FlashAttentionBuilderObject():
9+
10+
def __init__(self):
11+
pass
12+
13+
# general functions
14+
def flash_attn_func_v2(self, q, k, v, dropout_p, softmax_scale, is_causal):
15+
try:
16+
import torch
17+
import intel_extension_for_pytorch # noqa
18+
return torch.nn.functional.scaled_dot_product_attention(q,
19+
k,
20+
v,
21+
dropout_p=dropout_p,
22+
is_causal=is_causal,
23+
scale=softmax_scale)
24+
except ImportError:
25+
raise ImportError(
26+
"Please install pytorch and intel_extension_for_pytorch to include scaled dot product attention.")
27+
28+
29+
class FlashAttentionBuilder(SYCLOpBuilder):
30+
BUILD_VAR = "DS_BUILD_FlashAttention"
31+
NAME = "flash_attn"
32+
33+
def __init__(self, name=None):
34+
name = self.NAME if name is None else name
35+
super().__init__(name=name)
36+
37+
def absolute_name(self):
38+
return f'deepspeed.ops.{self.NAME}_op'
39+
40+
def sources(self):
41+
return
42+
43+
def include_paths(self):
44+
return []
45+
46+
def extra_ldflags(self):
47+
return []
48+
49+
def cxx_args(self):
50+
return []
51+
52+
def load(self):
53+
return FlashAttentionBuilderObject()

op_builder/xpu/inference.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
from .builder import SYCLOpBuilder
6+
7+
8+
class InferenceBuilder(SYCLOpBuilder):
9+
BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE"
10+
NAME = "transformer_inference"
11+
12+
def __init__(self, name=None):
13+
name = self.NAME if name is None else name
14+
super().__init__(name=name)
15+
16+
def absolute_name(self):
17+
return f'deepspeed.ops.transformer.inference.{self.NAME}_op'
18+
19+
def sources(self):
20+
return
21+
22+
def include_paths(self):
23+
return []
24+
25+
def extra_ldflags(self):
26+
return []
27+
28+
def cxx_args(self):
29+
return []
30+
31+
def load(self):
32+
try:
33+
import intel_extension_for_pytorch.deepspeed
34+
return intel_extension_for_pytorch.deepspeed.transformer_inference.transformer_inference
35+
except ImportError:
36+
raise ImportError("Please install intel-extension-for-pytorch >= 2.1.30 to include DeepSpeed kernels.")

op_builder/xpu/no_impl.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from .builder import SYCLOpBuilder
7+
8+
9+
class NotImplementedBuilder(SYCLOpBuilder):
10+
BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED"
11+
NAME = "deepspeed_not_implemented"
12+
13+
def __init__(self, name=None):
14+
name = self.NAME if name is None else name
15+
super().__init__(name=name)
16+
17+
def absolute_name(self):
18+
return f'deepspeed.ops.{self.NAME}_op'
19+
20+
def load(self, verbose=True):
21+
raise ValueError("This op had not been implemented on XPU backend.")
22+
23+
def sources(self):
24+
return []
25+
26+
def cxx_args(self):
27+
return []
28+
29+
def extra_ldflags(self):
30+
return []
31+
32+
def include_paths(self):
33+
return []

0 commit comments

Comments
 (0)