@@ -129,8 +129,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
129
129
130
130
const int64_t a_rows = a->Shape ().NumDimensions () > 1 ? a->Shape ()[a->Shape ().NumDimensions () - 2 ] : 1 ;
131
131
TensorShape output_shape_shader ({batch_size, a_rows, helper.N () / components});
132
- Activation activation;
133
- MatMulNaiveProgram program{activation, output_rank, output_number, has_bias};
132
+ MatMulNaiveProgram program{activation_, output_rank, output_number, has_bias};
134
133
135
134
program
136
135
.CacheHint (std::to_string (components), std::to_string (a_components), std::to_string (output_number))
@@ -208,8 +207,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
208
207
const TensorShape a_shape_temp = CreateMatMulIntermediateShape (outer_dims_a, dim_a_outer, dim_inner, components);
209
208
const TensorShape b_shape_temp = CreateMatMulIntermediateShape (outer_dims_b, dim_inner, dim_b_outer, components);
210
209
const TensorShape output_shape_temp = TensorShape ({batch_size, dim_a_outer, dim_b_outer / components});
211
- Activation activation;
212
- MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread};
210
+ MatMulProgram program{activation_, has_bias, is_vec4, elements_per_thread};
213
211
program
214
212
.CacheHint (absl::StrJoin (elements_per_thread, " -" ), std::to_string (is_vec4))
215
213
.AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
0 commit comments