@@ -22,7 +22,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
2222
2323 if vec_dtype .numpy_dtype == np .float32 :
2424 name_in_target = "cblas_sgemv"
25- elif vec_dtype . numpy_dtype == np .float64 :
25+ elif vec_dtype .numpy_dtype == np .float64 :
2626 name_in_target = "cblas_dgemv"
2727 else :
2828 raise LoopyError ("GEMV is only supported for float32 and float64 "
@@ -47,30 +47,37 @@ def with_descrs(self, arg_id_to_descr, callables_table):
4747 assert mat_descr .shape [0 ] == res_descr .shape [0 ]
4848 assert len (vec_descr .shape ) == len (res_descr .shape ) == 1
4949 # handling only the easy case when stride == 1
50- assert vec_descr .dim_tags [0 ].stride == 1
5150 assert mat_descr .dim_tags [1 ].stride == 1
52- assert res_descr .dim_tags [0 ].stride == 1
5351
5452 return self .copy (arg_id_to_descr = arg_id_to_descr ), callables_table
5553
5654 def emit_call_insn (self , insn , target , expression_to_code_mapper ):
5755 from pymbolic import var
56+ from loopy .codegen import UnvectorizableError
5857 mat_descr = self .arg_id_to_descr [0 ]
58+ vec_descr = self .arg_id_to_descr [1 ]
59+ res_descr = self .arg_id_to_descr [- 1 ]
5960 m , n = mat_descr .shape
6061 ecm = expression_to_code_mapper
62+
63+ if ecm .codegen_state .vectorization_info is not None :
64+ raise UnvectorizableError ("cannot vectorize BLAS-gemv." )
65+
6166 mat , vec = insn .expression .parameters
6267 result , = insn .assignees
6368
6469 c_parameters = [var ("CblasRowMajor" ),
6570 var ("CblasNoTrans" ),
6671 m , n ,
67- 1 ,
72+ 1 , # alpha
6873 ecm (mat ).expr ,
69- 1 ,
74+ 1 , # LDA
7075 ecm (vec ).expr ,
71- 1 ,
76+ vec_descr .dim_tags [0 ].stride , # INCX
77+ 1 , # beta
7278 ecm (result ).expr ,
73- 1 ]
79+ res_descr .dim_tags [0 ].stride # INCY
80+ ]
7481 return (var (self .name_in_target )(* c_parameters ),
7582 False # cblas_gemv does not return anything
7683 )
@@ -83,17 +90,66 @@ def generate_preambles(self, target):
8390# }}}
8491
8592
86- n = 10
93+ def transform_1 (knl ):
94+ return knl
95+
96+
97+ def transform_2 (knl ):
98+ # A similar transformation is applied to kernels containing
99+ # SLATE <https://www.firedrakeproject.org/firedrake.slate.html>
100+ # callables.
101+ knl = lp .split_iname (knl , "e" , 4 , inner_iname = "e_inner" , slabs = (0 , 1 ))
102+ knl = lp .privatize_temporaries_with_inames (knl , "e_inner" )
103+ knl = lp .tag_inames (knl , {"e_inner" : "vec" })
104+ if 0 :
105+ # Easy codegen exercise, but misses vectorizing certain instructions.
106+ knl = lp .tag_array_axes (knl , "tmp3" , "c,vec" )
107+ else :
108+ knl = lp .tag_array_axes (knl , "tmp3,tmp2" , "c,vec" )
109+ return knl
110+
111+
112+ def main ():
87113
88- knl = lp .make_kernel (
89- "{: }" ,
114+ knl = lp .make_kernel (
115+ "{[e,i1,i2]: 0<=e<n and 0<=i1,i2<4 }" ,
90116 """
91- y[:] = gemv(A[:, :], x[:])
92- """ , [
93- lp .GlobalArg ("A" , dtype = np .float64 , shape = (n , n )),
94- lp .GlobalArg ("x" , dtype = np .float64 , shape = (n , )),
95- lp .GlobalArg ("y" , shape = (n , )), ...],
96- target = CTarget ())
97-
98- knl = lp .register_callable (knl , "gemv" , CBLASGEMV (name = "gemv" ))
99- print (lp .generate_code_v2 (knl ).device_code ())
117+ for e
118+ for i1
119+ tmp1[i1] = 3*x[e, i1]
120+ end
121+ tmp2[:] = matvec(A[:, :], tmp1[:])
122+ for i2
123+ <> tmp3[i2] = 2 * tmp2[i2]
124+ out[e, i2] = tmp3[i2]
125+ end
126+ end
127+ """ ,
128+ kernel_data = [
129+ lp .TemporaryVariable ("tmp1" ,
130+ shape = (4 , ),
131+ dtype = None ),
132+ lp .TemporaryVariable ("tmp2" ,
133+ shape = (4 , ),
134+ dtype = None ),
135+ lp .GlobalArg ("A" ,
136+ shape = (4 , 4 ),
137+ dtype = "float64" ),
138+ lp .GlobalArg ("x" ,
139+ shape = lp .auto ,
140+ dtype = "float64" ),
141+ ...],
142+ target = lp .CVectorExtensionsTarget (),
143+ lang_version = (2018 , 2 ))
144+
145+ knl = lp .register_callable (knl , "matvec" , CBLASGEMV ("matvec" ))
146+
147+ for transform_func in [transform_1 , transform_2 ]:
148+ knl = transform_func (knl )
149+ print ("Generated code from '{transform_func.__name__} -----'" )
150+ print (lp .generate_code_v2 (knl ).device_code ())
151+ print (75 * "-" )
152+
153+
154+ if __name__ == "__main__" :
155+ main ()
0 commit comments