Skip to content

Commit cf19b1d

Browse files
author
Jian Weng
committedNov 16, 2020
for cgo ae
1 parent 31f86dc commit cf19b1d

File tree

10 files changed

+552
-16
lines changed

10 files changed

+552
-16
lines changed
 

‎apps/gpu/alone.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import tvm
2+
import tensorizer
3+
import logging
4+
import sys
5+
import numpy as np
6+
from tvm import relay
7+
from tvm import autotvm
8+
9+
import topi
10+
from tvm.relay import op
11+
12+
13+
#t0, t1 = eval(input())
14+
#n, c, h, w = map(int, t0)
15+
#oc, ic, kh, kw = map(int, t1)
16+
n, c, h, w, oc, ic, kh, kw, sh, sw = map(int, input().split())
17+
18+
oh = (h - kh) // sh + 1
19+
ow = (w - kw) // sw + 1
20+
21+
import time
22+
timing = -1
23+
24+
def tracer(module, info, is_before):
25+
global timing
26+
if bool(is_before):
27+
timing = time.time()
28+
else:
29+
print('Executes: ', info.name, (time.time() - timing) * 1000)
30+
31+
from tensorizer import tune
32+
tune.enable = False
33+
34+
result = info = 1e9
35+
for i in [None, 'fuse', 'pad'] if ow < 32 else [None]:
36+
j = 16
37+
while True:
38+
diffc = diffoc = diffh = diffw = 0
39+
#if c % 64:
40+
# diffc = 64 - c % 64
41+
42+
#if oc % 32:
43+
# diffoc = 32 - oc % 32
44+
45+
#can_fuse = can_pad = True
46+
#if i == 'pad':
47+
# can_fuse = False
48+
#if i == 'fuse':
49+
# can_pad = False
50+
#if not ((oh * ow % 32 == 0 and 32 % ow == 0) or ow % 32 == 0):
51+
# first_h = sh - (h - kh) % sh
52+
# first_w = sw - (w - kw) % sw
53+
# max_diff_h = 32 - oh % 32
54+
# max_diff_w = 32 - ow % 32
55+
# diffh = diffw = 1e9
56+
# for i in range(max_diff_h + 1):
57+
# for j in range(max_diff_w + 1):
58+
# if (((oh + i) * (ow + j) % 32 == 0 and 32 % (ow + j) == 0 and can_fuse) or ((ow + j) % 32 == 0 and can_pad)) and i + j < diffh + diffw:
59+
# def to_pad(padding, first, stride):
60+
# if padding == 0:
61+
# return 0
62+
# assert padding >= 1
63+
# return (padding - 1) * stride + first
64+
# diffh, diffw = to_pad(i, first_h, sh), to_pad(j, first_w, sw)
65+
# #assert (height + diffh - kh + 1) * (width + diffw - kw + 1) % 32 == 0
66+
67+
68+
#var_x = relay.var('x', shape=(n, (c + diffc) // 16, (h + diffh), (w + diffw), 16), dtype='float16')
69+
#var_w = relay.const(tvm.nd.array((np.random.randn((oc + diffoc) // 16, (c + diffc) // 16, kh, kw, 16, 16) * 128).astype('float16')))
70+
#conv2d = relay.nn.conv2d(var_x, var_w, out_dtype='float32', kernel_size=(kh, kw), channels=oc + diffoc, strides=(sh, sw), data_layout='NCHW16c', kernel_layout='OIHW16i16o')
71+
#if diffc or diffoc or diffh or diffw:
72+
# y = relay.strided_slice(conv2d,
73+
# begin=relay.const(tvm.nd.array([0, 0, 0, 0])),
74+
# end=relay.const(tvm.nd.array([n, oc, oh, ow])))
75+
#else:
76+
# y = conv2d
77+
var_x = relay.var('x', shape=(n, c, h, w), dtype='float32')
78+
var_w = relay.const(tvm.nd.array((np.random.randn(oc, ic, kh, kw) * 128).astype('float32')))
79+
var_b = relay.const(tvm.nd.array((np.random.randn(1, oc, 1, 1) * 128).astype('float32')))
80+
conv2d = relay.nn.conv2d(var_x, var_w, out_dtype='float32', kernel_size=(kh, kw), channels=oc, strides=(sh, sw), out_layout='NCHW16c')
81+
y = conv2d
82+
83+
func = relay.Function([var_x], y)
84+
module = tvm.IRModule()
85+
module['main'] = func
86+
87+
tune.padding = i
88+
tune.splitk = j
89+
passes = [(1, tensorizer.rewrite)]
90+
with tvm.transform.PassContext(opt_level=0, trace=tracer, config={'tir.add_lower_pass': passes}):
91+
#with tvm.transform.PassContext(opt_level=4, trace=tracer):
92+
#graph, lib, params = tvm.relay.build(module, target='cuda -libs=cublas,cudnn')
93+
graph, lib, params = tvm.relay.build(module, target='nvptx -libs=cublas,cudnn')
94+
from tvm.contrib import graph_runtime as runtime
95+
from tvm.contrib.debugger import debug_runtime as runtime
96+
func = runtime.create(graph, lib, tvm.gpu())
97+
98+
x_ =(np.random.randn(n, c, h, w) * 128).astype('float32')
99+
func.set_input('x', x_)
100+
timer = func.module.time_evaluator('run', ctx=tvm.gpu(), number=2, repeat=10)
101+
102+
timed = timer()
103+
while np.var(timed.results) > 1e-5:
104+
timed = timer()
105+
106+
if timed.mean < result:
107+
result = timed.mean
108+
info = (i, j)
109+
110+
111+
relay.backend.compile_engine.get().clear()
112+
j <<= 1
113+
if j > tune.total_idx:
114+
break

