@@ -26,146 +26,47 @@ module VX_tcu_drl_mul_exp #(
2626 output logic [N - 1 : 0 ][24 : 0 ] raw_sigs
2727);
2828
29- // raw fp signals
30- wire [N - 2 : 0 ] mul_sign_fp16, mul_sign_bf16, mul_sign_fp8, mul_sign_bf8;
31- wire [N - 2 : 0 ][7 : 0 ] mul_exp_fp16, mul_exp_bf16, mul_exp_fp8, mul_exp_bf8;
32- wire [N - 2 : 0 ][23 : 0 ] mul_sig_fp16, mul_sig_bf16, mul_sig_fp8, mul_sig_bf8;
33-
34- // raw int signals
35- wire [N - 2 : 0 ][16 : 0 ] mul_sig_int8;
36- wire [N - 2 : 0 ][9 : 0 ] mul_sig_uint4;
37-
3829 // muxed signals
39- logic [N - 1 : 0 ] mul_sign_mux;
40- logic [N - 1 : 0 ][7 : 0 ] mul_exp_mux;
41- logic [N - 1 : 0 ][23 : 0 ] mul_sig_mux;
42- logic [N - 1 : 0 ][24 : 0 ] int_vals_mux;
30+ logic [N - 1 : 0 ][7 : 0 ] raw_exps;
4331
4432 for (genvar i = 0 ; i < N - 1 ; i++ ) begin : g_prod
45- // FP16 multiplication
46- VX_tcu_drl_fp16mul fp16mul (
47- .enable (enable),
48- .a (a_rows[i]),
49- .b (b_cols[i]),
50- .sign_y (mul_sign_fp16[i]),
51- .raw_exp_y (mul_exp_fp16[i]),
52- .raw_sig_y (mul_sig_fp16[i])
53- );
54-
55- // BF16 multiplication
56- VX_tcu_drl_bf16mul bf16mul (
57- .enable (enable),
58- .a (a_rows[i]),
59- .b (b_cols[i]),
60- .sign_y (mul_sign_bf16[i]),
61- .raw_exp_y (mul_exp_bf16[i]),
62- .raw_sig_y (mul_sig_bf16[i])
63- );
64-
65- // FP8 E4M3 multiplication
66- VX_tcu_drl_fp8mul fp8mul (
67- .enable (enable),
68- .a (a_rows[i]),
69- .b (b_cols[i]),
70- .sign_y (mul_sign_fp8[i]),
71- .raw_exp_y (mul_exp_fp8[i]),
72- .raw_sig_y (mul_sig_fp8[i])
73- );
74-
75- // FP8 E5M2 multiplication
76- VX_tcu_drl_bf8mul bf8mul (
77- .enable (enable),
78- .a (a_rows[i]),
79- .b (b_cols[i]),
80- .sign_y (mul_sign_bf8[i]),
81- .raw_exp_y (mul_exp_bf8[i]),
82- .raw_sig_y (mul_sig_bf8[i])
33+ wire exp_low_larger;
34+ wire [6 : 0 ] raw_exp_diff;
35+
36+ // shared significand multiplier
37+ VX_tcu_drl_shared_mul shared_mul_inst (
38+ .enable (enable),
39+ .fmt_s (fmt_s),
40+ .a (a_rows[i]),
41+ .b (b_cols[i]),
42+ .exp_low_larger (exp_low_larger),
43+ .raw_exp_diff (raw_exp_diff),
44+ .y (raw_sigs[i])
8345 );
8446
85- // INT8 multiplication
86- VX_tcu_drl_int8mul int8mul (
87- .enable (enable),
88- .a (a_rows[i]),
89- .b (b_cols[i]),
90- .signed_y (mul_sig_int8[i])
47+ // exponent add and bias
48+ VX_tcu_drl_exp_bias exp_bias_inst (
49+ .enable (enable),
50+ .fmt_s (fmt_s[2 : 0 ]),
51+ .a (a_rows[i]),
52+ .b (b_cols[i]),
53+ .raw_exp_y (raw_exps[i]),
54+ .exp_low_larger (exp_low_larger),
55+ .raw_exp_diff (raw_exp_diff)
9156 );
92-
93- // UINT4 multiplication
94- VX_tcu_drl_uint4mul uint4mul (
95- .enable (enable),
96- .a (a_rows[i]),
97- .b (b_cols[i]),
98- .unsigned_y (mul_sig_uint4[i])
99- );
100-
101- // FP Format selection
102- always_comb begin
103- case (fmt_s[2 : 0 ])
104- 3'd1 : begin
105- mul_sign_mux[i] = mul_sign_fp16[i];
106- mul_exp_mux[i] = mul_exp_fp16[i];
107- mul_sig_mux[i] = mul_sig_fp16[i];
108- end
109- 3'd2 : begin
110- mul_sign_mux[i] = mul_sign_bf16[i];
111- mul_exp_mux[i] = mul_exp_bf16[i];
112- mul_sig_mux[i] = mul_sig_bf16[i];
113- end
114- 3'd3 : begin
115- mul_sign_mux[i] = mul_sign_fp8[i];
116- mul_exp_mux[i] = mul_exp_fp8[i];
117- mul_sig_mux[i] = mul_sig_fp8[i];
118- end
119- 3'd4 : begin
120- mul_sign_mux[i] = mul_sign_bf8[i];
121- mul_exp_mux[i] = mul_exp_bf8[i];
122- mul_sig_mux[i] = mul_sig_bf8[i];
123- end
124- default : begin
125- mul_sign_mux[i] = 1'bx ;
126- mul_exp_mux[i] = 8'hxx ;
127- mul_sig_mux[i] = 24'hxxxxxx ;
128- end
129- endcase
130- end
131-
132- // INT Format selection (sign extend)
133- always_comb begin
134- case (fmt_s[2 : 0 ])
135- 3'd1 : begin
136- int_vals_mux[i] = 25 '($signed (mul_sig_int8[i]));
137- end
138- 3'd4 : begin
139- int_vals_mux[i] = { 15'd0 , mul_sig_uint4[i]} ;
140- end
141- default : begin
142- int_vals_mux[i] = 25'hxxxxxxx ;
143- end
144- endcase
145- end
14657 end
14758
148- // FP c_val integration
149- always_comb begin
150- mul_sign_mux[N - 1 ] = c_val[31 ];
151- mul_exp_mux[N - 1 ] = c_val[30 : 23 ];
152- mul_sig_mux[N - 1 ] = { 1'b1 , c_val[22 : 0 ]} ;
153- end
154-
155- // INT c_val integration
156- assign int_vals_mux[N - 1 ] = c_val[24 : 0 ];
59+ // c_val integration
60+ assign raw_exps[N - 1 ] = c_val[30 : 23 ];
61+ assign raw_sigs[N - 1 ] = fmt_s[3 ] ? c_val[24 : 0 ] : { c_val[31 ], 1'b1 , c_val[22 : 0 ]} ;
15762
15863 // Raw maximum exponent finder (in parallel to mul) and shift amounts
15964 VX_tcu_drl_max_exp # (
16065 .N (N )
16166 ) find_max_exp (
162- .exponents (mul_exp_mux ),
67+ .exponents (raw_exps ),
16368 .max_exp (raw_max_exp),
16469 .shift_amounts (shift_amounts)
16570 );
16671
167- for (genvar i = 0 ; i < N ; i++ ) begin : g_fp_int_sig_sel
168- assign raw_sigs[i] = fmt_s[3 ] ? int_vals_mux[i] : { mul_sign_mux[i], mul_sig_mux[i]} ;
169- end
170-
17172endmodule
0 commit comments