1
+ import functools
2
+ import operator
3
+
4
+ import tvm
5
+
6
+ def _factors (x ):
7
+ res = []
8
+ for i in range (2 , x ):
9
+ if x % i == 0 :
10
+ res .append (i )
11
+ res .append (x // i )
12
+ if i * i > x :
13
+ break
14
+ return [1 , x ] + sorted (res ) if res else sorted (list (range (2 , 9 )), key = lambda v : x % v )
15
+
16
+ def _ceil_div (a , b ):
17
+ return (a - 1 ) // b + 1
18
+
19
+
20
+ def analyze_tiling (op , pattern , max_unroll = 32 , max_parallel = 10000 ):
21
+
22
+ info = list (tvm .arith ._ffi_api .MatchTensorizer (op , pattern ))
23
+ assert info
24
+ loops = {}
25
+ for i , j in zip (info [::2 ], info [1 ::2 ]):
26
+ loops [i ] = j
27
+
28
+ dom = {}
29
+ split = []
30
+ for i in op .axis :
31
+ split .append ([i .dom .extent .value ])
32
+ for i in op .reduce_axis :
33
+ split .append ([i .dom .extent .value ])
34
+
35
+ def tiling_stencil (axis , offset ):
36
+ nonlocal loops
37
+ for i , j in enumerate (axis ):
38
+ if j in loops .keys ():
39
+ factor = loops [j ].dom .extent .value
40
+ split [i + offset ] = [j .dom .extent .value // factor , (factor , 'offload' )]
41
+
42
+ tiling_stencil (op .axis , 0 )
43
+ tiling_stencil (op .reduce_axis , len (op .axis ))
44
+
45
+ # from outer to inner enumerate the loop levels to be parallelized
46
+ for parallel in range (len (op .axis )):
47
+ for tile_parallel in _factors (split [parallel ][0 ]):
48
+ copy_split = split [:]
49
+ fused_prod = functools .reduce (
50
+ operator .mul ,
51
+ [j for i in copy_split [:parallel ] for j in i if isinstance (j , int )],
52
+ 1 ) * _ceil_div (copy_split [parallel ][0 ], tile_parallel )
53
+ if fused_prod > max_parallel :
54
+ continue
55
+ copy_split [parallel ] = [(_ceil_div (copy_split [parallel ][0 ], tile_parallel ), 'parallel' ),
56
+ tile_parallel ] + copy_split [parallel ][1 :]
57
+ for unroll in range (parallel , len (op .axis )):
58
+ j = 1 if unroll == parallel else 0
59
+ for tile_unroll in _factors (copy_split [unroll ][j ]):
60
+ yield_split = copy_split [:]
61
+ yield_split [unroll ] = yield_split [unroll ][:j ] + [
62
+ _ceil_div (yield_split [unroll ][j ], tile_unroll ),
63
+ (tile_unroll , 'unroll' )] + yield_split [unroll ][j + 1 :]
64
+ unroll_prod = functools .reduce (
65
+ operator .mul ,
66
+ [j for i in yield_split [unroll + 1 :len (op .axis )] for j in i if isinstance (j , int )],
67
+ 1 ) * tile_unroll
68
+ if unroll_prod > max_unroll :
69
+ break
70
+ yield [split [parallel ][0 ] % tile_parallel ,
71
+ copy_split [unroll ][j ] % tile_unroll ,
72
+ fused_prod , unroll_prod , yield_split ]
73
+ if len (split [parallel ]) != 1 :
74
+ break
0 commit comments