Skip to content

Add check_op_perf utility to summary op performance automatically #1650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions test/microbench/batch_norm_1d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import torch
from torch.profiler import profile, ProfilerActivity

device = "xpu"
if torch.cuda.is_available():
device = "cuda"
activity = ProfilerActivity.CUDA
table_key = "cuda_time_total"
else:
device = "xpu"
activity = ProfilerActivity.XPU
table_key = "xpu_time_total"


shape_list = [((64, 8), (8)), ((4, 128, 15000), (128)), ((4, 256, 512), (256))]

Expand Down Expand Up @@ -29,12 +37,12 @@
backward,
)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], record_shapes=True
activities=[ProfilerActivity.CPU, activity], record_shapes=True
) as prof:
for i in range(20):
m = torch.nn.BatchNorm1d(shape[1], device=device)
output = m(input)
if backward:
gy = torch.empty_like(output)
output.backward(gy)
print(prof.key_averages().table(sort_by="xpu_time_total"))
print(prof.key_averages().table(sort_by=table_key))
20 changes: 14 additions & 6 deletions test/microbench/batch_norm_2d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import torch
from torch.profiler import profile, ProfilerActivity

device = "xpu"
if torch.cuda.is_available():
device = "cuda"
activity = ProfilerActivity.CUDA
table_key = "cuda_time_total"
else:
device = "xpu"
activity = ProfilerActivity.XPU
table_key = "xpu_time_total"


shape_list = [
(256, 256, 56, 56, 256),
Expand All @@ -20,14 +28,14 @@ def BTN2d(shape, dtype, channels_last, backward):
input = (
torch.randn(N, C, H, W)
.to(memory_format=torch.channels_last)
.to(device="xpu", dtype=dtype)
.to(device=device, dtype=dtype)
)
else:
input = torch.randn(N, C, H, W).to(device="xpu", dtype=dtype)
input = torch.randn(N, C, H, W).to(device=device, dtype=dtype)

if backward:
input.requires_grad_(True)
grad = torch.randn([C, H, W]).to(device="xpu", dtype=dtype)
grad = torch.randn([C, H, W]).to(device=device, dtype=dtype)

BTN = torch.nn.BatchNorm2d(shape[4], device=device)

Expand Down Expand Up @@ -59,9 +67,9 @@ def BTN2d(shape, dtype, channels_last, backward):
backward,
)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.XPU],
activities=[ProfilerActivity.CPU, activity],
record_shapes=True,
) as prof:
for i in range(20):
BTN2d(shape, dtype, channels_last, backward=True)
print(prof.key_averages().table(sort_by="xpu_time_total"))
print(prof.key_averages().table(sort_by=table_key))
20 changes: 14 additions & 6 deletions test/microbench/batch_norm_3d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import torch
from torch.profiler import profile, ProfilerActivity

device = "xpu"
if torch.cuda.is_available():
device = "cuda"
activity = ProfilerActivity.CUDA
table_key = "cuda_time_total"
else:
device = "xpu"
activity = ProfilerActivity.XPU
table_key = "xpu_time_total"


shape_list = [(2, 5, 6, 3, 5, 5), (2, 8, 64, 64, 64, 8), (16, 16, 128, 128, 256, 16)]

Expand All @@ -20,14 +28,14 @@ def BTN3d(shape, dtype, channels_last, backward):
input = (
torch.randn(N, C, D, H, W)
.to(memory_format=torch.channels_last_3d)
.to(device="xpu", dtype=dtype)
.to(device=device, dtype=dtype)
)
else:
input = torch.randn(N, C, D, H, W).to(device="xpu", dtype=dtype)
input = torch.randn(N, C, D, H, W).to(device=device, dtype=dtype)

if backward:
input.requires_grad_(True)
grad = torch.randn([C, D, H, W]).to(device="xpu", dtype=dtype)
grad = torch.randn([C, D, H, W]).to(device=device, dtype=dtype)

BTN = torch.nn.BatchNorm3d(shape[5], device=device)

Expand Down Expand Up @@ -59,9 +67,9 @@ def BTN3d(shape, dtype, channels_last, backward):
backward,
)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.XPU],
activities=[ProfilerActivity.CPU, activity],
record_shapes=True,
) as prof:
for i in range(20):
BTN3d(shape, dtype, channels_last, backward=True)
print(prof.key_averages().table(sort_by="xpu_time_total"))
print(prof.key_averages().table(sort_by=table_key))
16 changes: 13 additions & 3 deletions test/microbench/group_norm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import torch
from torch.profiler import profile, ProfilerActivity

