Skip to content

Commit 6ffe97a

Browse files
committed
test vectorizability of blas-callables with vector inputs
1 parent 7fe71a2 commit 6ffe97a

File tree

1 file changed

+75
-19
lines changed

1 file changed

+75
-19
lines changed

examples/python/call-external.py

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
2222

2323
if vec_dtype.numpy_dtype == np.float32:
2424
name_in_target = "cblas_sgemv"
25-
elif vec_dtype. numpy_dtype == np.float64:
25+
elif vec_dtype.numpy_dtype == np.float64:
2626
name_in_target = "cblas_dgemv"
2727
else:
2828
raise LoopyError("GEMV is only supported for float32 and float64 "
@@ -47,30 +47,37 @@ def with_descrs(self, arg_id_to_descr, callables_table):
4747
assert mat_descr.shape[0] == res_descr.shape[0]
4848
assert len(vec_descr.shape) == len(res_descr.shape) == 1
4949
# handling only the easy case when stride == 1
50-
assert vec_descr.dim_tags[0].stride == 1
5150
assert mat_descr.dim_tags[1].stride == 1
52-
assert res_descr.dim_tags[0].stride == 1
5351

5452
return self.copy(arg_id_to_descr=arg_id_to_descr), callables_table
5553

5654
def emit_call_insn(self, insn, target, expression_to_code_mapper):
5755
from pymbolic import var
56+
from loopy.codegen import UnvectorizableError
5857
mat_descr = self.arg_id_to_descr[0]
58+
vec_descr = self.arg_id_to_descr[1]
59+
res_descr = self.arg_id_to_descr[-1]
5960
m, n = mat_descr.shape
6061
ecm = expression_to_code_mapper
62+
63+
if ecm.codegen_state.vectorization_info is not None:
64+
raise UnvectorizableError("cannot vectorize BLAS-gemv.")
65+
6166
mat, vec = insn.expression.parameters
6267
result, = insn.assignees
6368

6469
c_parameters = [var("CblasRowMajor"),
6570
var("CblasNoTrans"),
6671
m, n,
67-
1,
72+
1, # alpha
6873
ecm(mat).expr,
69-
1,
74+
1, # LDA
7075
ecm(vec).expr,
71-
1,
76+
vec_descr.dim_tags[0].stride, # INCX
77+
1, # beta
7278
ecm(result).expr,
73-
1]
79+
res_descr.dim_tags[0].stride # INCY
80+
]
7481
return (var(self.name_in_target)(*c_parameters),
7582
False # cblas_gemv does not return anything
7683
)
@@ -83,17 +90,66 @@ def generate_preambles(self, target):
8390
# }}}
8491

8592

86-
n = 10
93+
def transform_1(knl):
94+
return knl
95+
96+
97+
def transform_2(knl):
98+
# A similar transformation is applied to kernels containing
99+
# SLATE <https://www.firedrakeproject.org/firedrake.slate.html>
100+
# callables.
101+
knl = lp.split_iname(knl, "e", 4, inner_iname="e_inner", slabs=(0, 1))
102+
knl = lp.privatize_temporaries_with_inames(knl, "e_inner")
103+
knl = lp.tag_inames(knl, {"e_inner": "vec"})
104+
if 0:
105+
# Easy codegen exercise, but misses vectorizing certain instructions.
106+
knl = lp.tag_array_axes(knl, "tmp3", "c,vec")
107+
else:
108+
knl = lp.tag_array_axes(knl, "tmp3,tmp2", "c,vec")
109+
return knl
110+
111+
112+
def main():
87113

88-
knl = lp.make_kernel(
89-
"{:}",
114+
knl = lp.make_kernel(
115+
"{[e,i1,i2]: 0<=e<n and 0<=i1,i2<4}",
90116
"""
91-
y[:] = gemv(A[:, :], x[:])
92-
""", [
93-
lp.GlobalArg("A", dtype=np.float64, shape=(n, n)),
94-
lp.GlobalArg("x", dtype=np.float64, shape=(n, )),
95-
lp.GlobalArg("y", shape=(n, )), ...],
96-
target=CTarget())
97-
98-
knl = lp.register_callable(knl, "gemv", CBLASGEMV(name="gemv"))
99-
print(lp.generate_code_v2(knl).device_code())
117+
for e
118+
for i1
119+
tmp1[i1] = 3*x[e, i1]
120+
end
121+
tmp2[:] = matvec(A[:, :], tmp1[:])
122+
for i2
123+
<> tmp3[i2] = 2 * tmp2[i2]
124+
out[e, i2] = tmp3[i2]
125+
end
126+
end
127+
""",
128+
kernel_data=[
129+
lp.TemporaryVariable("tmp1",
130+
shape=(4, ),
131+
dtype=None),
132+
lp.TemporaryVariable("tmp2",
133+
shape=(4, ),
134+
dtype=None),
135+
lp.GlobalArg("A",
136+
shape=(4, 4),
137+
dtype="float64"),
138+
lp.GlobalArg("x",
139+
shape=lp.auto,
140+
dtype="float64"),
141+
...],
142+
target=lp.CVectorExtensionsTarget(),
143+
lang_version=(2018, 2))
144+
145+
knl = lp.register_callable(knl, "matvec", CBLASGEMV("matvec"))
146+
147+
for transform_func in [transform_1, transform_2]:
148+
knl = transform_func(knl)
149+
print("Generated code from '{transform_func.__name__} -----'")
150+
print(lp.generate_code_v2(knl).device_code())
151+
print(75 * "-")
152+
153+
154+
if __name__ == "__main__":
155+
main()

0 commit comments

Comments
 (0)