1313// limitations under the License.
1414
1515#include " jit_domain/jit_amx_s8s8_dynamic_quant_matmul.hpp"
16-
17- #include " regs_pool.hpp"
18-
1916namespace jd {
2017
2118#define GET_OFF (field ) offsetof(ssd::dynamic_quant_matmul_data_t , field)
@@ -24,162 +21,167 @@ void jit_amx_s8s8_dynamic_quant_matmul_t::generate() {
2421 Xbyak::Label data_label;
2522 inLocalLabel ();
2623 {
24+ const int const_var_size = 4 + 2 * sizeof (tileconfig_t ) + 64 ;
2725 const int stack_tmpbuf_offset = 4 + 2 * sizeof (tileconfig_t );
2826 auto trans_block_col = param_.k / param_.tile_k ;
29- const auto does_calc = param_.align_m_loop > 0 || param_.tail_m != 0 ;
30- regs_pool rp (this , 1 , {does_calc ? 11 : 6 , does_calc ? 32 : 0 , 0 });
31- const auto reg_m_loop = rp.reg <Reg64>();
32- const auto reg_n_loop = rp.reg <Reg64>();
33- const auto reg_strideA = rp.reg <Reg64>();
34- const auto reg_strideB = rp.reg <Reg64>();
35- const auto reg_strideC = rp.reg <Reg64>();
27+ Xbyak::util::StackFrame st (this , 1 , 11 , const_var_size);
28+ int reg_idx = 0 ;
29+ const Reg64& reg_param = st.p [0 ];
30+ const Reg64& reg_m_loop = st.t [reg_idx++];
31+ const Reg64& reg_n_loop = st.t [reg_idx++];
32+ const Reg64& reg_strideA = st.t [reg_idx++];
33+ const Reg64& reg_strideB = st.t [reg_idx++];
34+ const Reg64& reg_strideC = st.t [reg_idx++];
35+ auto allocate_reg = [&] {
36+ Reg64* reg = &const_cast <Reg64&>(st.t [reg_idx++]);
37+ return std::shared_ptr<Reg64>(reg, [&](...) { reg_idx--; });
38+ };
3639
3740 auto prepare_mask = [&]() {
38- const auto reg_tmp = rp. reg <Xbyak::Reg32> ();
39- mov (reg_tmp, 0xffff >> param_.write_mask );
40- kmovd (matC_n_mask, reg_tmp);
41- mov (reg_tmp, 0xffff >> (16 - param_.tail_m ));
42- kmovd (scaleC_mask, reg_tmp);
41+ auto reg_tmp = allocate_reg ();
42+ mov (reg_tmp-> cvt32 () , 0xffff >> param_.write_mask );
43+ kmovd (matC_n_mask, reg_tmp-> cvt32 () );
44+ mov (reg_tmp-> cvt32 () , 0xffff >> (16 - param_.tail_m ));
45+ kmovd (scaleC_mask, reg_tmp-> cvt32 () );
4346 };
4447
4548 auto ip_16x16 = [&](int block_num) {
46- const auto reg_tmp = rp. reg <Reg64> ();
49+ auto reg_tmp = allocate_reg ();
4750 // build block
4851 {
49- const auto reg_matA_addr = rp. reg <Reg64> ();
50- const auto reg_matB_addr = rp. reg <Reg64> ();
52+ auto reg_matA_addr = allocate_reg ();
53+ auto reg_matB_addr = allocate_reg ();
5154 for (int i = 0 ; i < block_num; i++) tilezero (Tmm (i));
5255 // prepare addr & stride;
53- mov (reg_matA_addr, ptr[rp. p [ 0 ] + GET_OFF (activation)]);
54- mov (reg_matB_addr, ptr[rp. p [ 0 ] + GET_OFF (reordered_weight)]);
55- imul (reg_tmp, reg_m_loop, 16 * param_.k );
56- add (reg_matA_addr, reg_tmp);
57- imul (reg_tmp, reg_n_loop, param_.align_build_block_num * trans_block_col * 64 * (param_.tile_k / 4 ));
58- add (reg_matB_addr, reg_tmp);
56+ mov (* reg_matA_addr, ptr[reg_param + GET_OFF (activation)]);
57+ mov (* reg_matB_addr, ptr[reg_param + GET_OFF (reordered_weight)]);
58+ imul (* reg_tmp, reg_m_loop, 16 * param_.k );
59+ add (* reg_matA_addr, * reg_tmp);
60+ imul (* reg_tmp, reg_n_loop, param_.align_build_block_num * trans_block_col * 64 * (param_.tile_k / 4 ));
61+ add (* reg_matB_addr, * reg_tmp);
5962 for (int k_loop = 0 ; k_loop < param_.k / param_.tile_k ; k_loop++) {
60- tileloadd (Tmm (3 ), ptr[reg_matA_addr + reg_strideA + k_loop * param_.tile_k ]);
63+ tileloadd (Tmm (3 ), ptr[* reg_matA_addr + reg_strideA + k_loop * param_.tile_k ]);
6164 for (int idx = 0 ; idx < block_num; idx++) {
6265 int offset = idx * trans_block_col * 64 * (param_.tile_k / 4 ) + k_loop * 64 ;
63- tileloadd (Tmm (4 + idx), ptr[reg_matB_addr + reg_strideB + offset]);
66+ tileloadd (Tmm (4 + idx), ptr[* reg_matB_addr + reg_strideB + offset]);
6467 tdpbssd (Tmm (idx), Tmm (3 ), Tmm (4 + idx));
6568 }
6669 }
6770 }
6871 // store block to tmp_buf & dequant+add_bias
6972 {
7073 // store the block to tmp_buf
71- imul (reg_tmp, reg_n_loop, 16 * param_.align_build_block_num * sizeof (int ));
72- const auto reg_tmp_buf = rp. reg <Reg64> ();
73- mov (reg_tmp_buf, ptr[rp. p [ 0 ] + GET_OFF (tmp_buf)]);
74- add (reg_tmp_buf, reg_tmp);
74+ imul (* reg_tmp, reg_n_loop, 16 * param_.align_build_block_num * sizeof (int ));
75+ auto reg_tmp_buf = allocate_reg ();
76+ mov (* reg_tmp_buf, ptr[reg_param + GET_OFF (tmp_buf)]);
77+ add (* reg_tmp_buf, * reg_tmp);
7578 for (int idx = 0 ; idx < block_num; idx++)
76- tilestored (ptr[reg_tmp_buf + reg_strideC + idx * 16 * sizeof (int )], Tmm (idx));
79+ tilestored (ptr[* reg_tmp_buf + reg_strideC + idx * 16 * sizeof (int )], Tmm (idx));
7780 // dequant + add_bias
78- const auto zmms = rp. regs <Zmm, 4 >();
79- const auto reg_tmp2 = rp. reg <Reg64> ();
80- const auto reg_scale_w = rp. reg <Reg64> ();
81- const auto reg_scale_a = rp. reg <Reg64> ();
82- const auto reg_bias = rp. reg <Reg64> ();
83- mov (reg_scale_w, ptr[rp. p [ 0 ] + GET_OFF (scale_w)]);
84- mov (reg_scale_a, ptr[rp. p [ 0 ] + GET_OFF (scale_a)]);
85- mov (reg_bias, ptr[rp. p [ 0 ] + GET_OFF (bias)]);
86- mov (reg_tmp_buf, ptr[rp. p [ 0 ] + GET_OFF (tmp_buf)]);
87-
88- imul (reg_tmp2, reg_m_loop, 16 * sizeof (float )); // offset of scale_a
81+ auto zmms = regs<Zmm, 4 >(0 );
82+ auto reg_tmp2 = allocate_reg ();
83+ auto reg_scale_w = allocate_reg ();
84+ auto reg_scale_a = allocate_reg ();
85+ auto reg_bias = allocate_reg ();
86+ mov (* reg_scale_w, ptr[reg_param + GET_OFF (scale_w)]);
87+ mov (* reg_scale_a, ptr[reg_param + GET_OFF (scale_a)]);
88+ mov (* reg_bias, ptr[reg_param + GET_OFF (bias)]);
89+ mov (* reg_tmp_buf, ptr[reg_param + GET_OFF (tmp_buf)]);
90+
91+ imul (* reg_tmp2, reg_m_loop, 16 * sizeof (float )); // offset of scale_a
8992 for (int idx = 0 ; idx < block_num; idx++) {
90- vmovups (zmms[0 ], ptr[reg_scale_w + reg_tmp + idx * 16 * sizeof (float )]);
91- if (param_.add_bias ) vmovups (zmms[1 ], ptr[reg_bias + reg_tmp + idx * 16 * sizeof (float )]);
93+ vmovups (zmms[0 ], ptr[* reg_scale_w + * reg_tmp + idx * 16 * sizeof (float )]);
94+ if (param_.add_bias ) vmovups (zmms[1 ], ptr[* reg_bias + * reg_tmp + idx * 16 * sizeof (float )]);
9295 for (int row_loop = 0 ; row_loop < 16 ; row_loop++) {
93- vcvtdq2ps (zmms[2 ], ptr[reg_tmp_buf + reg_tmp + (idx * 16 + row_loop * param_.pad_n ) * sizeof (float )]);
94- vbroadcastss (zmms[3 ], dword[reg_scale_a + reg_tmp2 + row_loop * sizeof (float )]);
96+ vcvtdq2ps (zmms[2 ], ptr[* reg_tmp_buf + * reg_tmp + (idx * 16 + row_loop * param_.pad_n ) * sizeof (float )]);
97+ vbroadcastss (zmms[3 ], dword[* reg_scale_a + * reg_tmp2 + row_loop * sizeof (float )]);
9598 vmulps (zmms[2 ], zmms[2 ], zmms[3 ]);
9699 if (param_.add_bias )
97100 vfmadd213ps (zmms[2 ], zmms[0 ], zmms[1 ]);
98101 else
99102 vmulps (zmms[2 ], zmms[2 ], zmms[0 ]);
100- vmovups (ptr[reg_tmp_buf + reg_tmp + (idx * 16 + row_loop * param_.pad_n ) * sizeof (float )], zmms[2 ]);
103+ vmovups (ptr[* reg_tmp_buf + * reg_tmp + (idx * 16 + row_loop * param_.pad_n ) * sizeof (float )], zmms[2 ]);
101104 }
102105 }
103106 }
104107 };
105108
106- auto tmp_buf_load_M_row = [&](const int M, const Reg64& offset) {
107- const auto reg_tmp_buf = rp. reg <Reg64> ();
108- mov (reg_tmp_buf, ptr[rp. p [ 0 ] + GET_OFF (tmp_buf)]);
109- for (int i = 0 ; i < M; i++) vmovups (Zmm (i), ptr[reg_tmp_buf + offset + (i * param_.pad_n ) * sizeof (int )]);
109+ auto tmp_buf_load_M_row = [&](int M, Reg64& offset) {
110+ auto reg_tmp_buf = allocate_reg ();
111+ mov (* reg_tmp_buf, ptr[reg_param + GET_OFF (tmp_buf)]);
112+ for (int i = 0 ; i < M; i++) vmovups (Zmm (i), ptr[* reg_tmp_buf + offset + (i * param_.pad_n ) * sizeof (int )]);
110113 };
111114
112- auto get_16_abs_max_zmm = [&](const std::array<Zmm, 16 >& zmms, const Reg64& reg_max_abs_loop,
113- const bool need_mask = false ) {
114- const auto reg_tmp = rp.reg <Reg64>();
115- const auto reg_tmp_buf = rp.reg <Reg64>();
116- mov (reg_tmp_buf, ptr[rp.p [0 ] + GET_OFF (tmp_buf)]);
117- imul (reg_tmp, reg_max_abs_loop, 16 * sizeof (int ));
115+ auto get_16_abs_max_zmm = [&](std::array<Zmm, 16 >& zmms, Reg64& reg_max_abs_loop, bool need_mask = false ) {
116+ auto reg_tmp = allocate_reg ();
117+ auto reg_tmp_buf = allocate_reg ();
118+ mov (*reg_tmp_buf, ptr[reg_param + GET_OFF (tmp_buf)]);
119+ imul (*reg_tmp, reg_max_abs_loop, 16 * sizeof (int ));
118120 for (int i = 0 ; i < 16 ; i++)
119121 vrangeps (need_mask ? zmms[i] | matC_n_mask : zmms[i], zmms[i],
120- ptr[reg_tmp_buf + reg_tmp + i * param_.pad_n * sizeof (int )], 11U );
122+ ptr[* reg_tmp_buf + * reg_tmp + i * param_.pad_n * sizeof (int )], 11U );
121123 };
122124
123- auto log2n_max_reduce_16x16 = [&](const std::array<Zmm, 16 >& zmms) {
125+ auto log2n_max_reduce_16x16 = [&](std::array<Zmm, 16 >& zmms) {
124126 int i = 8 ;
125127 while (i != 0 ) {
126128 for (int ii = 0 ; ii < i; ii++) vmaxps (zmms[ii], zmms[ii], zmms[ii + i]);
127129 i /= 2 ;
128130 }
129131 };
130132
131- auto write_back_scale = [&](const Zmm& scale, const int M) {
132- const auto reg_tmp = rp. reg <Reg64> ();
133- const auto reg_scale_dst = rp. reg <Reg64> ();
134- mov (reg_scale_dst, ptr[rp. p [ 0 ] + GET_OFF (scale_dst)]);
133+ auto write_back_scale = [&](Zmm& scale, int M) {
134+ auto reg_tmp = allocate_reg ();
135+ auto reg_scale_dst = allocate_reg ();
136+ mov (* reg_scale_dst, ptr[reg_param + GET_OFF (scale_dst)]);
135137 vmulps (scale, scale, zword_b[rip + data_label]);
136- imul (reg_tmp, reg_m_loop, 16 * sizeof (float ));
137- vmovups (M == 16 ? ptr[reg_scale_dst + reg_tmp] : ptr[reg_scale_dst + reg_tmp] | scaleC_mask, scale);
138+ imul (* reg_tmp, reg_m_loop, 16 * sizeof (float ));
139+ vmovups (M == 16 ? ptr[* reg_scale_dst + * reg_tmp] : ptr[* reg_scale_dst + * reg_tmp] | scaleC_mask, scale);
138140 vrcp14ps (scale, scale);
139141 vmovups (ptr[rip + data_label + stack_tmpbuf_offset], scale);
140142 };
141143
142144 auto calculate_scale = [&](int M, std::string label_prefix) {
143- const auto zmms = rp. regs <Zmm, 16 >();
145+ auto zmms = regs<Zmm, 16 >(0 );
144146 for (int i = 0 ; i < 16 ; i++) vxorps (zmms[i], zmms[i], zmms[i]);
145147 // calculate 16 row abs max in 16 zmms
146148 {
147- const auto reg_max_abs_loop = rp. reg <Reg64> ();
148- xor_ (reg_max_abs_loop, reg_max_abs_loop);
149+ auto reg_max_abs_loop = allocate_reg ();
150+ xor_ (* reg_max_abs_loop, * reg_max_abs_loop);
149151 if (param_.n / 16 > 0 ) {
150152 L (label_prefix + " max_abs_loop" );
151- get_16_abs_max_zmm (zmms, reg_max_abs_loop);
152- inc (reg_max_abs_loop);
153- cmp (reg_max_abs_loop, param_.n / 16 );
153+ get_16_abs_max_zmm (zmms, * reg_max_abs_loop);
154+ inc (* reg_max_abs_loop);
155+ cmp (* reg_max_abs_loop, param_.n / 16 );
154156 jl (label_prefix + " max_abs_loop" );
155157 }
156158 if (param_.write_mask != 0 ) {
157- get_16_abs_max_zmm (zmms, reg_max_abs_loop, true );
159+ get_16_abs_max_zmm (zmms, * reg_max_abs_loop, true );
158160 }
159161 }
160162
161163 // get scale
162- transpose_16x16_ps (zmms, rp. regs <Zmm, 16 >());
164+ transpose_16x16_ps (zmms, regs<Zmm, 16 >(16 ));
163165 log2n_max_reduce_16x16 (zmms);
164166 write_back_scale (zmms[0 ], M);
165167 };
166168
167- auto quant_write_back_Mx16 = [&](const int M, const Reg64& store_n_loop, const bool need_mask = false ) {
168- const auto reg_tmp = rp. reg <Reg64> ();
169- const auto reg_tmp2 = rp. reg <Reg64> ();
170- const auto reg_dst = rp. reg <Reg64> ();
171- mov (reg_dst, ptr[rp. p [ 0 ] + GET_OFF (dst)]);
172- imul (reg_tmp, store_n_loop, 16 * sizeof (float ));
173- tmp_buf_load_M_row (M, reg_tmp);
174- imul (reg_tmp, reg_m_loop, 16 * param_.n );
175- imul (reg_tmp2, store_n_loop, 16 );
176- add (reg_tmp, reg_tmp2);
169+ auto quant_write_back_Mx16 = [&](int M, Reg64& store_n_loop, bool need_mask = false ) {
170+ auto reg_tmp = allocate_reg ();
171+ auto reg_tmp2 = allocate_reg ();
172+ auto reg_dst = allocate_reg ();
173+ mov (* reg_dst, ptr[reg_param + GET_OFF (dst)]);
174+ imul (* reg_tmp, store_n_loop, 16 * sizeof (float ));
175+ tmp_buf_load_M_row (M, * reg_tmp);
176+ imul (* reg_tmp, reg_m_loop, 16 * param_.n );
177+ imul (* reg_tmp2, store_n_loop, 16 );
178+ add (* reg_tmp, * reg_tmp2);
177179 for (int i = 0 ; i < M; i++) {
178180 int quant_scale = i * sizeof (float ) + stack_tmpbuf_offset;
179181 vmulps (Zmm (i), Zmm (i), zword_b[rip + data_label + quant_scale]);
180182 vcvtps2dq (Zmm (i), Zmm (i));
181183 vpmovsdb (
182- need_mask ? ptr[reg_dst + reg_tmp + i * param_.n ] | matC_n_mask : ptr[reg_dst + reg_tmp + i * param_.n ],
184+ need_mask ? ptr[* reg_dst + * reg_tmp + i * param_.n ] | matC_n_mask : ptr[* reg_dst + * reg_tmp + i * param_.n ],
183185 Zmm (i));
184186 }
185187 };
@@ -197,16 +199,16 @@ void jit_amx_s8s8_dynamic_quant_matmul_t::generate() {
197199
198200 calculate_scale (M, label_prefix);
199201
200- const auto store_n_loop = rp. reg <Reg64> ();
201- xor_ (store_n_loop, store_n_loop);
202+ auto store_n_loop = allocate_reg ();
203+ xor_ (* store_n_loop, * store_n_loop);
202204 if (param_.n / 16 > 0 ) {
203205 L (label_prefix + " store_n_loop" );
204- quant_write_back_Mx16 (M, store_n_loop);
205- inc (store_n_loop);
206- cmp (store_n_loop, param_.n / 16 );
206+ quant_write_back_Mx16 (M, * store_n_loop);
207+ inc (* store_n_loop);
208+ cmp (* store_n_loop, param_.n / 16 );
207209 jl (label_prefix + " store_n_loop" );
208210 }
209- if (param_.write_mask ) quant_write_back_Mx16 (M, store_n_loop, true );
211+ if (param_.write_mask ) quant_write_back_Mx16 (M, * store_n_loop, true );
210212 };
211213
212214 prepare_mask ();
0 commit comments