‎apps/gpu/input

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1 608 14 14 192 608 1 1 1 1

‎apps/gpu/intersect

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
(1, 576, 14, 14, 192, 576, 1, 1, 1, 1, 8.4071, 15.396700000000001)
2+
(1, 160, 9, 9, 224, 160, 3, 3, 1, 1, 9.589333333333332, 13.485766666666668)
3+
(1, 2048, 8, 8, 384, 2048, 1, 1, 1, 1, 13.051733333333335, 17.1916)
4+
(1, 64, 58, 58, 128, 64, 3, 3, 2, 2, 15.326533333333337, 18.845066666666664)
5+
(1, 64, 56, 56, 128, 64, 1, 1, 2, 2, 5.128066666666667, 6.184633333333333)
6+
(1, 1056, 7, 7, 192, 1056, 1, 1, 1, 1, 6.848499999999999, 8.254633333333333)
7+
(1, 64, 29, 29, 96, 64, 3, 3, 1, 1, 19.699066666666667, 22.661566666666666)
8+
(1, 576, 14, 14, 128, 576, 1, 1, 1, 1, 7.0441, 8.0889)
9+
(1, 1024, 14, 14, 512, 1024, 1, 1, 1, 1, 28.08936666666667, 32.23866666666667)
10+
(1, 160, 17, 23, 192, 160, 1, 7, 1, 1, 21.459366666666668, 24.52113333333334)
11+
(1, 288, 35, 35, 384, 288, 3, 3, 2, 2, 86.23503333333333, 98.33573333333334)
12+
(1, 128, 16, 16, 128, 128, 3, 3, 1, 1, 10.165566666666667, 11.5552)
13+
(1, 192, 23, 17, 192, 192, 7, 1, 1, 1, 26.066133333333333, 29.171766666666667)
14+
(1, 768, 17, 17, 128, 768, 1, 1, 1, 1, 12.296766666666663, 13.711066666666666)
15+
(1, 768, 17, 17, 160, 768, 1, 1, 1, 1, 14.967966666666669, 16.653766666666666)
16+
(1, 256, 56, 56, 128, 256, 1, 1, 2, 2, 9.581933333333334, 10.642633333333334)
17+
(1, 192, 17, 17, 320, 192, 3, 3, 2, 2, 11.066666666666668, 12.182533333333334)
18+
(1, 96, 16, 16, 128, 96, 3, 3, 1, 1, 8.426166666666667, 9.266100000000003)
19+
(1, 128, 16, 16, 160, 128, 3, 3, 1, 1, 11.457933333333335, 12.278133333333335)
20+
(1, 1280, 8, 8, 384, 1280, 1, 1, 1, 1, 11.535666666666666, 12.33676666666667)
21+
(1, 64, 16, 16, 96, 64, 3, 3, 1, 1, 6.273666666666665, 6.7090666666666685)
22+
(1, 192, 16, 16, 192, 192, 3, 3, 1, 1, 17.286700000000003, 18.317766666666667)
23+
(1, 160, 23, 17, 192, 160, 7, 1, 1, 1, 25.39276666666667, 26.673099999999998)
24+
(1, 768, 17, 17, 192, 768, 1, 1, 1, 1, 17.378100000000003, 18.132166666666663)
25+
(1, 1024, 14, 14, 256, 1024, 1, 1, 1, 1, 17.09733333333333, 17.7697)
26+
(1, 256, 16, 16, 256, 256, 3, 3, 1, 1, 29.66086666666666, 30.810766666666673)
27+
(1, 576, 14, 14, 64, 576, 1, 1, 1, 1, 5.460666666666666, 5.6329666666666665)
28+
(1, 608, 14, 14, 192, 608, 1, 1, 1, 1, 8.621333333333334, 8.864866666666666)
29+
(1, 192, 27, 27, 64, 192, 1, 1, 1, 1, 6.770499999999999, 6.957433333333333)
30+
(1, 192, 16, 16, 256, 192, 3, 3, 1, 1, 23.1803, 23.818533333333335)
31+
(1, 160, 17, 23, 160, 160, 1, 7, 1, 1, 19.4532, 19.969299999999997)
32+
(1, 256, 27, 27, 64, 256, 1, 1, 1, 1, 8.121599999999999, 8.312833333333332)
33+
(1, 448, 10, 10, 384, 448, 3, 3, 1, 1, 21.378566666666664, 21.486066666666662)
34+
(1, 2048, 8, 8, 448, 2048, 1, 1, 1, 1, 20.103566666666666, 20.196166666666663)
35+
(1, 128, 28, 28, 512, 128, 1, 1, 1, 1, 14.981066666666665, 15.042333333333334)
36+
(1, 32, 149, 149, 32, 32, 3, 3, 1, 1, 42.16560000000001, 42.2847)
37+
(1, 80, 73, 73, 192, 80, 3, 3, 1, 1, 362.71533333333326, 363.31000000000006)

