|
| 1 | +import tvm |
| 2 | +import tensorizer |
| 3 | +import logging |
| 4 | +import sys |
| 5 | +from tvm import relay |
| 6 | +from tvm import autotvm |
| 7 | +import numpy as np |
| 8 | +from topi.util import get_const_tuple |
| 9 | + |
| 10 | +import topi |
| 11 | +from tvm.relay import op |
| 12 | + |
| 13 | +workloads = [ |
| 14 | +#(256, 16, 16, 256, 3, 1), |
| 15 | +#(512, 9, 9, 512, 3, 1), |
| 16 | +(128, 30, 30, 128, 3, 1), |
| 17 | +#(64, 56, 56, 128, 1, 2), |
| 18 | +#(64, 58, 58, 64, 3, 1), |
| 19 | +#(128, 28, 28, 256, 1, 2), |
| 20 | +#(256, 16, 16, 512, 3, 2), |
| 21 | +#(64, 58, 58, 128, 3, 2), |
| 22 | +#(4, 230, 230, 64, 7, 2), |
| 23 | +#(128, 30, 30, 256, 3, 2), |
| 24 | +#(256, 14, 14, 512, 1, 2) |
| 25 | +] |
| 26 | + |
| 27 | +output = [] |
| 28 | + |
| 29 | +for ic, h, w, oc, kernel, stride in workloads: |
| 30 | + |
| 31 | + d = 16 |
| 32 | + |
| 33 | + if ic % 4: |
| 34 | + ic += 4 - ic % 4 |
| 35 | + |
| 36 | + if oc % 16: |
| 37 | + oc += 16 - oc % 16 |
| 38 | + |
| 39 | + if ic % 16 == 0: |
| 40 | + c_chunk = 16 |
| 41 | + a = tvm.te.placeholder((1, ic // 16, d, h, w, 16), dtype='int8') |
| 42 | + b = tvm.te.placeholder((oc // 16, ic // 16, kernel, kernel, kernel, 4, 16, 4), dtype='int8') |
| 43 | + rco = tvm.te.reduce_axis((0, ic // 16)) |
| 44 | + rcm = tvm.te.reduce_axis((0, 4)) |
| 45 | + rci = tvm.te.reduce_axis((0, 4)) |
| 46 | + else: |
| 47 | + assert ic % 4 == 0 |
| 48 | + c_chunk = 4 |
| 49 | + a = tvm.te.placeholder((1, ic // 4, d, h, w, 4), dtype='int8') |
| 50 | + b = tvm.te.placeholder((oc // 16, ic // 4, kernel, kernel, kernel, 1, 16, 4), dtype='int8') |
| 51 | + rco = tvm.te.reduce_axis((0, ic // 4)) |
| 52 | + rcm = tvm.te.reduce_axis((0, 1)) |
| 53 | + rci = tvm.te.reduce_axis((0, 4)) |
| 54 | + |
| 55 | + rd = tvm.te.reduce_axis((0, kernel)) |
| 56 | + rh = tvm.te.reduce_axis((0, kernel)) |
| 57 | + rw = tvm.te.reduce_axis((0, kernel)) |
| 58 | + |
| 59 | + |
| 60 | + c = tvm.te.compute((1, oc // 16, (d - kernel) // stride + 1, (w - kernel) // stride + 1, (h - kernel) // stride + 1, 16), |
| 61 | + lambda batch, ochunk, x, y, z, oblock: tvm.te.sum(a[batch, rco, stride*x+rd, stride*y+rh, stride*z+rw, rcm*4+rci] |
| 62 | + .astype('int32') * |
| 63 | + b[ochunk, rco, rd, rh, rw, rcm, oblock, rci] |
| 64 | + .astype('int32'), axis=[rco, rd, rh, rw, rcm, rci])) |
| 65 | + |
| 66 | + print(get_const_tuple(a.shape)) |
| 67 | + print(get_const_tuple(b.shape)) |
| 68 | + print(get_const_tuple(c.shape)) |
| 69 | + #a = tvm.te.placeholder((N, C, H, W, c), dtype='int8') |
| 70 | + #b = tvm.te.placeholder((O, I, KH, KW, e, o, i), dtype='int8') |
| 71 | + |
| 72 | + passes = [(1, tensorizer.rewrite)] |
| 73 | + from tensorizer import tune |
| 74 | + tune.cpu_idx = 0 |
| 75 | + target = -1 |
| 76 | + results = [] |
| 77 | + virgin = True |
| 78 | + while True: |
| 79 | + with tvm.transform.PassContext(opt_level=3, config={'tir.add_lower_pass': passes}), tvm.target.create('llvm -mcpu=cascadelake'): |
| 80 | + sch = tensorizer.INTRINSICS['vnni']['schedule']([c], (stride, stride, stride)) |
| 81 | + |
| 82 | + module = tvm.build(sch, [a, b, c], 'llvm -mcpu=cascadelake') |
| 83 | + np_a = np.zeros(get_const_tuple(a.shape), dtype='int8') |
| 84 | + np_b = np.zeros(get_const_tuple(b.shape), dtype='int8') |
| 85 | + np_c = np.zeros(get_const_tuple(c.shape), dtype='int32') |
| 86 | + nd_a = tvm.nd.array(np_a, tvm.cpu()) |
| 87 | + nd_b = tvm.nd.array(np_b, tvm.cpu()) |
| 88 | + nd_c = tvm.nd.array(np_c, tvm.cpu()) |
| 89 | + fte = module.time_evaluator(module.entry_name, ctx=tvm.cpu(), number=3, repeat=10) |
| 90 | + res = fte(nd_a, nd_b, nd_c) |
| 91 | + while np.var(res.results) > 1e-5: |
| 92 | + res = fte(nd_a, nd_b, nd_c) |
| 93 | + import functools, operator |
| 94 | + total = functools.reduce(operator.mul, get_const_tuple(c.shape), 1) * (kernel ** 3) * (ic // c_chunk) |
| 95 | + results.append(res.mean) |
| 96 | + |
| 97 | + relay.backend.compile_engine.get().clear() |
| 98 | + tune.cpu_idx += 1 |
| 99 | + #if tune.cpu_idx - target > 8: |
| 100 | + # break |
| 101 | + if tune.cpu_idx >= tune.total_idx - 1: |
| 102 | + break |
| 103 | + #print(results) |
| 104 | + results = min(results) |
| 105 | + output.append((total, results * 1e6, total / results / 1e9)) |
| 106 | + open('res', 'a').write(str(output[-1]) + '\n') |
| 107 | + |
| 108 | +print(*output, sep='\n') |
0 commit comments