Skip to content

Commit 758530b

Browse files
author
Jian Weng
committed
i am not sure how good/bad it is
1 parent 87f7c1a commit 758530b

File tree

9 files changed

+177
-250
lines changed

9 files changed

+177
-250
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
*.ll
33
*.exe
44
*.s
5+
.vscode/*
56
__pycache__

python/tensorizer/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from . import tensorcore
21
from . import util
3-
from . import dse
4-
from .alter import AlterOpLayout
52
from .generic import *
63
from .intrinsics import *

python/tensorizer/alter.py

-43
This file was deleted.

python/tensorizer/analyzer.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import functools
2+
import operator
3+
4+
import tvm
5+
6+
def _factors(x):
7+
res = []
8+
for i in range(2, x):
9+
if x % i == 0:
10+
res.append(i)
11+
res.append(x // i)
12+
if i * i > x:
13+
break
14+
return [1, x] + sorted(res) if res else sorted(list(range(2, 9)), key=lambda v: x % v)
15+
16+
def _ceil_div(a, b):
17+
return (a - 1) // b + 1
18+
19+
20+
def analyze_tiling(op, pattern, max_unroll=32, max_parallel=10000):
21+
22+
info = list(tvm.arith._ffi_api.MatchTensorizer(op, pattern))
23+
assert info
24+
loops = {}
25+
for i, j in zip(info[::2], info[1::2]):
26+
loops[i] = j
27+
28+
dom = {}
29+
split = []
30+
for i in op.axis:
31+
split.append([i.dom.extent.value])
32+
for i in op.reduce_axis:
33+
split.append([i.dom.extent.value])
34+
35+
def tiling_stencil(axis, offset):
36+
nonlocal loops
37+
for i, j in enumerate(axis):
38+
if j in loops.keys():
39+
factor = loops[j].dom.extent.value
40+
split[i + offset] = [j.dom.extent.value // factor, (factor, 'offload')]
41+
42+
tiling_stencil(op.axis, 0)
43+
tiling_stencil(op.reduce_axis, len(op.axis))
44+
45+
# from outer to inner enumerate the loop levels to be parallelized
46+
for parallel in range(len(op.axis)):
47+
for tile_parallel in _factors(split[parallel][0]):
48+
copy_split = split[:]
49+
fused_prod = functools.reduce(
50+
operator.mul,
51+
[j for i in copy_split[:parallel] for j in i if isinstance(j, int)],
52+
1) * _ceil_div(copy_split[parallel][0], tile_parallel)
53+
if fused_prod > max_parallel:
54+
continue
55+
copy_split[parallel] = [(_ceil_div(copy_split[parallel][0], tile_parallel), 'parallel'),
56+
tile_parallel] + copy_split[parallel][1:]
57+
for unroll in range(parallel, len(op.axis)):
58+
j = 1 if unroll == parallel else 0
59+
for tile_unroll in _factors(copy_split[unroll][j]):
60+
yield_split = copy_split[:]
61+
yield_split[unroll] = yield_split[unroll][:j] + [
62+
_ceil_div(yield_split[unroll][j], tile_unroll),
63+
(tile_unroll, 'unroll')] + yield_split[unroll][j+1:]
64+
unroll_prod = functools.reduce(
65+
operator.mul,
66+
[j for i in yield_split[unroll+1:len(op.axis)] for j in i if isinstance(j, int)],
67+
1) * tile_unroll
68+
if unroll_prod > max_unroll:
69+
break
70+
yield [split[parallel][0] % tile_parallel,
71+
copy_split[unroll][j] % tile_unroll,
72+
fused_prod, unroll_prod, yield_split]
73+
if len(split[parallel]) != 1:
74+
break

python/tensorizer/dse.py

-59
This file was deleted.

python/tensorizer/generic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def visitor(op):
3232
def rewrite(f, mod, ctx):
3333
is_init = [False]
3434
stmt = f.body
35-
#print(stmt)
35+
print(stmt)
3636

3737
def detector(op):
3838
nonlocal is_init
@@ -61,7 +61,7 @@ def visitor(op):
6161
return None
6262

6363
res = f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, detector, visitor, ['For', 'AttrStmt']))
64-
print(res)
64+
#print(res)
6565
return res
6666

6767
def analyze(op, stencil):

0 commit comments

Comments
 (0)