Skip to content

Commit c4ab306

Browse files
committed
fedp drl shared mul
1 parent a4ab447 commit c4ab306

File tree

9 files changed

+304
-125
lines changed

9 files changed

+304
-125
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Copyright © 2019-2023
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
`include "VX_define.vh"
15+
16+
module VX_tcu_drl_exp_bias (
17+
input wire enable,
18+
input wire [2:0] fmt_s,
19+
input wire [15:0] a,
20+
input wire [15:0] b,
21+
output logic [7:0] raw_exp_y,
22+
output logic exp_low_larger,
23+
output logic [6:0] raw_exp_diff
24+
);
25+
//NOTE: exception handling neglected for now
26+
`UNUSED_VAR({a, b, enable});
27+
28+
//FP16 exponent addition and bias
29+
wire [7:0] raw_exp_fp16;
30+
wire [7:0] fp16_32_conv_bias = 8'd98; //127-30 + 1
31+
VX_csa_tree #(
32+
.N(3),
33+
.W(8),
34+
.S(8)
35+
) biasexp_fp16(
36+
.operands({{3'd0, a[14:10]}, {3'd0, b[14:10]}, fp16_32_conv_bias}),
37+
.sum (raw_exp_fp16),
38+
`UNUSED_PIN (cout)
39+
);
40+
41+
//BF16 exponent addition and bias
42+
wire [7:0] raw_exp_bf16;
43+
wire [9:0] neg_bias = 10'b1110000010; //-127+1
44+
wire [9:0] raw_exp_bf16_signed;
45+
`UNUSED_VAR(raw_exp_bf16_signed);
46+
VX_csa_tree #(
47+
.N(3),
48+
.W(10),
49+
.S(10)
50+
) biasexp_bf16(
51+
.operands({{2'd0, a[14:7]}, {2'd0, b[14:7]}, neg_bias}),
52+
.sum (raw_exp_bf16_signed),
53+
`UNUSED_PIN (cout)
54+
);
55+
assign raw_exp_bf16 = raw_exp_bf16_signed[9] ? -raw_exp_bf16_signed[7:0] : raw_exp_bf16_signed[7:0];
56+
57+
//FP8 (E4M3) exponent addition and bias
58+
wire [7:0] raw_exp_fp8;
59+
wire [1:0][4:0] raw_exp_fp8_sub;
60+
for (genvar i = 0; i < 2; i++) begin : g_fp8_sub
61+
VX_ks_adder #(
62+
.N(4)
63+
) raw_exp_fp8_sub_add (
64+
.dataa (a[(i*8)+6 -: 4]),
65+
.datab (b[(i*8)+6 -: 4]),
66+
.sum (raw_exp_fp8_sub[i][3:0]),
67+
.cout (raw_exp_fp8_sub[i][4])
68+
);
69+
end
70+
wire [5:0] raw_exp_fp8_diff = {1'b0, raw_exp_fp8_sub[1]} - {1'b0, raw_exp_fp8_sub[0]};
71+
wire fp8_exp_low_larger = raw_exp_fp8_diff[5];
72+
wire [4:0] raw_exp_fp8_unbiased = fp8_exp_low_larger ? raw_exp_fp8_sub[0] : raw_exp_fp8_sub[1];
73+
wire [7:0] fp8_conv_bias_fp32 = 8'd115; //127-14+2
74+
VX_ks_adder #(
75+
.N(8)
76+
) biasexp_fp8 (
77+
.dataa ({3'd0, raw_exp_fp8_unbiased}),
78+
.datab (fp8_conv_bias_fp32),
79+
.sum (raw_exp_fp8),
80+
`UNUSED_PIN (cout)
81+
);
82+
83+
//BF8 (E5M2) exponent addition and bias
84+
wire [7:0] raw_exp_bf8;
85+
wire [1:0][5:0] raw_exp_bf8_sub;
86+
for (genvar j = 0; j < 2; j++) begin : g_bf8_sub
87+
VX_ks_adder #(
88+
.N(5)
89+
) raw_exp_bf8_sub_add (
90+
.dataa (a[(j*8)+6 -: 5]),
91+
.datab (b[(j*8)+6 -: 5]),
92+
.sum (raw_exp_bf8_sub[j][4:0]),
93+
.cout (raw_exp_bf8_sub[j][5])
94+
);
95+
end
96+
wire [6:0] raw_exp_bf8_diff = {1'b0, raw_exp_bf8_sub[1]} - {1'b0, raw_exp_bf8_sub[0]};
97+
wire bf8_exp_low_larger = raw_exp_bf8_diff[6];
98+
wire [5:0] raw_exp_bf8_unbiased = bf8_exp_low_larger ? raw_exp_bf8_sub[0] : raw_exp_bf8_sub[1];
99+
wire [7:0] bf8_conv_bias_fp32 = 8'd99; //127-30+2
100+
VX_ks_adder #(
101+
.N(8)
102+
) biasexp_bf8 (
103+
.dataa ({2'd0, raw_exp_bf8_unbiased}),
104+
.datab (bf8_conv_bias_fp32),
105+
.sum (raw_exp_bf8),
106+
`UNUSED_PIN (cout)
107+
);
108+
109+
//Select exp out based on datatype
110+
always_comb begin
111+
case(fmt_s[2:0])
112+
3'd1: begin
113+
raw_exp_y = raw_exp_fp16;
114+
exp_low_larger = 1'bx;
115+
raw_exp_diff = 7'dx;
116+
end
117+
3'd2: begin
118+
raw_exp_y = raw_exp_bf16;
119+
exp_low_larger = 1'bx;
120+
raw_exp_diff = 7'dx;
121+
end
122+
3'd3: begin
123+
raw_exp_y = raw_exp_fp8;
124+
exp_low_larger = fp8_exp_low_larger;
125+
raw_exp_diff = {raw_exp_fp8_diff[5], raw_exp_fp8_diff};
126+
end
127+
3'd4: begin
128+
raw_exp_y = raw_exp_bf8;
129+
exp_low_larger = bf8_exp_low_larger;
130+
raw_exp_diff = raw_exp_bf8_diff;
131+
end
132+
default: begin
133+
raw_exp_y = 8'dx;
134+
exp_low_larger = 1'bx;
135+
raw_exp_diff = 7'dx;
136+
end
137+
endcase
138+
end
139+
140+
endmodule

hw/rtl/tcu/drl/VX_tcu_drl_mul_exp.sv

Lines changed: 26 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -26,146 +26,47 @@ module VX_tcu_drl_mul_exp #(
2626
output logic [N-1:0][24:0] raw_sigs
2727
);
2828

29-
//raw fp signals
30-
wire [N-2:0] mul_sign_fp16, mul_sign_bf16, mul_sign_fp8, mul_sign_bf8;
31-
wire [N-2:0][7:0] mul_exp_fp16, mul_exp_bf16, mul_exp_fp8, mul_exp_bf8;
32-
wire [N-2:0][23:0] mul_sig_fp16, mul_sig_bf16, mul_sig_fp8, mul_sig_bf8;
33-
34-
//raw int signals
35-
wire [N-2:0][16:0] mul_sig_int8;
36-
wire [N-2:0][9:0] mul_sig_uint4;
37-
3829
//muxed signals
39-
logic [N-1:0] mul_sign_mux;
40-
logic [N-1:0][7:0] mul_exp_mux;
41-
logic [N-1:0][23:0] mul_sig_mux;
42-
logic [N-1:0][24:0] int_vals_mux;
30+
logic [N-1:0][7:0] raw_exps;
4331

4432
for (genvar i = 0; i < N-1; i++) begin : g_prod
45-
// FP16 multiplication
46-
VX_tcu_drl_fp16mul fp16mul (
47-
.enable (enable),
48-
.a (a_rows[i]),
49-
.b (b_cols[i]),
50-
.sign_y (mul_sign_fp16[i]),
51-
.raw_exp_y (mul_exp_fp16[i]),
52-
.raw_sig_y (mul_sig_fp16[i])
53-
);
54-
55-
// BF16 multiplication
56-
VX_tcu_drl_bf16mul bf16mul (
57-
.enable (enable),
58-
.a (a_rows[i]),
59-
.b (b_cols[i]),
60-
.sign_y (mul_sign_bf16[i]),
61-
.raw_exp_y (mul_exp_bf16[i]),
62-
.raw_sig_y (mul_sig_bf16[i])
63-
);
64-
65-
// FP8 E4M3 multiplication
66-
VX_tcu_drl_fp8mul fp8mul (
67-
.enable (enable),
68-
.a (a_rows[i]),
69-
.b (b_cols[i]),
70-
.sign_y (mul_sign_fp8[i]),
71-
.raw_exp_y (mul_exp_fp8[i]),
72-
.raw_sig_y (mul_sig_fp8[i])
73-
);
74-
75-
// FP8 E5M2 multiplication
76-
VX_tcu_drl_bf8mul bf8mul (
77-
.enable (enable),
78-
.a (a_rows[i]),
79-
.b (b_cols[i]),
80-
.sign_y (mul_sign_bf8[i]),
81-
.raw_exp_y (mul_exp_bf8[i]),
82-
.raw_sig_y (mul_sig_bf8[i])
33+
wire exp_low_larger;
34+
wire [6:0] raw_exp_diff;
35+
36+
//shared significand multiplier
37+
VX_tcu_drl_shared_mul shared_mul_inst (
38+
.enable (enable),
39+
.fmt_s (fmt_s),
40+
.a (a_rows[i]),
41+
.b (b_cols[i]),
42+
.exp_low_larger (exp_low_larger),
43+
.raw_exp_diff (raw_exp_diff),
44+
.y (raw_sigs[i])
8345
);
8446

85-
// INT8 multiplication
86-
VX_tcu_drl_int8mul int8mul (
87-
.enable (enable),
88-
.a (a_rows[i]),
89-
.b (b_cols[i]),
90-
.signed_y (mul_sig_int8[i])
47+
//exponent add and bias
48+
VX_tcu_drl_exp_bias exp_bias_inst (
49+
.enable (enable),
50+
.fmt_s (fmt_s[2:0]),
51+
.a (a_rows[i]),
52+
.b (b_cols[i]),
53+
.raw_exp_y (raw_exps[i]),
54+
.exp_low_larger (exp_low_larger),
55+
.raw_exp_diff (raw_exp_diff)
9156
);
92-
93-
// UINT4 multiplication
94-
VX_tcu_drl_uint4mul uint4mul (
95-
.enable (enable),
96-
.a (a_rows[i]),
97-
.b (b_cols[i]),
98-
.unsigned_y (mul_sig_uint4[i])
99-
);
100-
101-
//FP Format selection
102-
always_comb begin
103-
case(fmt_s[2:0])
104-
3'd1: begin
105-
mul_sign_mux[i] = mul_sign_fp16[i];
106-
mul_exp_mux[i] = mul_exp_fp16[i];
107-
mul_sig_mux[i] = mul_sig_fp16[i];
108-
end
109-
3'd2: begin
110-
mul_sign_mux[i] = mul_sign_bf16[i];
111-
mul_exp_mux[i] = mul_exp_bf16[i];
112-
mul_sig_mux[i] = mul_sig_bf16[i];
113-
end
114-
3'd3: begin
115-
mul_sign_mux[i] = mul_sign_fp8[i];
116-
mul_exp_mux[i] = mul_exp_fp8[i];
117-
mul_sig_mux[i] = mul_sig_fp8[i];
118-
end
119-
3'd4: begin
120-
mul_sign_mux[i] = mul_sign_bf8[i];
121-
mul_exp_mux[i] = mul_exp_bf8[i];
122-
mul_sig_mux[i] = mul_sig_bf8[i];
123-
end
124-
default: begin
125-
mul_sign_mux[i] = 1'bx;
126-
mul_exp_mux[i] = 8'hxx;
127-
mul_sig_mux[i] = 24'hxxxxxx;
128-
end
129-
endcase
130-
end
131-
132-
//INT Format selection (sign extend)
133-
always_comb begin
134-
case(fmt_s[2:0])
135-
3'd1: begin
136-
int_vals_mux[i] = 25'($signed(mul_sig_int8[i]));
137-
end
138-
3'd4: begin
139-
int_vals_mux[i] = {15'd0, mul_sig_uint4[i]};
140-
end
141-
default: begin
142-
int_vals_mux[i] = 25'hxxxxxxx;
143-
end
144-
endcase
145-
end
14657
end
14758

148-
//FP c_val integration
149-
always_comb begin
150-
mul_sign_mux[N-1] = c_val[31];
151-
mul_exp_mux[N-1] = c_val[30:23];
152-
mul_sig_mux[N-1] = {1'b1, c_val[22:0]};
153-
end
154-
155-
//INT c_val integration
156-
assign int_vals_mux[N-1] = c_val[24:0];
59+
//c_val integration
60+
assign raw_exps[N-1] = c_val[30:23];
61+
assign raw_sigs[N-1] = fmt_s[3] ? c_val[24:0] : {c_val[31], 1'b1, c_val[22:0]};
15762

15863
//Raw maximum exponent finder (in parallel to mul) and shift amounts
15964
VX_tcu_drl_max_exp #(
16065
.N(N)
16166
) find_max_exp (
162-
.exponents (mul_exp_mux),
67+
.exponents (raw_exps),
16368
.max_exp (raw_max_exp),
16469
.shift_amounts (shift_amounts)
16570
);
16671

167-
for (genvar i = 0; i < N; i++) begin : g_fp_int_sig_sel
168-
assign raw_sigs[i] = fmt_s[3] ? int_vals_mux[i] : {mul_sign_mux[i], mul_sig_mux[i]};
169-
end
170-
17172
endmodule

0 commit comments

Comments
 (0)