device = "xpu"
if torch.cuda.is_available():
device = "cuda"
activity = ProfilerActivity.CUDA
table_key = "cuda_time_total"
else:
device = "xpu"
activity = ProfilerActivity.XPU
table_key = "xpu_time_total"


backward = True


shape_list = [
(1, 32, 128, 32, 32), # all channel for 1 group
(16, 1024, 128, 32, 32), # normal shape, big memory
Expand Down Expand Up @@ -64,7 +74,7 @@
backward,
)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.XPU],
activities=[ProfilerActivity.CPU, activity],
record_shapes=True,
) as prof:
for i in range(20):
Expand All @@ -73,4 +83,4 @@
if backward:
grad_out = torch.randn_like(output).to(device)
(grad_dpcpp,) = torch.autograd.grad(output, input, grad_out)
print(prof.key_averages().table(sort_by="xpu_time_total"))
print(prof.key_averages().table(sort_by=table_key))
16 changes: 13 additions & 3 deletions test/microbench/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import torch
from torch.profiler import profile, ProfilerActivity

device = "xpu"
if torch.cuda.is_available():
device = "cuda"
activity = ProfilerActivity.CUDA
table_key = "cuda_time_total"
else:
device = "xpu"
activity = ProfilerActivity.XPU
table_key = "xpu_time_total"


backward = True


shape_list = [
((1, 1024), (1024)),
((2, 4096, 320), (4096, 320)),
Expand Down Expand Up @@ -38,12 +48,12 @@
backward,
)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], record_shapes=True
activities=[ProfilerActivity.CPU, activity], record_shapes=True
) as prof:
for i in range(20):
m = torch.nn.LayerNorm(shape[1], device=device, dtype=dtype)
output = m(input)
if backward:
gy = torch.empty_like(output)
output.backward(gy)
print(prof.key_averages().table(sort_by="xpu_time_total"))
print(prof.key_averages().table(sort_by=table_key))
87 changes: 87 additions & 0 deletions tools/check_op_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import subprocess
from pathlib import Path


def find_pytorch_dir():
path = Path(__file__).resolve()
while path != path.root:
if path.name == "torch-xpu-ops":
return str(path)
path = path.parent
return ''


OP_LIST = {
'layer_norm.py': ['aten::native_layer_norm', 'aten::native_layer_norm_backward'],
'group_norm.py': ['aten::native_group_norm', 'aten::native_group_norm_backward'],
'batch_norm_1d.py': [('aten::native_batch_norm', 'aten::cudnn_batch_norm'),
('aten::native_batch_norm_backward', 'aten::cudnn_batch_norm_backward')],
'batch_norm_2d.py': [('aten::native_batch_norm', 'aten::cudnn_batch_norm'),
('aten::native_batch_norm_backward', 'aten::cudnn_batch_norm_backward')],
# 'batch_norm_3d.py': ['aten::native_batch_norm', 'aten::native_batch_norm_backward'],
}


def find_op_time(text, ops):
res = []

def transform_to_us(time):
if time.endswith('us'):
return float(time[:-2])
elif time.endswith('ms'):
return float(time[:-2]) * 1000.0
elif time.endswith('s'):
return float(time[:-1]) * 1000000.0
else:
raise Exception("time format not support")
flag = "None"
print(text)
for line in text.split('\n'):
line = line.strip()
if line.startswith('shape:'):
flag = line
for op in ops:
if not isinstance(op, tuple):
op = (op,)
op_base_name = op[0]
for op_alias in op:
if op_alias in line:
items = []
for item in line.strip().split(' '):
if len(item) > 1:
items.append(item.strip())
if items[0].strip() == op_alias:
op_time = transform_to_us(items[-2])
res.append([op_base_name, flag, str(op_time)])
res_ = ["@@".join(item) for item in res]
res_ = list(set(res_))
res = [item.split("@@") for item in res_]
res = sorted(res, key=lambda x: x[1])
res = sorted(res, key=lambda x: x[0])
return res


if __name__ == '__main__':
root_folder = find_pytorch_dir().strip()
perf_suit = os.path.join(root_folder, 'test/microbench/')
import csv
csv_data = [
["Operator", "Tag", "Latency(us)"],
]
for item, ops in OP_LIST.items():
print(item)
f = os.path.join(perf_suit, item)
result = subprocess.run(
["python", f],
capture_output=True,
text=True
)
output = result.stdout
res = find_op_time(output, ops)
csv_data += res
for item in res:
print(item)
with open("check_op_perf.csv", mode="w", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
writer.writerows(csv_data)
Loading