‎apps/gpu/run.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import subprocess
2+
3+
with open('./intersect', 'r') as f:
4+
shapes = []
5+
for i in f.readlines():
6+
a = eval(i)
7+
shapes.append(' '.join(map(str, a[:-2])))
8+
shapes = set(shapes)
9+
for i in shapes:
10+
with open('input', 'w') as f:
11+
f.write(i)
12+
print('tuning:', i)
13+
subprocess.check_output('python relay.py < input', shell=True)

‎apps/gpu/tune.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import subprocess
2+
3+
#with open('/home/ubuntu/shapes.raw', 'r') as f:
4+
# shapes = []
5+
# for i in f.readlines():
6+
# shapes.append(i)
7+
# shapes = set(shapes)
8+
# for i in shapes:
9+
# with open('input', 'w') as f:
10+
# f.write(i)
11+
# print('tuning:', i)
12+
# subprocess.check_output('python relay.py < input', shell=True)
13+
#
14+
15+
shapes = [(1, 288, 35, 35, 384, 288, 3, 3, 2, 2), (1, 160, 9, 9, 224, 160, 3, 3, 1, 1), (1, 1056, 7, 7, 192, 1056, 1, 1, 1, 1), (1, 80, 73, 73, 192, 80, 3, 3, 1, 1), (1, 128, 16, 16, 128, 128, 3, 3, 1, 1), (1, 192, 16, 16, 192, 192, 3, 3, 1, 1), (1, 256, 16, 16, 256, 256, 3, 3, 1, 1), (1, 1024, 14, 14, 512, 1024, 1, 1, 1, 1), (1, 128, 16, 16, 160, 128, 3, 3, 1, 1), (1, 576, 14, 14, 192, 576, 1, 1, 1, 1), (1, 96, 16, 16, 128, 96, 3, 3, 1, 1), (1, 1024, 14, 14, 256, 1024, 1, 1, 1, 1), (1, 576, 14, 14, 128, 576, 1, 1, 1, 1), (1, 64, 29, 29, 96, 64, 3, 3, 1, 1), (1, 64, 56, 56, 128, 64, 1, 1, 2, 2), (1, 608, 14, 14, 192, 608, 1, 1, 1, 1)]
16+
17+
for i in shapes:
18+
with open('input', 'w') as f:
19+
f.write(' '.join(map(str, i)))
20+
print('tuning:', i)
21+
subprocess.check_output('python relay.py < input', shell=True)

