Skip to content

Commit 8effab6

Browse files
author
Jian Weng
committed
cgo ae update
1 parent f7d54fa commit 8effab6

File tree

11 files changed

+119
-85
lines changed

11 files changed

+119
-85
lines changed

apps/cpu/kernel/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
input

apps/cpu/kernel/conv2d.py

+36-24
Original file line numberDiff line numberDiff line change
@@ -10,38 +10,43 @@
1010
import topi
1111
from tvm.relay import op
1212

13-
##ic, h, w, oc, _, kh, kw, sh, sw = map(int, input().split())
14-
#ic, h, w, oc, kh, sh = map(int, input().split())
15-
#kw = kh
16-
#sw = sh
17-
#
18-
#if ic % 4:
19-
# ic += 4 - ic % 4
20-
#
21-
#if oc % 16:
22-
# oc += 16 - oc % 16
23-
#
24-
#a = tvm.te.placeholder((1, ic // 16, h, w, 16), dtype='int8')
25-
#if ic % 16 == 0:
26-
# b = tvm.te.placeholder((oc // 16, ic // 16, kh, kw, 4, 16, 4), dtype='int8')
27-
#else:
28-
# assert ic % 4 == 0
29-
# a = tvm.te.placeholder((1, ic // 4, h, w, 4), dtype='int8')
30-
# b = tvm.te.placeholder((oc // 16, ic // 4, kh, kw, 1, 16, 4), dtype='int8')
13+
#ic, h, w, oc, _, kh, kw, sh, sw = map(int, input().split())
14+
_, ic, h, w, oc, _, kh, _, sh, _ = map(int, input().split())
15+
kw = kh
16+
sw = sh
3117

32-
N, C, H, W, c, O, I, KH, KW, e, o, i, sh, sw = map(int, input().split())
18+
if ic % 4:
19+
ic += 4 - ic % 4
3320

