Skip to content

Commit 934bbc9

Browse files
authored
Merge branch 'master' into gma/add_autotp_workflow
2 parents e5b0b18 + d5fa87f commit 934bbc9

File tree

30 files changed

+565
-131
lines changed

30 files changed

+565
-131
lines changed

Diff for: .github/workflows/nv-inference.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232

3333
- name: Install pytorch
3434
run: |
35-
pip install -U --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116
35+
pip install -U --cache-dir $TORCH_CACHE torch==2.1.2 torchvision==0.16.2 --extra-index-url https://download.pytorch.org/whl/cu118
3636
python -c "import torch; print('torch:', torch.__version__, torch)"
3737
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
3838
@@ -57,6 +57,6 @@ jobs:
5757
run: |
5858
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
5959
cd tests
60-
#pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="1.13" --cuda_ver="11.6"
61-
pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="1.13" --cuda_ver="11.6"
62-
pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="1.13" --cuda_ver="11.6"
60+
#pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="2.1" --cuda_ver="11.8"
61+
pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="2.1" --cuda_ver="11.8"
62+
pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="2.1" --cuda_ver="11.8"

Diff for: accelerator/cuda_accelerator.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# DeepSpeed Team
55

6+
import functools
67
import os
78
import pkgutil
89
import importlib
@@ -260,31 +261,31 @@ def replay_graph(self, graph):
260261

261262
@property
262263
def BFloat16Tensor(self):
263-
return torch.cuda.BFloat16Tensor
264+
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='cuda')
264265

265266
@property
266267
def ByteTensor(self):
267-
return torch.cuda.ByteTensor
268+
return functools.partial(torch.tensor, dtype=torch.uint8, device='cuda')
268269

269270
@property
270271
def DoubleTensor(self):
271-
return torch.cuda.DoubleTensor
272+
return functools.partial(torch.tensor, dtype=torch.double, device='cuda')
272273

273274
@property
274275
def FloatTensor(self):
275-
return torch.cuda.FloatTensor
276+
return functools.partial(torch.tensor, dtype=torch.float, device='cuda')
276277

277278
@property
278279
def HalfTensor(self):
279-
return torch.cuda.HalfTensor
280+
return functools.partial(torch.tensor, dtype=torch.half, device='cuda')
280281

281282
@property
282283
def IntTensor(self):
283-
return torch.cuda.IntTensor
284+
return functools.partial(torch.tensor, dtype=torch.int, device='cuda')
284285

285286
@property
286287
def LongTensor(self):
287-
return torch.cuda.LongTensor
288+
return functools.partial(torch.tensor, dtype=torch.long, device='cuda')
288289

289290
def pin_memory(self, tensor, align_bytes=1):
290291
return tensor.pin_memory()

Diff for: deepspeed/inference/v2/checkpoint/huggingface_engine.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,13 @@ def _fetch_checkpoint_files(self):
4040
# currently coming from the ckpt engine init but maybe a catch all kwargs for other
4141
# snapshot download parameters would be more flexible.
4242

43-
# NOTE(jeff): allow_patterns here are explicitly not using safetensors or other
44-
# checkpoint files that may be present. Example of all files in the llama-2-7b
45-
# repo here: https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main
46-
from huggingface_hub import snapshot_download, list_files_info
43+
from huggingface_hub import snapshot_download, list_repo_tree
4744

4845
def model_has_safetensors(model_name_or_path: str) -> bool:
4946
if os.path.isdir(model_name_or_path):
5047
file_list = os.listdir(model_name_or_path)
5148
else:
52-
file_list = [rf.rfilename for rf in list_files_info(model_name_or_path)]
49+
file_list = [rf.path for rf in list_repo_tree(model_name_or_path)]
5350
for f in file_list:
5451
if f.endswith(".safetensors"):
5552
return True

Diff for: deepspeed/profiling/flops_profiler/profiler.py

+13
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,15 @@ def _elementwise_flops_compute(input, other):
827827
return flops, 0
828828

829829

830+
def _attn_flops_compute(q, k, v, *args, **kwargs):
831+
"""
832+
Count flops for the scaled_dot_product_attention operation.
833+
"""
834+
macs = _prod(q.shape) * k.shape[-2]
835+
macs += _prod(q.shape[:-1]) * k.shape[-2] * v.shape[-1]
836+
return 2 * macs, macs
837+
838+
830839
def wrapFunc(func, funcFlopCompute):
831840
oldFunc = func
832841
name = func.__str__
@@ -899,10 +908,14 @@ def _patch_functionals():
899908
# embedding
900909
F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)
901910

911+
# attn
912+
F.scaled_dot_product_attention = wrapFunc(F.scaled_dot_product_attention, _attn_flops_compute)
913+
902914

903915
def _patch_tensor_methods():
904916
torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute)
905917
torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute)
918+
torch.Tensor.__matmul__ = wrapFunc(torch.Tensor.__matmul__, _matmul_flops_compute)
906919
torch.mm = wrapFunc(torch.mm, _matmul_flops_compute)
907920
torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute)
908921
torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute)

0 commit comments

Comments
 (0)