Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 84f1726

Browse files
committed
Revert "[Kernels] Register pool utility (#799)"
This reverts commit aab68c6.
1 parent 0025e21 commit 84f1726

21 files changed

+378
-581
lines changed

intel_extension_for_transformers/backends/neural_engine/CMakePresets.json

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,6 @@
2020
},
2121
"vendor": { "microsoft.com/VisualStudioRemoteSettings/CMake/1.0": { "sourceDir": "$env{HOME}/.vs/$ms{projectDirName}/intel_extension_for_transformers/backends/neural_engine" } }
2222
},
23-
{
24-
"name": "linux-debug-kernels",
25-
"displayName": "Linux Debuge Kernels",
26-
"inherits": "linux-debug",
27-
"cacheVariables": { "NE_WITH_SPARSELIB_ONLY": "ON" }
28-
},
2923
{
3024
"name": "linux-release",
3125
"displayName": "Linux Release",
@@ -83,13 +77,6 @@
8377
"NE_WITH_SPARSELIB_BENCHMARK": "ON"
8478
}
8579
},
86-
{
87-
"name": "x64-debug-kernels",
88-
"displayName": "x64 Debug Kernels",
89-
"description": "Windows x64 Debug Kernels",
90-
"inherits": "x64-debug",
91-
"cacheVariables": { "NE_WITH_SPARSELIB_ONLY": "ON" }
92-
},
9380
{
9481
"name": "x64-release",
9582
"displayName": "x64 Release",

intel_extension_for_transformers/backends/neural_engine/kernels/include/operator_desc.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class operator_desc {
4444
: ker_kind_(ker_kind),
4545
ker_prop_(ker_prop),
4646
engine_kind_(eng_kind),
47-
runtime_kind_(jd::runtime_kind::undef),
4847
impl_nthr_(omp_get_max_threads()),
4948
ts_descs_(ts_descs),
5049
attrs_(attrs),

intel_extension_for_transformers/backends/neural_engine/kernels/src/jit_domain/jit_amx_s8s8_dynamic_quant_matmul.cpp

Lines changed: 92 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
// limitations under the License.
1414

1515
#include "jit_domain/jit_amx_s8s8_dynamic_quant_matmul.hpp"
16-
17-
#include "regs_pool.hpp"
18-
1916
namespace jd {
2017

2118
#define GET_OFF(field) offsetof(ssd::dynamic_quant_matmul_data_t, field)
@@ -24,162 +21,167 @@ void jit_amx_s8s8_dynamic_quant_matmul_t::generate() {
2421
Xbyak::Label data_label;
2522
inLocalLabel();
2623
{
24+
const int const_var_size = 4 + 2 * sizeof(tileconfig_t) + 64;
2725
const int stack_tmpbuf_offset = 4 + 2 * sizeof(tileconfig_t);
2826
auto trans_block_col = param_.k / param_.tile_k;
29-
const auto does_calc = param_.align_m_loop > 0 || param_.tail_m != 0;
30-
regs_pool rp(this, 1, {does_calc ? 11 : 6, does_calc ? 32 : 0, 0});
31-
const auto reg_m_loop = rp.reg<Reg64>();
32-
const auto reg_n_loop = rp.reg<Reg64>();
33-
const auto reg_strideA = rp.reg<Reg64>();
34-
const auto reg_strideB = rp.reg<Reg64>();
35-
const auto reg_strideC = rp.reg<Reg64>();
27+
Xbyak::util::StackFrame st(this, 1, 11, const_var_size);
28+
int reg_idx = 0;
29+
const Reg64& reg_param = st.p[0];
30+
const Reg64& reg_m_loop = st.t[reg_idx++];
31+
const Reg64& reg_n_loop = st.t[reg_idx++];
32+
const Reg64& reg_strideA = st.t[reg_idx++];
33+
const Reg64& reg_strideB = st.t[reg_idx++];
34+
const Reg64& reg_strideC = st.t[reg_idx++];
35+
auto allocate_reg = [&] {
36+
Reg64* reg = &const_cast<Reg64&>(st.t[reg_idx++]);
37+
return std::shared_ptr<Reg64>(reg, [&](...) { reg_idx--; });
38+
};
3639

3740
auto prepare_mask = [&]() {
38-
const auto reg_tmp = rp.reg<Xbyak::Reg32>();
39-
mov(reg_tmp, 0xffff >> param_.write_mask);
40-
kmovd(matC_n_mask, reg_tmp);
41-
mov(reg_tmp, 0xffff >> (16 - param_.tail_m));
42-
kmovd(scaleC_mask, reg_tmp);
41+
auto reg_tmp = allocate_reg();
42+
mov(reg_tmp->cvt32(), 0xffff >> param_.write_mask);
43+
kmovd(matC_n_mask, reg_tmp->cvt32());
44+
mov(reg_tmp->cvt32(), 0xffff >> (16 - param_.tail_m));
45+
kmovd(scaleC_mask, reg_tmp->cvt32());
4346
};
4447

4548
auto ip_16x16 = [&](int block_num) {
46-
const auto reg_tmp = rp.reg<Reg64>();
49+
auto reg_tmp = allocate_reg();
4750
// build block
4851
{
49-
const auto reg_matA_addr = rp.reg<Reg64>();
50-
const auto reg_matB_addr = rp.reg<Reg64>();
52+
auto reg_matA_addr = allocate_reg();
53+
auto reg_matB_addr = allocate_reg();
5154
for (int i = 0; i < block_num; i++) tilezero(Tmm(i));
5255
// prepare addr & stride;
53-
mov(reg_matA_addr, ptr[rp.p[0] + GET_OFF(activation)]);
54-
mov(reg_matB_addr, ptr[rp.p[0] + GET_OFF(reordered_weight)]);
55-
imul(reg_tmp, reg_m_loop, 16 * param_.k);
56-
add(reg_matA_addr, reg_tmp);
57-
imul(reg_tmp, reg_n_loop, param_.align_build_block_num * trans_block_col * 64 * (param_.tile_k / 4));
58-
add(reg_matB_addr, reg_tmp);
56+
mov(*reg_matA_addr, ptr[reg_param + GET_OFF(activation)]);
57+
mov(*reg_matB_addr, ptr[reg_param + GET_OFF(reordered_weight)]);
58+
imul(*reg_tmp, reg_m_loop, 16 * param_.k);
59+
add(*reg_matA_addr, *reg_tmp);
60+
imul(*reg_tmp, reg_n_loop, param_.align_build_block_num * trans_block_col * 64 * (param_.tile_k / 4));
61+
add(*reg_matB_addr, *reg_tmp);
5962
for (int k_loop = 0; k_loop < param_.k / param_.tile_k; k_loop++) {
60-
tileloadd(Tmm(3), ptr[reg_matA_addr + reg_strideA + k_loop * param_.tile_k]);
63+
tileloadd(Tmm(3), ptr[*reg_matA_addr + reg_strideA + k_loop * param_.tile_k]);
6164
for (int idx = 0; idx < block_num; idx++) {
6265
int offset = idx * trans_block_col * 64 * (param_.tile_k / 4) + k_loop * 64;
63-
tileloadd(Tmm(4 + idx), ptr[reg_matB_addr + reg_strideB + offset]);
66+
tileloadd(Tmm(4 + idx), ptr[*reg_matB_addr + reg_strideB + offset]);
6467
tdpbssd(Tmm(idx), Tmm(3), Tmm(4 + idx));
6568
}
6669
}
6770
}
6871
// store block to tmp_buf & dequant+add_bias
6972
{
7073
// store the block to tmp_buf
71-
imul(reg_tmp, reg_n_loop, 16 * param_.align_build_block_num * sizeof(int));
72-
const auto reg_tmp_buf = rp.reg<Reg64>();
73-
mov(reg_tmp_buf, ptr[rp.p[0] + GET_OFF(tmp_buf)]);
74-
add(reg_tmp_buf, reg_tmp);
74+
imul(*reg_tmp, reg_n_loop, 16 * param_.align_build_block_num * sizeof(int));
75+
auto reg_tmp_buf = allocate_reg();
76+
mov(*reg_tmp_buf, ptr[reg_param + GET_OFF(tmp_buf)]);
77+
add(*reg_tmp_buf, *reg_tmp);
7578
for (int idx = 0; idx < block_num; idx++)
76-
tilestored(ptr[reg_tmp_buf + reg_strideC + idx * 16 * sizeof(int)], Tmm(idx));
79+
tilestored(ptr[*reg_tmp_buf + reg_strideC + idx * 16 * sizeof(int)], Tmm(idx));
7780
// dequant + add_bias
78-
const auto zmms = rp.regs<Zmm, 4>();
79-
const auto reg_tmp2 = rp.reg<Reg64>();
80-
const auto reg_scale_w = rp.reg<Reg64>();
81-
const auto reg_scale_a = rp.reg<Reg64>();
82-
const auto reg_bias = rp.reg<Reg64>();
83-
mov(reg_scale_w, ptr[rp.p[0] + GET_OFF(scale_w)]);
84-
mov(reg_scale_a, ptr[rp.p[0] + GET_OFF(scale_a)]);
85-
mov(reg_bias, ptr[rp.p[0] + GET_OFF(bias)]);
86-
mov(reg_tmp_buf, ptr[rp.p[0] + GET_OFF(tmp_buf)]);
87-
88-
imul(reg_tmp2, reg_m_loop, 16 * sizeof(float)); // offset of scale_a
81+
auto zmms = regs<Zmm, 4>(0);
82+
auto reg_tmp2 = allocate_reg();
83+
auto reg_scale_w = allocate_reg();
84+
auto reg_scale_a = allocate_reg();
85+
auto reg_bias = allocate_reg();
86+
mov(*reg_scale_w, ptr[reg_param + GET_OFF(scale_w)]);
87+
mov(*reg_scale_a, ptr[reg_param + GET_OFF(scale_a)]);
88+
mov(*reg_bias, ptr[reg_param + GET_OFF(bias)]);
89+
mov(*reg_tmp_buf, ptr[reg_param + GET_OFF(tmp_buf)]);
90+
91+
imul(*reg_tmp2, reg_m_loop, 16 * sizeof(float)); // offset of scale_a
8992
for (int idx = 0; idx < block_num; idx++) {
90-
vmovups(zmms[0], ptr[reg_scale_w + reg_tmp + idx * 16 * sizeof(float)]);
91-
if (param_.add_bias) vmovups(zmms[1], ptr[reg_bias + reg_tmp + idx * 16 * sizeof(float)]);
93+
vmovups(zmms[0], ptr[*reg_scale_w + *reg_tmp + idx * 16 * sizeof(float)]);
94+
if (param_.add_bias) vmovups(zmms[1], ptr[*reg_bias + *reg_tmp + idx * 16 * sizeof(float)]);
9295
for (int row_loop = 0; row_loop < 16; row_loop++) {
93-
vcvtdq2ps(zmms[2], ptr[reg_tmp_buf + reg_tmp + (idx * 16 + row_loop * param_.pad_n) * sizeof(float)]);
94-
vbroadcastss(zmms[3], dword[reg_scale_a + reg_tmp2 + row_loop * sizeof(float)]);
96+
vcvtdq2ps(zmms[2], ptr[*reg_tmp_buf + *reg_tmp + (idx * 16 + row_loop * param_.pad_n) * sizeof(float)]);
97+
vbroadcastss(zmms[3], dword[*reg_scale_a + *reg_tmp2 + row_loop * sizeof(float)]);
9598
vmulps(zmms[2], zmms[2], zmms[3]);
9699
if (param_.add_bias)
97100
vfmadd213ps(zmms[2], zmms[0], zmms[1]);
98101
else
99102
vmulps(zmms[2], zmms[2], zmms[0]);
100-
vmovups(ptr[reg_tmp_buf + reg_tmp + (idx * 16 + row_loop * param_.pad_n) * sizeof(float)], zmms[2]);
103+
vmovups(ptr[*reg_tmp_buf + *reg_tmp + (idx * 16 + row_loop * param_.pad_n) * sizeof(float)], zmms[2]);
101104
}
102105
}
103106
}
104107
};
105108

106-
auto tmp_buf_load_M_row = [&](const int M, const Reg64& offset) {
107-
const auto reg_tmp_buf = rp.reg<Reg64>();
108-
mov(reg_tmp_buf, ptr[rp.p[0] + GET_OFF(tmp_buf)]);
109-
for (int i = 0; i < M; i++) vmovups(Zmm(i), ptr[reg_tmp_buf + offset + (i * param_.pad_n) * sizeof(int)]);
109+
auto tmp_buf_load_M_row = [&](int M, Reg64& offset) {
110+
auto reg_tmp_buf = allocate_reg();
111+
mov(*reg_tmp_buf, ptr[reg_param + GET_OFF(tmp_buf)]);
112+
for (int i = 0; i < M; i++) vmovups(Zmm(i), ptr[*reg_tmp_buf + offset + (i * param_.pad_n) * sizeof(int)]);
110113
};
111114

112-
auto get_16_abs_max_zmm = [&](const std::array<Zmm, 16>& zmms, const Reg64& reg_max_abs_loop,
113-
const bool need_mask = false) {
114-
const auto reg_tmp = rp.reg<Reg64>();
115-
const auto reg_tmp_buf = rp.reg<Reg64>();
116-
mov(reg_tmp_buf, ptr[rp.p[0] + GET_OFF(tmp_buf)]);
117-
imul(reg_tmp, reg_max_abs_loop, 16 * sizeof(int));
115+
auto get_16_abs_max_zmm = [&](std::array<Zmm, 16>& zmms, Reg64& reg_max_abs_loop, bool need_mask = false) {
116+
auto reg_tmp = allocate_reg();
117+
auto reg_tmp_buf = allocate_reg();
118+
mov(*reg_tmp_buf, ptr[reg_param + GET_OFF(tmp_buf)]);
119+
imul(*reg_tmp, reg_max_abs_loop, 16 * sizeof(int));
118120
for (int i = 0; i < 16; i++)
119121
vrangeps(need_mask ? zmms[i] | matC_n_mask : zmms[i], zmms[i],
120-
ptr[reg_tmp_buf + reg_tmp + i * param_.pad_n * sizeof(int)], 11U);
122+
ptr[*reg_tmp_buf + *reg_tmp + i * param_.pad_n * sizeof(int)], 11U);
121123
};
122124

123-
auto log2n_max_reduce_16x16 = [&](const std::array<Zmm, 16>& zmms) {
125+
auto log2n_max_reduce_16x16 = [&](std::array<Zmm, 16>& zmms) {
124126
int i = 8;
125127
while (i != 0) {
126128
for (int ii = 0; ii < i; ii++) vmaxps(zmms[ii], zmms[ii], zmms[ii + i]);
127129
i /= 2;
128130
}
129131
};
130132

131-
auto write_back_scale = [&](const Zmm& scale, const int M) {
132-
const auto reg_tmp = rp.reg<Reg64>();
133-
const auto reg_scale_dst = rp.reg<Reg64>();
134-
mov(reg_scale_dst, ptr[rp.p[0] + GET_OFF(scale_dst)]);
133+
auto write_back_scale = [&](Zmm& scale, int M) {
134+
auto reg_tmp = allocate_reg();
135+
auto reg_scale_dst = allocate_reg();
136+
mov(*reg_scale_dst, ptr[reg_param + GET_OFF(scale_dst)]);
135137
vmulps(scale, scale, zword_b[rip + data_label]);
136-
imul(reg_tmp, reg_m_loop, 16 * sizeof(float));
137-
vmovups(M == 16 ? ptr[reg_scale_dst + reg_tmp] : ptr[reg_scale_dst + reg_tmp] | scaleC_mask, scale);
138+
imul(*reg_tmp, reg_m_loop, 16 * sizeof(float));
139+
vmovups(M == 16 ? ptr[*reg_scale_dst + *reg_tmp] : ptr[*reg_scale_dst + *reg_tmp] | scaleC_mask, scale);
138140
vrcp14ps(scale, scale);
139141
vmovups(ptr[rip + data_label + stack_tmpbuf_offset], scale);
140142
};
141143

142144
auto calculate_scale = [&](int M, std::string label_prefix) {
143-
const auto zmms = rp.regs<Zmm, 16>();
145+
auto zmms = regs<Zmm, 16>(0);
144146
for (int i = 0; i < 16; i++) vxorps(zmms[i], zmms[i], zmms[i]);
145147
// calculate 16 row abs max in 16 zmms
146148
{
147-
const auto reg_max_abs_loop = rp.reg<Reg64>();
148-
xor_(reg_max_abs_loop, reg_max_abs_loop);
149+
auto reg_max_abs_loop = allocate_reg();
150+
xor_(*reg_max_abs_loop, *reg_max_abs_loop);
149151
if (param_.n / 16 > 0) {
150152
L(label_prefix + "max_abs_loop");
151-
get_16_abs_max_zmm(zmms, reg_max_abs_loop);
152-
inc(reg_max_abs_loop);
153-
cmp(reg_max_abs_loop, param_.n / 16);
153+
get_16_abs_max_zmm(zmms, *reg_max_abs_loop);
154+
inc(*reg_max_abs_loop);
155+
cmp(*reg_max_abs_loop, param_.n / 16);
154156
jl(label_prefix + "max_abs_loop");
155157
}
156158
if (param_.write_mask != 0) {
157-
get_16_abs_max_zmm(zmms, reg_max_abs_loop, true);
159+
get_16_abs_max_zmm(zmms, *reg_max_abs_loop, true);
158160
}
159161
}
160162

161163
// get scale
162-
transpose_16x16_ps(zmms, rp.regs<Zmm, 16>());
164+
transpose_16x16_ps(zmms, regs<Zmm, 16>(16));
163165
log2n_max_reduce_16x16(zmms);
164166
write_back_scale(zmms[0], M);
165167
};
166168

167-
auto quant_write_back_Mx16 = [&](const int M, const Reg64& store_n_loop, const bool need_mask = false) {
168-
const auto reg_tmp = rp.reg<Reg64>();
169-
const auto reg_tmp2 = rp.reg<Reg64>();
170-
const auto reg_dst = rp.reg<Reg64>();
171-
mov(reg_dst, ptr[rp.p[0] + GET_OFF(dst)]);
172-
imul(reg_tmp, store_n_loop, 16 * sizeof(float));
173-
tmp_buf_load_M_row(M, reg_tmp);
174-
imul(reg_tmp, reg_m_loop, 16 * param_.n);
175-
imul(reg_tmp2, store_n_loop, 16);
176-
add(reg_tmp, reg_tmp2);
169+
auto quant_write_back_Mx16 = [&](int M, Reg64& store_n_loop, bool need_mask = false) {
170+
auto reg_tmp = allocate_reg();
171+
auto reg_tmp2 = allocate_reg();
172+
auto reg_dst = allocate_reg();
173+
mov(*reg_dst, ptr[reg_param + GET_OFF(dst)]);
174+
imul(*reg_tmp, store_n_loop, 16 * sizeof(float));
175+
tmp_buf_load_M_row(M, *reg_tmp);
176+
imul(*reg_tmp, reg_m_loop, 16 * param_.n);
177+
imul(*reg_tmp2, store_n_loop, 16);
178+
add(*reg_tmp, *reg_tmp2);
177179
for (int i = 0; i < M; i++) {
178180
int quant_scale = i * sizeof(float) + stack_tmpbuf_offset;
179181
vmulps(Zmm(i), Zmm(i), zword_b[rip + data_label + quant_scale]);
180182
vcvtps2dq(Zmm(i), Zmm(i));
181183
vpmovsdb(
182-
need_mask ? ptr[reg_dst + reg_tmp + i * param_.n] | matC_n_mask : ptr[reg_dst + reg_tmp + i * param_.n],
184+
need_mask ? ptr[*reg_dst + *reg_tmp + i * param_.n] | matC_n_mask : ptr[*reg_dst + *reg_tmp + i * param_.n],
183185
Zmm(i));
184186
}
185187
};
@@ -197,16 +199,16 @@ void jit_amx_s8s8_dynamic_quant_matmul_t::generate() {
197199

198200
calculate_scale(M, label_prefix);
199201

200-
const auto store_n_loop = rp.reg<Reg64>();
201-
xor_(store_n_loop, store_n_loop);
202+
auto store_n_loop = allocate_reg();
203+
xor_(*store_n_loop, *store_n_loop);
202204
if (param_.n / 16 > 0) {
203205
L(label_prefix + "store_n_loop");
204-
quant_write_back_Mx16(M, store_n_loop);
205-
inc(store_n_loop);
206-
cmp(store_n_loop, param_.n / 16);
206+
quant_write_back_Mx16(M, *store_n_loop);
207+
inc(*store_n_loop);
208+
cmp(*store_n_loop, param_.n / 16);
207209
jl(label_prefix + "store_n_loop");
208210
}
209-
if (param_.write_mask) quant_write_back_Mx16(M, store_n_loop, true);
211+
if (param_.write_mask) quant_write_back_Mx16(M, *store_n_loop, true);
210212
};
211213

212214
prepare_mask();

0 commit comments

Comments
 (0)