34-
a = tvm.te.placeholder((N, C, H, W, c), dtype='int8')
35-
b = tvm.te.placeholder((O, I, KH, KW, e, o, i), dtype='int8')
21+
if oc % 16:
22+
oc += 16 - oc % 16
23+
24+
a = tvm.te.placeholder((1, ic // 16, h, w, 16), dtype='int8')
25+
if ic % 16 == 0:
26+
b = tvm.te.placeholder((oc // 16, ic // 16, kh, kw, 4, 16, 4), dtype='int8')
27+
else:
28+
assert ic % 4 == 0
29+
a = tvm.te.placeholder((1, ic // 4, h, w, 4), dtype='int8')
30+
b = tvm.te.placeholder((oc // 16, ic // 4, kh, kw, 1, 16, 4), dtype='int8')
31+
32+
#N, C, H, W, c, O, I, KH, KW, e, o, i, sh, sw = map(int, input().split())
33+
34+
#a = tvm.te.placeholder((N, C, H, W, c), dtype='int8')
35+
#b = tvm.te.placeholder((O, I, KH, KW, e, o, i), dtype='int8')
3636

3737
passes = [(1, tensorizer.rewrite)]
3838
from tensorizer import tune
39-
tune.cpu_idx = 0
39+
tune.cpu_idx = -1
4040
target = -1
4141
results = []
42-
virgin = True
42+
result = 1e9
43+
4344
while True:
4445
with tvm.transform.PassContext(opt_level=3, config={'tir.add_lower_pass': passes}), tvm.target.create('llvm -mcpu=cascadelake'):
46+
if tune.cpu_idx == -1:
47+
tune.cpu_idx = 0
48+
tune.parallel_only = True
49+
4550
conv = topi.nn.conv2d_NCHWc_int8(a, b, stride=(sh, sw), padding=0, dilation=1, out_dtype='int32',
4651
layout='NCHW4c', out_layout='NCHW16c')
4752
sch = tensorizer.INTRINSICS['vnni']['schedule']([conv], (sh, sw))
@@ -59,9 +64,16 @@
5964
res = fte(nd_a, nd_b, nd_c)
6065
results.append(res.mean)
6166

67+
if tune.parallel_only:
68+
tune.cpu_idx = -1
69+
tune.parallel_only = False
70+
71+
if res.mean < result:
72+
target = tune.cpu_idx
73+
result = res.mean
74+
6275
relay.backend.compile_engine.get().clear()
6376
tune.cpu_idx += 1
64-
break
6577
if tune.cpu_idx - target > 8:
6678
break
6779
if tune.cpu_idx >= tune.total_idx:

apps/cpu/kernel/cpu-tune.log

+16
Original file line numberDiff line numberDiff line change
@@ -430,3 +430,19 @@
430430
(1, 2, 112, 112, 16) (1, 2, 1, 1, 4, 16, 4) (1, 1) [5.283333333333333e-06]
431431
(1, 128, 7, 7, 16) (32, 128, 1, 1, 4, 16, 4) (1, 1) [6.08829e-05]
432432
(1, 36, 14, 14, 16) (4, 36, 1, 1, 4, 16, 4) (1, 1) [1.0045566666666667e-05]
433+
(1, 18, 35, 35, 16) (24, 18, 3, 3, 4, 16, 4) (2, 2) [0.0003514215333333333, 9.668606666666667e-05, 0.00011963883333333332, 0.0001122919, 0.0002835423, 8.759403333333333e-05, 0.0001860404, 0.00014346666666666667, 0.00039427413333333327, 0.00039437946666666667, 0.00016885596666666667, 0.00028357029999999996, 0.000281523, 0.0004038882]
434+
(1, 10, 9, 9, 16) (14, 10, 3, 3, 4, 16, 4) (1, 1) [2.265226666666667e-05, 8.446466666666666e-06, 1.0891100000000001e-05, 1.31841e-05, 1.0353866666666666e-05, 1.2766766666666668e-05, 1.1706900000000001e-05, 1.1118366666666667e-05, 2.4449199999999996e-05, 2.41494e-05]
435+
(1, 66, 7, 7, 16) (12, 66, 1, 1, 4, 16, 4) (1, 1) [1.4224633333333333e-05, 6.878800000000001e-06, 8.318566666666667e-06, 9.253066666666668e-06, 8.469466666666667e-06, 9.702333333333332e-06, 8.547600000000001e-06, 9.105799999999998e-06, 1.4786199999999999e-05, 1.5234766666666663e-05]
436+
(1, 5, 73, 73, 16) (12, 5, 3, 3, 4, 16, 4) (1, 1) [0.0007445264666666666, 0.00036293423333333337, 0.0004282294666666667, 0.00040284786666666667, 0.00036327886666666663, 0.00042169116666666666, 0.00047538076666666675, 0.00042853806666666666, 0.0004993923, 0.00036451126666666665]
437+
(1, 8, 16, 16, 16) (8, 8, 3, 3, 4, 16, 4) (1, 1) [3.388693333333333e-05, 1.1279599999999998e-05, 1.0211000000000003e-05, 1.0718e-05, 4.224366666666667e-05, 2.1665866666666665e-05, 1.1584133333333333e-05, 3.2161733333333334e-05, 3.2217433333333334e-05, 3.2244733333333326e-05, 3.0116033333333334e-05]
438+
(1, 12, 16, 16, 16) (12, 12, 3, 3, 4, 16, 4) (1, 1) [6.998316666666667e-05, 1.8598966666666666e-05, 1.8014533333333332e-05, 1.8093833333333333e-05, 7.989396666666665e-05, 4.3606333333333336e-05, 1.838603333333333e-05, 7.076240000000001e-05, 7.021e-05, 7.07758e-05, 5.9090766666666673e-05]
439+
(1, 16, 16, 16, 16) (16, 16, 3, 3, 4, 16, 4) (1, 1) [0.0001271999, 2.9112400000000003e-05, 2.99537e-05, 3.089303333333333e-05, 0.00012914833333333334, 7.562293333333334e-05, 2.9749866666666664e-05, 0.00012641846666666662, 0.0001263090666666667, 0.00012664806666666664]
440+
(1, 64, 14, 14, 16) (32, 64, 1, 1, 4, 16, 4) (1, 1) [0.00010905626666666667, 2.81171e-05, 2.7718833333333336e-05, 2.8254033333333328e-05, 0.0001004875, 7.107966666666667e-05, 2.9853599999999995e-05, 9.300396666666667e-05, 0.00010172613333333334, 0.00010611056666666666, 3.0567733333333336e-05]
441+
(1, 8, 16, 16, 16) (10, 8, 3, 3, 4, 16, 4) (1, 1) [4.00728e-05, 1.3028833333333335e-05, 1.2608533333333333e-05, 1.2518066666666666e-05, 4.4550733333333334e-05, 2.5933866666666666e-05, 1.2598300000000001e-05, 4.01845e-05, 4.0060066666666666e-05, 3.942363333333333e-05, 3.655853333333333e-05, 3.844476666666666e-05]
442+
(1, 36, 14, 14, 16) (12, 36, 1, 1, 4, 16, 4) (1, 1) [2.2557133333333335e-05, 9.7363e-06, 1.0471766666666667e-05, 9.455833333333335e-06, 2.3511866666666668e-05, 1.597273333333333e-05, 9.989866666666667e-06, 2.2900466666666662e-05, 2.26288e-05, 2.29245e-05, 2.1308e-05, 2.338776666666667e-05]
443+
(1, 6, 16, 16, 16) (8, 6, 3, 3, 4, 16, 4) (1, 1) [2.5126666666666668e-05, 9.6417e-06, 9.640433333333333e-06, 9.442599999999998e-06, 3.2072099999999994e-05, 1.667773333333333e-05, 9.598066666666667e-06, 2.48075e-05, 2.417616666666667e-05, 2.4684066666666666e-05, 2.3512433333333335e-05, 2.6022999999999998e-05]
444+
(1, 64, 14, 14, 16) (16, 64, 1, 1, 4, 16, 4) (1, 1) [5.522673333333333e-05, 1.6551866666666663e-05, 1.7416066666666667e-05, 1.63774e-05, 5.077093333333333e-05, 3.9541566666666674e-05, 1.7163833333333333e-05, 5.5439433333333343e-05, 5.500383333333334e-05, 5.518076666666666e-05, 5.034303333333334e-05, 5.2097466666666674e-05]
445+
(1, 36, 14, 14, 16) (8, 36, 1, 1, 4, 16, 4) (1, 1) [1.65873e-05, 6.972e-06, 7.9898e-06, 8.129266666666667e-06, 1.7917033333333335e-05, 1.3090033333333333e-05, 8.531433333333334e-06, 1.59983e-05, 1.6602433333333336e-05, 1.6452366666666665e-05]
446+
(1, 4, 29, 29, 16) (6, 4, 3, 3, 4, 16, 4) (1, 1) [3.99176e-05, 2.10778e-05, 3.882036666666667e-05, 1.9675066666666664e-05, 2.0670900000000002e-05, 3.98366e-05, 3.8784699999999995e-05, 3.0413866666666667e-05, 2.063113333333333e-05, 3.92953e-05, 2.0253333333333335e-05, 4.06086e-05]
447+
(1, 4, 56, 56, 16) (8, 4, 1, 1, 4, 16, 4) (2, 2) [5.222933333333333e-06, 6.210866666666667e-06, 5.7302e-06, 6.315266666666668e-06, 1.03406e-05, 6.2129e-06, 5.113633333333333e-06, 6.020233333333333e-06, 7.795133333333333e-06, 7.409666666666667e-06, 6.312966666666667e-06, 5.339066666666666e-06, 1.11085e-05, 5.6410666666666675e-06, 6.255299999999999e-06]
448+
(1, 38, 14, 14, 16) (12, 38, 1, 1, 4, 16, 4) (1, 1) [2.3831933333333335e-05, 8.716766666666668e-06, 9.773e-06, 9.623833333333333e-06, 2.4328300000000002e-05, 1.7445633333333328e-05, 1.0776666666666668e-05, 2.384416666666666e-05, 2.3965666666666667e-05, 2.4390266666666666e-05]

apps/cpu/kernel/input

-1
This file was deleted.

apps/cpu/kernel/run.py

+17-23
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,28 @@
1515
#(1024,15,15,2048,1024,1,1,2,2),
1616
#(128,65,65,256,128,3,3,2,2)]
1717

18-
workloads = [
19-
(256, 16, 16, 256, 3, 1),
20-
(512, 9, 9, 512, 3, 1),
21-
(128, 30, 30, 128, 3, 1),
22-
(64, 56, 56, 128, 1, 2),
23-
(64, 58, 58, 64, 3, 1),
24-
(128, 28, 28, 256, 1, 2),
25-
(256, 16, 16, 512, 3, 2),
26-
(64, 58, 58, 128, 3, 2),
27-
(4, 230, 230, 64, 7, 2),
28-
(128, 30, 30, 256, 3, 2),
29-
(256, 14, 14, 512, 1, 2)
30-
]
18+
#workloads = [
19+
#(256, 16, 16, 256, 3, 1),
20+
#(512, 9, 9, 512, 3, 1),
21+
#(128, 30, 30, 128, 3, 1),
22+
#(64, 56, 56, 128, 1, 2),
23+
#(64, 58, 58, 64, 3, 1),
24+
#(128, 28, 28, 256, 1, 2),
25+
#(256, 16, 16, 512, 3, 2),
26+
#(64, 58, 58, 128, 3, 2),
27+
#(4, 230, 230, 64, 7, 2),
28+
#(128, 30, 30, 256, 3, 2),
29+
#(256, 14, 14, 512, 1, 2)
30+
#]
31+
32+
workloads = [(1, 288, 35, 35, 384, 288, 3, 3, 2, 2), (1, 160, 9, 9, 224, 160, 3, 3, 1, 1), (1, 1056, 7, 7, 192, 1056, 1, 1, 1, 1), (1, 80, 73, 73, 192, 80, 3, 3, 1, 1), (1, 128, 16, 16, 128, 128, 3, 3, 1, 1), (1, 192, 16, 16, 192, 192, 3, 3, 1, 1), (1, 256, 16, 16, 256, 256, 3, 3, 1, 1), (1, 1024, 14, 14, 512, 1024, 1, 1, 1, 1), (1, 128, 16, 16, 160, 128, 3, 3, 1, 1), (1, 576, 14, 14, 192, 576, 1, 1, 1, 1), (1, 96, 16, 16, 128, 96, 3, 3, 1, 1), (1, 1024, 14, 14, 256, 1024, 1, 1, 1, 1), (1, 576, 14, 14, 128, 576, 1, 1, 1, 1), (1, 64, 29, 29, 96, 64, 3, 3, 1, 1), (1, 64, 56, 56, 128, 64, 1, 1, 2, 2), (1, 608, 14, 14, 192, 608, 1, 1, 1, 1)]
3133

3234
for i in workloads:
3335
exec_time = []
3436
with open('input', 'w') as f:
3537
f.write(' '.join(map(str, i)))
3638
try:
37-
avg = []
38-
for j in range(10):
39-
res = subprocess.check_output('python ./conv2d.py < input', shell=True).decode('utf-8')
40-
for k in res.split('\n'):
41-
if k.startswith('exec: '):
42-
res = float(k.strip('exec: '))
43-
avg.append(res)
44-
exec_time.append(sum(avg) / len(avg))
39+
res = subprocess.check_output('python ./conv2d.py < input', shell=True).decode('utf-8')
40+
print(i, 'done')
4541
except:
4642
print(i, 'fails')
47-
print('!!!', i, min(exec_time))
48-

apps/cpu/relay/conv2d.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def tracer(module, info, is_before):
3939
result = 1e9
4040
target = -1
4141
from tensorizer import tune
42-
tune.cpu_idx = 0
42+
tune.cpu_idx = -1
43+
results = []
44+
4345
while True:
4446
with tvm.transform.PassContext(opt_level=3, trace=tracer, config={'tir.add_lower_pass': [(1, tensorizer.rewrite)]}):
4547
graph, lib, params = tvm.relay.build(module, target='llvm -mcpu=cascadelake')
@@ -52,6 +54,7 @@ def tracer(module, info, is_before):
5254

5355
timer = func.module.time_evaluator('run', ctx=tvm.cpu(0), number=3, repeat=10)
5456
timed = timer()
57+
results.append(timed.mean)
5558

5659
if timed.mean < result:
5760
result = timed.mean
@@ -64,7 +67,7 @@ def tracer(module, info, is_before):
6467
if tune.cpu_idx >= tune.total_idx:
6568
break
6669

67-
with open('/home/ubuntu/Tensorization-PoC/cpu-tune.log', 'a') as f:
70+
with open('./cpu-tune.log', 'a') as f:
6871
f.write(f'{tune.ashape} {tune.bshape} {tune.strides} {target}\n')
6972

70-
print(result, target, tune.cpu_idx)
73+
#print(result, target, tune.cpu_idx)

poc/vnni/input

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
256
1+
608
22
14
33
14
4-
512
4+
192
5+
1
6+
1
57
1
68
1
7-
2
8-
2

poc/vnni/run.py

+25-25
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,36 @@
1414
#(1024,15,15,2048,1024,1,1,2,2),
1515
#(128,65,65,256,128,3,3,2,2)]
1616
#
17-
#workloads = [(1, 288, 35, 35, 384, 288, 3, 3, 2, 2), (1, 160, 9, 9, 224, 160, 3, 3, 1, 1), (1, 1056, 7, 7, 192, 1056, 1, 1, 1, 1), (1, 80, 73, 73, 192, 80, 3, 3, 1, 1), (1, 128, 16, 16, 128, 128, 3, 3, 1, 1), (1, 192, 16, 16, 192, 192, 3, 3, 1, 1), (1, 256, 16, 16, 256, 256, 3, 3, 1, 1), (1, 1024, 14, 14, 512, 1024, 1, 1, 1, 1), (1, 128, 16, 16, 160, 128, 3, 3, 1, 1), (1, 576, 14, 14, 192, 576, 1, 1, 1, 1), (1, 96, 16, 16, 128, 96, 3, 3, 1, 1), (1, 1024, 14, 14, 256, 1024, 1, 1, 1, 1), (1, 576, 14, 14, 128, 576, 1, 1, 1, 1), (1, 64, 29, 29, 96, 64, 3, 3, 1, 1), (1, 64, 56, 56, 128, 64, 1, 1, 2, 2), (1, 608, 14, 14, 192, 608, 1, 1, 1, 1)]
18-
19-
workloads = [
20-
(256, 16, 16, 256, 3, 1),
21-
(512, 9, 9, 512, 3, 1),
22-
(128, 30, 30, 128, 3, 1),
23-
(64, 56, 56, 128, 1, 2),
24-
(64, 58, 58, 64, 3, 1),
25-
(128, 28, 28, 256, 1, 2),
26-
(256, 16, 16, 512, 3, 2),
27-
(64, 58, 58, 128, 3, 2),
28-
(4, 230, 230, 64, 7, 2),
29-
(128, 30, 30, 256, 3, 2),
30-
(256, 14, 14, 512, 1, 2)
31-
]
17+
workloads = [(1, 288, 35, 35, 384, 288, 3, 3, 2, 2), (1, 160, 9, 9, 224, 160, 3, 3, 1, 1), (1, 1056, 7, 7, 192, 1056, 1, 1, 1, 1), (1, 80, 73, 73, 192, 80, 3, 3, 1, 1), (1, 128, 16, 16, 128, 128, 3, 3, 1, 1), (1, 192, 16, 16, 192, 192, 3, 3, 1, 1), (1, 256, 16, 16, 256, 256, 3, 3, 1, 1), (1, 1024, 14, 14, 512, 1024, 1, 1, 1, 1), (1, 128, 16, 16, 160, 128, 3, 3, 1, 1), (1, 576, 14, 14, 192, 576, 1, 1, 1, 1), (1, 96, 16, 16, 128, 96, 3, 3, 1, 1), (1, 1024, 14, 14, 256, 1024, 1, 1, 1, 1), (1, 576, 14, 14, 128, 576, 1, 1, 1, 1), (1, 64, 29, 29, 96, 64, 3, 3, 1, 1), (1, 64, 56, 56, 128, 64, 1, 1, 2, 2), (1, 608, 14, 14, 192, 608, 1, 1, 1, 1)]
18+
#
19+
#workloads = [
20+
#(256, 16, 16, 256, 3, 1),
21+
#(512, 9, 9, 512, 3, 1),
22+
#(128, 30, 30, 128, 3, 1),
23+
#(64, 56, 56, 128, 1, 2),
24+
#(64, 58, 58, 64, 3, 1),
25+
#(128, 28, 28, 256, 1, 2),
26+
#(256, 16, 16, 512, 3, 2),
27+
#(64, 58, 58, 128, 3, 2),
28+
#(4, 230, 230, 64, 7, 2),
29+
#(128, 30, 30, 256, 3, 2),
30+
#(256, 14, 14, 512, 1, 2)
31+
#]
3232

3333
for i in workloads:
3434
exec_time = []
3535
with open('input', 'w') as f:
3636
lst = list(i)
37-
#for j in (i[1:5] + i[6:]):
38-
# f.write(str(j) + '\n')
39-
f.write(str(i[0]) + '\n')
40-
f.write(str(i[1]) + '\n')
41-
f.write(str(i[2]) + '\n')
42-
f.write(str(i[3]) + '\n')
43-
f.write(str(i[4]) + '\n')
44-
f.write(str(i[4]) + '\n')
45-
f.write(str(i[5]) + '\n')
46-
f.write(str(i[5]) + '\n')
37+
for j in (i[1:5] + i[6:]):
38+
f.write(str(j) + '\n')
39+
#f.write(str(i[0]) + '\n')
40+
#f.write(str(i[1]) + '\n')
41+
#f.write(str(i[2]) + '\n')
42+
#f.write(str(i[3]) + '\n')
43+
#f.write(str(i[4]) + '\n')
44+
#f.write(str(i[4]) + '\n')
45+
#f.write(str(i[5]) + '\n')
46+
#f.write(str(i[5]) + '\n')
4747
try:
4848
avg = []
4949
for j in range(10):

python/tensorizer/intrinsics/cpu.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -163,18 +163,24 @@ def callback(op):
163163
reduction.append(axis)
164164
else:
165165
simple.append(axis)
166+
sch[op].pragma(stencil[0], 'tensorize', pragma)
167+
168+
if tune.parallel_only:
169+
if str(op) != str(output):
170+
sch[op].reorder(*(simple + unroll + reduction + stencil))
171+
else:
172+
sch[op].reorder(*([fusion] + unroll + simple + reduction + stencil))
173+
return
174+
166175
for i in unroll:
167176
sch[op].unroll(i)
168-
sch[op].pragma(stencil[0], 'tensorize', pragma)
169177
#if simple:
170178
# unroll = [simple[-1]] + unroll
171179
# simple = simple[:-1]
172180
if str(op) != str(output):
173-
#sch[op].reorder(*(simple + reduction + unroll + stencil))
174-
sch[op].reorder(*(simple + unroll + reduction + stencil))
181+
sch[op].reorder(*(simple + reduction + unroll + stencil))
175182
else:
176-
#sch[op].reorder(*([fusion] + simple + reduction + unroll + stencil))
177-
sch[op].reorder(*([fusion] + unroll + simple + reduction + stencil))
183+
sch[op].reorder(*([fusion] + simple + reduction + unroll + stencil))
178184

179185
traverse_inline(sch, output, callback)
180186

python/tensorizer/intrinsics/looptiler.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def _ceil_div(a, b):
2121

2222
def analyze_tiling(op, pattern, max_unroll=32, max_parallel=3000):
2323

24+
print(tvm.arith._ffi_api)
2425
info = list(tvm.arith._ffi_api.MatchTensorizer(op, pattern))
2526
assert info
2627
loops = {}

python/tensorizer/tune.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
splitk = None
1414
x86 = {}
1515

16+
parallel_only = None
17+
1618
HOME = os.getenv("HOME")
1719

1820
def load_x86():

0 commit comments

Comments
 (0)