‎gpu-tune.log

+288
Large diffs are not rendered by default.

‎python/tensorizer/intrinsics/cpu.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,15 @@ def callback(op):
7575

7676
if tune.cpu_idx is None:
7777
to_apply = points[0][-1]
78-
#with open('/home/ubuntu/Tensorization-PoC/cpu-shapes.log', 'a') as f:
79-
# f.write(f'{tune.ashape} {tune.bshape} {tune.strides}\n')
78+
import os
79+
HOME = os.getenv("HOME")
80+
try:
81+
f = open(HOME + '/Tensorization-PoC/cpu-shapes.log', 'a')
82+
except:
83+
f = open(HOME + '/UNIT/cpu-shapes.log', 'a')
84+
except:
85+
assert False
86+
f.write(f'{tune.ashape} {tune.bshape} {tune.strides}\n')
8087
if (tune.ashape, tune.bshape, tune.strides) in tune.x86.keys():
8188
to_apply = points[tune.x86[(tune.ashape, tune.bshape, tune.strides)]][-1]
8289
else:
@@ -182,4 +189,4 @@ def callback(op):
182189
arm_operand = functools.partial(loader, cast_type='int8x16')
183190
arm_writeback = functools.partial(writer, llvm_intrin='llvm.aarch64.neon.sdot.v4i32.v16i8', dtype='int32x4')
184191
from .pattern import arm_sdot128_i8i16
185-
arm_schedule = functools.partial(schedule, pattern=arm_sdot128_i8i16, pragma='vdot', max_threads=10000)
192+
arm_schedule = functools.partial(schedule, pattern=arm_sdot128_i8i16, pragma='vdot', max_threads=10000)

‎python/tensorizer/intrinsics/gpu.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -213,4 +213,5 @@ def cleanup(store, axis, operands):
213213
res = [res, tvm.tir.Evaluate(tvm.tir.call_llvm_intrin('handle', 'llvm.nvvm.barrier0', tvm.tir.const(0, 'int32')))]
214214

215215
res = tvm.tir.SeqStmt(res)
216-
return res
216+
return res
217+

‎python/tensorizer/ops/gpu.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from tvm import te
44
from tvm import autotvm
55
import tvm
6+
from tensorizer import tune
67

78
@autotvm.register_topi_compute('conv2d_NCHW16c_OHWI16o.nvptx')
89
def _conv2d_NCHW16c_OHWI16o_impl(cfg, a, b, stride_h, stride_w, out_type):
@@ -70,7 +71,17 @@ def schedule_fetcher(sch, buffer, y, x):
7071
sch[buffer].vectorize(xi)
7172

