@@ -43,90 +43,152 @@ Expr evaluate_polynomial(Expr x, float *coeff, int n) {
43
43
}
44
44
}
45
45
46
- // Copied from halide_ext, plan is to add this to Halide.
47
- Halide::Tuple halide_ext_exp (const Expr &x_full) {
48
- // Type type = x_full.type();
49
- // CHECK_EQ(type.element_of(), Float(32));
50
-
51
- const float ln2_part1 = 0 .6931457519f ;
52
- const float ln2_part2 = 1 .4286067653e-6f ;
53
- const float one_over_ln2 = 1 .0f / logf (2 .0f );
46
+ /* Extended exponential which produces two output values,
47
+ * each of the same precision as the input, as described in
48
+ * "The Two-Pass Softmax Algorithm" by Marat Dukhan and
49
+ * Artsiom Ablavatski [https://arxiv.org/abs/2001.04438].
50
+ *
51
+ * The first element of the returned Tuple is a psuedo-mantissa while
52
+ * the second is an exponent which is an integer. The product of the
53
+ * pseudo-mantissa and 2 raised to the returned exponent is the
54
+ * desired result e^a. For arguments up to slightly greater than
55
+ * 11629079, the pseudo-mantissa is guaranteed to be within the
56
+ * interval (-e, e). For larger arguments, the exponent result of the
57
+ * tuple may not be able to represent the exact integer necessary to
58
+ * keep the pseudo-mantissa within bounds. Thus it can become
59
+ * progressively larger in magnitude as the argument increases.
60
+ *
61
+ * Ideally this routine will maintain a degree of accuracy through the
62
+ * entire range and be able to produce results out to the end of the
63
+ * numeric range. At present neither of these properties are true due to
64
+ * the following issues:
65
+ * - Range reduction may overflow when scaling the argument.
66
+ * - Range reduction is increasingly inaccurate in reducing the value
67
+ * due to the implementation. This results in overflow in the polynomial
68
+ * evaluation.
69
+ * - Even if the above to issues were resolved, the approximation polynomial
70
+ * would have to run on values outside its intended approximation range.
71
+ */
72
+ Halide::Tuple extended_exp (const Expr &x_full) {
73
+ float ln2_part1 = 0 .6931457519f ;
74
+ float ln2_part2 = 1 .4286067653e-6f ;
75
+ float one_over_ln2 = 1 .0f / logf (2 .0f );
54
76
55
77
Expr scaled = x_full * one_over_ln2;
56
78
Expr k_real = floor (scaled);
57
79
58
80
Expr x = x_full - k_real * ln2_part1;
59
- x -= k_real * ln2_part2;
60
-
61
- float coeff[] = {0 .00031965933071842413f ,
62
- 0 .00119156835564003744f ,
63
- 0 .00848988645943932717f ,
64
- 0 .04160188091348320655f ,
65
- 0 .16667983794100929562f ,
66
- 0 .49999899033463041098f ,
67
- 1 .0f ,
68
- 1 .0f };
81
+ x = x - k_real * ln2_part2;
82
+
83
+ float coeff[] = {
84
+ 0 .00031965933071842413f ,
85
+ 0 .00119156835564003744f ,
86
+ 0 .00848988645943932717f ,
87
+ 0 .04160188091348320655f ,
88
+ 0 .16667983794100929562f ,
89
+ 0 .49999899033463041098f ,
90
+ 1 .0f ,
91
+ 1 .0f };
69
92
Expr result = evaluate_polynomial (x, coeff, sizeof (coeff) / sizeof (coeff[0 ]));
70
93
71
- result = Halide::Internal::common_subexpression_elimination (result);
94
+ // Ensure that the mantissa part is not a NaN or itself an infinity.
95
+ result = strict_float (select (!is_finite (k_real), 1 , result));
96
+ result = common_subexpression_elimination (result);
72
97
73
98
return {result, k_real};
74
99
}
75
100
76
101
} // anonymous namespace
77
102
78
103
struct Softmax : public Halide ::NamesInterface {
79
- Softmax (const std::string &base_name)
104
+ enum class Algorithm {
105
+ Naive,
106
+ TwoPass,
107
+ ThreePass,
108
+ };
109
+
110
+ Softmax (const std::string &base_name,
111
+ Algorithm algorithm = Algorithm::TwoPass)
80
112
: base_name(base_name),
113
+ algorithm (algorithm),
81
114
result(base_name + " _softmax" ),
82
115
ext_exp(base_name + " _softmax_ext_exp" ),
83
116
exponentials(base_name + " _softmax_exponentials" ),
84
- softmax_sums (base_name + " _softmax_sum" ) {
117
+ softmax_sum (base_name + " _softmax_sum" ) {
85
118
}
86
119
std::string base_name;
120
+ Algorithm algorithm;
87
121
Func result;
88
- Func ext_exp;
122
+
123
+ // Naive algorithm
89
124
Func exponentials;
90
- Func softmax_sums;
125
+
126
+ // Two pass algorithm
127
+ Func ext_exp;
128
+
129
+ // Three pass algorithm
130
+ Func max_bias;
131
+ Func biased_exp;
132
+
133
+ // Common to different algorithms
134
+ Func softmax_sum;
91
135
Var result_inner;
92
136
RVar softmax_sum_inner; // TODO: Remove this.
93
137
Var softmax_sum_inner_var;
94
138
LoopLevel softmax_sum_compute_at;
95
139
96
- // Keeping this to either use for testing or turn into a comment.
97
- #if 0
98
- void naive_algorithm(Func input, const Type &generating_type) {
99
- auto args = input.args();
100
- RDom r(0, size);
101
-
102
- exponentials(args) =
103
- default_exp(cast<double>(clamp(input(args), -1e12f, 1e12f)));
104
-
105
- std::vector<Var> args_sum(args.begin() + 1, args.end());
106
- std::vector<Expr> args_reduction;
107
- args_reduction.emplace_back(r.x);
108
- args_reduction.insert(args_reduction.end(), args_sum.begin(),
109
- args_sum.end());
110
-
111
- softmax_sum(args_sum) = Expr(0.0);
112
- softmax_sum(args_sum) += exponentials(args_reduction);
113
- softmax_sum_inner = r.x;
114
-
115
- result(args) = cast(generating_type,
116
- input(args) / select(softmax_sum(args_sum) < Expr(1e-5),
117
- 1, softmax_sum(args_sum)));
118
- result_inner = args[0];
119
- }
120
- #endif
140
+ void apply (Func input, Expr size, const Type &generating_type) {
141
+ switch (algorithm) {
142
+ case Algorithm::Naive:
143
+ naive_algorithm (input, size, generating_type);
144
+ break ;
145
+ case Algorithm::TwoPass:
146
+ two_pass_algorithm (input, size, generating_type);
147
+ break ;
148
+ case Algorithm::ThreePass:
149
+ three_pass_algorithm (input, size, generating_type);
150
+ break ;
151
+ };
152
+ }
153
+
154
+ void naive_algorithm (Func input, Expr size, const Type &generating_type) {
155
+ auto args = input.args ();
156
+ RDom r (0 , size);
157
+
158
+ exponentials (args) =
159
+ default_exp (cast<double >(clamp (input (args), -1e12f, 1e12f)));
160
+
161
+ std::vector<Var> args_sum (args.begin () + 1 , args.end ());
162
+ std::vector<Expr> args_reduction;
163
+ args_reduction.emplace_back (r.x );
164
+ args_reduction.insert (args_reduction.end (), args_sum.begin (),
165
+ args_sum.end ());
166
+
167
+ softmax_sum (args_sum) = Expr (0.0 );
168
+ softmax_sum (args_sum) += exponentials (args_reduction);
169
+ softmax_sum_inner = r.x ;
170
+ softmax_sum_inner_var = args_sum[0 ];
171
+
172
+ result (args) = cast (generating_type,
173
+ input (args) / select (softmax_sum (args_sum) < Expr (1e-5 ),
174
+ 1 , softmax_sum (args_sum)));
175
+ result_inner = args[0 ];
176
+ softmax_sum_compute_at = LoopLevel (result, args[1 ]);
177
+ }
121
178
122
179
// Implementation based on the algorithm in
123
180
// https://arxiv.org/pdf/2001.04438.pdf
124
- void apply (Func input, Expr size, const Type &generating_type) {
181
+ void two_pass_algorithm (Func input, Expr size, const Type &generating_type) {
125
182
auto args = input.args ();
126
183
RDom r (0 , size);
127
184
128
- // TODO: avoid needing double here
129
- ext_exp (args) = halide_ext_exp (cast<double >(input (args)));
185
+ // TODO: It should not be necessary to use double for computation here.
186
+ #define USE_DOUBLE 1
187
+ #if USE_DOUBLE
188
+ ext_exp (args) = extended_exp (cast<double >(input (args)));
189
+ #else
190
+ ext_exp (args) = extended_exp (input (args));
191
+ #endif
130
192
131
193
std::vector<Var> args_inner (args.begin () + 1 , args.end ());
132
194
std::vector<Expr> args_reduction;
@@ -136,32 +198,71 @@ struct Softmax : public Halide::NamesInterface {
136
198
137
199
// This reduction maintains a Tuple of with the sum and the maximum exponent
138
200
// so far, both as floating point numbers.
139
- softmax_sums (args_inner) =
140
- Tuple (cast<double >(0 ), Expr (std::numeric_limits<double >::lowest ()));
201
+ softmax_sum (args_inner) =
202
+ #if USE_DOUBLE
203
+ Halide::Tuple (Expr (0.0 ), Expr (std::numeric_limits<double >::lowest ()));
204
+ #else
205
+ Halide::Tuple (0 .0f , Expr (std::numeric_limits<float >::lowest ()));
206
+ #endif
141
207
Expr running_max_exp =
142
- max (softmax_sums (args_inner)[1 ], ext_exp (args_reduction)[1 ]);
208
+ max (softmax_sum (args_inner)[1 ], ext_exp (args_reduction)[1 ]);
143
209
Expr m_sub_i_term = ext_exp (args_reduction)[0 ] *
144
210
pow (2 .0f , ext_exp (args_reduction)[1 ] - running_max_exp);
145
- Expr m_sum_term = softmax_sums (args_inner)[0 ] *
146
- pow (2 .0f , softmax_sums (args_inner)[1 ] - running_max_exp);
211
+ Expr m_sum_term = softmax_sum (args_inner)[0 ] *
212
+ pow (2 .0f , softmax_sum (args_inner)[1 ] - running_max_exp);
147
213
Expr running_sum = m_sub_i_term + m_sum_term;
148
- softmax_sums (args_inner) = Tuple (running_sum, running_max_exp);
149
- Expr lambda = 1 / softmax_sums (args_inner)[0 ];
214
+ softmax_sum (args_inner) = Tuple (running_sum, running_max_exp);
215
+ Expr lambda = 1 / softmax_sum (args_inner)[0 ];
150
216
Expr t =
151
217
cast (generating_type,
152
218
ext_exp (args)[0 ] * lambda *
153
- pow (2 .0f , ext_exp (args)[1 ] - softmax_sums (args_inner)[1 ]));
219
+ pow (2 .0f , ext_exp (args)[1 ] - softmax_sum (args_inner)[1 ]));
154
220
result (args) = t;
155
221
result_inner = args[0 ];
156
222
softmax_sum_inner = r;
157
223
softmax_sum_inner_var = args_inner[0 ];
158
224
softmax_sum_compute_at = LoopLevel (result, args[1 ]);
159
225
}
160
226
227
+ void three_pass_algorithm (Func input, Expr size, const Type &generating_type) {
228
+ auto args = input.args ();
229
+ RDom r (0 , size);
230
+
231
+ std::vector<Var> args_inner (args.begin () + 1 , args.end ());
232
+ std::vector<Expr> args_reduction;
233
+ args_reduction.emplace_back (r.x );
234
+ args_reduction.insert (args_reduction.end (), args_inner.begin (),
235
+ args_inner.end ());
236
+
237
+ max_bias (args_inner) = std::numeric_limits<float >::lowest ();
238
+ max_bias (args_inner) = max (max_bias (args_inner), input (args_reduction));
239
+
240
+ biased_exp (args) = halide_exp (input (args) - max_bias (args_inner));
241
+ softmax_sum (args_inner) = 0 .0f ;
242
+ softmax_sum (args_inner) += biased_exp (args_reduction);
243
+
244
+ Expr lambda = 1 / softmax_sum (args_inner);
245
+ result (args) = halide_exp (input (args) - max_bias (args_inner)) * lambda;
246
+ result_inner = args[0 ];
247
+ softmax_sum_inner = r;
248
+ softmax_sum_inner_var = args_inner[0 ];
249
+ softmax_sum_compute_at = LoopLevel (result, args[1 ]);
250
+ }
251
+
252
+ // TODO: add support for resuse vs. recompute scheduling on exp operations.
253
+
161
254
void default_schedule (LoopLevel result_loop_level, const Target &t,
162
255
bool vectorize) {
163
- ext_exp.compute_inline ();
164
- softmax_sums.compute_at (softmax_sum_compute_at)
256
+ if (algorithm == Algorithm::Naive) {
257
+ exponentials.compute_at (softmax_sum_compute_at);
258
+ } else if (algorithm == Algorithm::TwoPass) {
259
+ ext_exp.compute_inline ();
260
+ } else if (algorithm == Algorithm::ThreePass) {
261
+ max_bias.compute_at (softmax_sum_compute_at);
262
+ // TODO: vectorize max loop, maybe parallelize
263
+ biased_exp.compute_at (softmax_sum_compute_at);
264
+ }
265
+ softmax_sum.compute_at (softmax_sum_compute_at)
165
266
.store_in (MemoryType::Register)
166
267
.vectorize (softmax_sum_inner_var, t.natural_vector_size <float >())
167
268
.update (0 )
@@ -170,7 +271,11 @@ struct Softmax : public Halide::NamesInterface {
170
271
if (vectorize) {
171
272
// In some modes, this dimension is narrow and we don't want to vectorize
172
273
// it
274
+ #if USE_DOUBLE
173
275
result.vectorize (result_inner, t.natural_vector_size <double >());
276
+ #else
277
+ result.vectorize (result_inner, t.natural_vector_size <float >());
278
+ #endif
174
279
}
175
280
}
176
281
};
0 commit comments