7273
rc = sch[conv].op.reduce_axis[0]
73-
rco, rci = sch[conv].split(rc, 64)
74+
for i in [16, 32, 64]:
75+
if rc.dom.extent.value % i == 0:
76+
split_k = i
77+
print('!!!!!!!!!!!!!!!')
78+
print(tune.splitk)
79+
if tune.splitk is not None:
80+
tune.total_idx = split_k
81+
split_k = tune.splitk
82+
83+
rc = sch[conv].op.reduce_axis[0]
84+
rco, rci = sch[conv].split(rc, split_k)
7485
rcio, rcii = sch[conv].split(rci, 16)
7586
rf = sch.rfactor(conv, rcio)
7687
cc = sch.cache_write(rf, 'wmma.accumulator')
@@ -83,7 +94,9 @@ def schedule_fetcher(sch, buffer, y, x):
8394
xyio, xyii = sch[conv].split(xyi, 16)
8495
obo, obi = sch[conv].split(ob, 8)
8596
sch[conv].reorder(batch, oco, xyo, oci, xyio, xyii, obo, obi)
86-
sch[conv].bind(sch[conv].fuse(oci, xyio), te.thread_axis('threadIdx.y'))
97+
fused = sch[conv].fuse(oci, xyio)
98+
fo, fi = sch[conv].split(fused, split_k // 16)
99+
sch[conv].bind(fi, te.thread_axis('threadIdx.y'))
87100
sch[conv].bind(sch[conv].fuse(xyii, obo), te.thread_axis('threadIdx.x'))
88101
sch[conv].vectorize(obi)
89102
sch[rf].compute_at(sch[conv], xyo)
@@ -99,7 +112,9 @@ def schedule_fetcher(sch, buffer, y, x):
99112
xyio, xyii = sch[output].split(xyi, 16)
100113
obo, obi = sch[output].split(ob, 8)
101114
sch[output].reorder(batch, oco, xyo, oci, xyio, xyii, obo, obi)
102-
sch[output].bind(sch[output].fuse(oci, xyio), te.thread_axis('threadIdx.y'))
115+
fused = sch[output].fuse(oci, xyio)
116+
fo, fi = sch[output].split(fused, split_k // 16)
117+
sch[output].bind(fi, te.thread_axis('threadIdx.y'))
103118
sch[output].bind(sch[output].fuse(xyii, obo), te.thread_axis('threadIdx.x'))
104119
sch[output].vectorize(obi)
105120
sch[output].bind(oco, te.thread_axis('blockIdx.y'))
@@ -134,7 +149,7 @@ def schedule_fetcher(sch, buffer, y, x):
134149
sch[aaii].compute_at(sch[cc], crw)
135150
sch[a_icol].compute_inline()
136151
fused = sch[aaii].fuse(sch[aaii].op.axis[1], sch[aaii].op.axis[2], sch[aaii].op.axis[3])
137-
fo, fi = sch[aaii].split(fused, nparts=4)
152+
fo, fi = sch[aaii].split(fused, nparts=split_k // 16)
138153
fio, fii = sch[aaii].split(fi, nparts=32)
139154
sch[aaii].bind(fo, te.thread_axis('threadIdx.y'))
140155
sch[aaii].bind(fio, te.thread_axis('threadIdx.x'))
@@ -143,10 +158,10 @@ def schedule_fetcher(sch, buffer, y, x):
143158
else:
144159
a_reuse = sch.cache_read(a, 'shared', [cc])
145160
sch[a_reuse].compute_at(sch[cc], crcio)
146-
schedule_fetcher(sch, a_reuse, 4, 32)
161+
schedule_fetcher(sch, a_reuse, split_k // 16, 32)
147162
a_shared = sch.cache_read(a_reuse, 'shared', [cc])
148163
sch[a_shared].compute_at(sch[cc], crw)
149-
schedule_fetcher(sch, a_shared, 4, 32)
164+
schedule_fetcher(sch, a_shared, split_k // 16, 32)
150165

151166
aa = sch.cache_read(a_shared, 'wmma.matrix_a', [cc])
152167
#aa = sch.cache_read(a, 'wmma.matrix_a', [cc])
@@ -168,9 +183,11 @@ def _conv2d_schedule_wdim(sch, conv, output, stride_h, stride_w):
168183
if rc.dom.extent.value % i == 0:
169184
split_k = i
170185

171-
if stride_h != 1 or stride_w != 1:
172-
split_k = 16
173-
186+
print('!!!!!!!!!!!!!!!!!')
187+
print(tune.splitk)
188+
if tune.splitk is not None:
189+
tune.total_idx = split_k
190+
split_k = tune.splitk
174191

175192
rc = sch[conv].op.reduce_axis[0]
176193
rco, rci = sch[conv].split(rc, split_k)
@@ -186,7 +203,8 @@ def _conv2d_schedule_wdim(sch, conv, output, stride_h, stride_w):
186203
sch[conv].reorder(batch, x, yo, oco, oo, oci, yio, oio, yii, oii)
187204
sch[rf].compute_at(sch[conv], oo)
188205
fused = sch[conv].fuse(oci, yio, oio)
189-
sch[conv].bind(fused, te.thread_axis('threadIdx.y'))
206+
fo, fi = sch[conv].split(fused, split_k // 16)
207+
sch[conv].bind(fi, te.thread_axis('threadIdx.y'))
190208
vo, vi = sch[conv].split(oii, 8)
191209
sch[conv].vectorize(vi)
192210
fused = sch[conv].fuse(yii, vo)
@@ -269,10 +287,16 @@ def callback(op):
269287
nonlocal sch
270288
if len(list(op.reduce_axis)):
271289
a, b = op.input_tensors
290+
tune.ashape = get_const_tuple(a.shape)
291+
tune.bshape = get_const_tuple(b.shape)
272292

273293
conv = op.output(0)
274294
n, c, h, w, _ = get_const_tuple(conv.shape)
275295
stride_h, stride_w = attrs.get_int_tuple('strides')
296+
tune.strides = (stride_h, stride_w)
297+
ky = tune.ashape, tune.bshape, (stride_h, stride_w)
298+
if tune.enable and ky in tune.cuda_kernel.keys():
299+
tune.splitk = int(tune.cuda_kernel[ky])
276300
if w % 32 == 0:
277301
_conv2d_schedule_wdim(sch, conv, output, stride_h, stride_w)
278302
else:
@@ -281,4 +305,6 @@ def callback(op):
281305

282306
traverse_inline(sch, output, callback)
283307

308+
tune.splitk = None
309+
284310
return sch

‎python/tensorizer/tune.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import os
2+
3+
enable = False
4+
15
cpu_idx = None
26
total_idx = None
37
ashape = None
@@ -8,13 +12,37 @@
812
splitk = None
913
x86 = {}
1014

15+
HOME = os.getenv("HOME")
16+
1117
def load_x86():
12-
for i in open('/home/ubuntu/Tensorization-PoC/cpu-tune.log').readlines():
18+
try:
19+
f = open(HOME + '/Tensorization-PoC/cpu-tune.log')
20+
except:
21+
f = open(HOME + '/UNIT/cpu-tune.log')
22+
for i in f.readlines():
1323
i = i.replace(') ', '), ')
1424
try:
1525
a, b, s, v = eval(i)
1626
except:
1727
a, b, s, v, _, _ = eval(i)
1828
x86[(a, b, s)] = v
1929

20-
load_x86()
30+
cuda_kernel = {}
31+
cuda_relay = {}
32+
def load_cuda():
33+
try:
34+
f = open(HOME + '/Tensorization-PoC/gpu-tune.log')
35+
except:
36+
f = open(HOME + '/UNIT/gpu-tune.log')
37+
raw = f.readlines()
38+
for i in raw[::2]:
39+
i = i.replace(') ', '), ')
40+
a, b, s, v, _ = eval(i)
41+
cuda_kernel[(a, b, s)] = v
42+
for i in raw[1::2]:
43+
N, C, H, W, O, I, KH, KW, SH, SW, v, _ = i.split()
44+
v = v.strip(',')
45+
cuda_relay[tuple(map(int, (N, C, H, W, O, I, KH, KW, SH, SW)))] = v
46+
47+
load_x86()
48+
load_cuda()

0 commit comments

Comments
 (0)
Please sign in to comment.