Skip to content

Commit 5693e38

Browse files
author
Zalman Stern
committed
Add support for multiple algorithms, including naive and three pass.
1 parent 0e99743 commit 5693e38

File tree

1 file changed

+166
-61
lines changed

1 file changed

+166
-61
lines changed

apps/hallmark/src/ml_ops/softmax.h

+166-61
Original file line numberDiff line numberDiff line change
@@ -43,90 +43,152 @@ Expr evaluate_polynomial(Expr x, float *coeff, int n) {
4343
}
4444
}
4545

46-
// Copied from halide_ext, plan is to add this to Halide.
47-
Halide::Tuple halide_ext_exp(const Expr &x_full) {
48-
// Type type = x_full.type();
49-
// CHECK_EQ(type.element_of(), Float(32));
50-
51-
const float ln2_part1 = 0.6931457519f;
52-
const float ln2_part2 = 1.4286067653e-6f;
53-
const float one_over_ln2 = 1.0f / logf(2.0f);
46+
/* Extended exponential which produces two output values,
47+
* each of the same precision as the input, as described in
48+
* "The Two-Pass Softmax Algorithm" by Marat Dukhan and
49+
* Artsiom Ablavatski [https://arxiv.org/abs/2001.04438].
50+
*
51+
* The first element of the returned Tuple is a psuedo-mantissa while
52+
* the second is an exponent which is an integer. The product of the
53+
* pseudo-mantissa and 2 raised to the returned exponent is the
54+
* desired result e^a. For arguments up to slightly greater than
55+
* 11629079, the pseudo-mantissa is guaranteed to be within the
56+
* interval (-e, e). For larger arguments, the exponent result of the
57+
* tuple may not be able to represent the exact integer necessary to
58+
* keep the pseudo-mantissa within bounds. Thus it can become
59+
* progressively larger in magnitude as the argument increases.
60+
*
61+
* Ideally this routine will maintain a degree of accuracy through the
62+
* entire range and be able to produce results out to the end of the
63+
* numeric range. At present neither of these properties are true due to
64+
* the following issues:
65+
* - Range reduction may overflow when scaling the argument.
66+
* - Range reduction is increasingly inaccurate in reducing the value
67+
* due to the implementation. This results in overflow in the polynomial
68+
* evaluation.
69+
* - Even if the above to issues were resolved, the approximation polynomial
70+
* would have to run on values outside its intended approximation range.
71+
*/
72+
Halide::Tuple extended_exp(const Expr &x_full) {
73+
float ln2_part1 = 0.6931457519f;
74+
float ln2_part2 = 1.4286067653e-6f;
75+
float one_over_ln2 = 1.0f / logf(2.0f);
5476

5577
Expr scaled = x_full * one_over_ln2;
5678
Expr k_real = floor(scaled);
5779

5880
Expr x = x_full - k_real * ln2_part1;
59-
x -= k_real * ln2_part2;
60-
61-
float coeff[] = {0.00031965933071842413f,
62-
0.00119156835564003744f,
63-
0.00848988645943932717f,
64-
0.04160188091348320655f,
65-
0.16667983794100929562f,
66-
0.49999899033463041098f,
67-
1.0f,
68-
1.0f};
81+
x = x - k_real * ln2_part2;
82+
83+
float coeff[] = {
84+
0.00031965933071842413f,
85+
0.00119156835564003744f,
86+
0.00848988645943932717f,
87+
0.04160188091348320655f,
88+
0.16667983794100929562f,
89+
0.49999899033463041098f,
90+
1.0f,
91+
1.0f};
6992
Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0]));
7093

71-
result = Halide::Internal::common_subexpression_elimination(result);
94+
// Ensure that the mantissa part is not a NaN or itself an infinity.
95+
result = strict_float(select(!is_finite(k_real), 1, result));
96+
result = common_subexpression_elimination(result);
7297

7398
return {result, k_real};
7499
}
75100

76101
} // anonymous namespace
77102

78103
struct Softmax : public Halide::NamesInterface {
79-
Softmax(const std::string &base_name)
104+
enum class Algorithm {
105+
Naive,
106+
TwoPass,
107+
ThreePass,
108+
};
109+
110+
Softmax(const std::string &base_name,
111+
Algorithm algorithm = Algorithm::TwoPass)
80112
: base_name(base_name),
113+
algorithm(algorithm),
81114
result(base_name + "_softmax"),
82115
ext_exp(base_name + "_softmax_ext_exp"),
83116
exponentials(base_name + "_softmax_exponentials"),
84-
softmax_sums(base_name + "_softmax_sum") {
117+
softmax_sum(base_name + "_softmax_sum") {
85118
}
86119
std::string base_name;
120+
Algorithm algorithm;
87121
Func result;
88-
Func ext_exp;
122+
123+
// Naive algorithm
89124
Func exponentials;
90-
Func softmax_sums;
125+
126+
// Two pass algorithm
127+
Func ext_exp;
128+
129+
// Three pass algorithm
130+
Func max_bias;
131+
Func biased_exp;
132+
133+
// Common to different algorithms
134+
Func softmax_sum;
91135
Var result_inner;
92136
RVar softmax_sum_inner; // TODO: Remove this.
93137
Var softmax_sum_inner_var;
94138
LoopLevel softmax_sum_compute_at;
95139

96-
// Keeping this to either use for testing or turn into a comment.
97-
#if 0
98-
void naive_algorithm(Func input, const Type &generating_type) {
99-
auto args = input.args();
100-
RDom r(0, size);
101-
102-
exponentials(args) =
103-
default_exp(cast<double>(clamp(input(args), -1e12f, 1e12f)));
104-
105-
std::vector<Var> args_sum(args.begin() + 1, args.end());
106-
std::vector<Expr> args_reduction;
107-
args_reduction.emplace_back(r.x);
108-
args_reduction.insert(args_reduction.end(), args_sum.begin(),
109-
args_sum.end());
110-
111-
softmax_sum(args_sum) = Expr(0.0);
112-
softmax_sum(args_sum) += exponentials(args_reduction);
113-
softmax_sum_inner = r.x;
114-
115-
result(args) = cast(generating_type,
116-
input(args) / select(softmax_sum(args_sum) < Expr(1e-5),
117-
1, softmax_sum(args_sum)));
118-
result_inner = args[0];
119-
}
120-
#endif
140+
void apply(Func input, Expr size, const Type &generating_type) {
141+
switch (algorithm) {
142+
case Algorithm::Naive:
143+
naive_algorithm(input, size, generating_type);
144+
break;
145+
case Algorithm::TwoPass:
146+
two_pass_algorithm(input, size, generating_type);
147+
break;
148+
case Algorithm::ThreePass:
149+
three_pass_algorithm(input, size, generating_type);
150+
break;
151+
};
152+
}
153+
154+
void naive_algorithm(Func input, Expr size, const Type &generating_type) {
155+
auto args = input.args();
156+
RDom r(0, size);
157+
158+
exponentials(args) =
159+
default_exp(cast<double>(clamp(input(args), -1e12f, 1e12f)));
160+
161+
std::vector<Var> args_sum(args.begin() + 1, args.end());
162+
std::vector<Expr> args_reduction;
163+
args_reduction.emplace_back(r.x);
164+
args_reduction.insert(args_reduction.end(), args_sum.begin(),
165+
args_sum.end());
166+
167+
softmax_sum(args_sum) = Expr(0.0);
168+
softmax_sum(args_sum) += exponentials(args_reduction);
169+
softmax_sum_inner = r.x;
170+
softmax_sum_inner_var = args_sum[0];
171+
172+
result(args) = cast(generating_type,
173+
input(args) / select(softmax_sum(args_sum) < Expr(1e-5),
174+
1, softmax_sum(args_sum)));
175+
result_inner = args[0];
176+
softmax_sum_compute_at = LoopLevel(result, args[1]);
177+
}
121178

122179
// Implementation based on the algorithm in
123180
// https://arxiv.org/pdf/2001.04438.pdf
124-
void apply(Func input, Expr size, const Type &generating_type) {
181+
void two_pass_algorithm(Func input, Expr size, const Type &generating_type) {
125182
auto args = input.args();
126183
RDom r(0, size);
127184

128-
// TODO: avoid needing double here
129-
ext_exp(args) = halide_ext_exp(cast<double>(input(args)));
185+
// TODO: It should not be necessary to use double for computation here.
186+
#define USE_DOUBLE 1
187+
#if USE_DOUBLE
188+
ext_exp(args) = extended_exp(cast<double>(input(args)));
189+
#else
190+
ext_exp(args) = extended_exp(input(args));
191+
#endif
130192

131193
std::vector<Var> args_inner(args.begin() + 1, args.end());
132194
std::vector<Expr> args_reduction;
@@ -136,32 +198,71 @@ struct Softmax : public Halide::NamesInterface {
136198

137199
// This reduction maintains a Tuple of with the sum and the maximum exponent
138200
// so far, both as floating point numbers.
139-
softmax_sums(args_inner) =
140-
Tuple(cast<double>(0), Expr(std::numeric_limits<double>::lowest()));
201+
softmax_sum(args_inner) =
202+
#if USE_DOUBLE
203+
Halide::Tuple(Expr(0.0), Expr(std::numeric_limits<double>::lowest()));
204+
#else
205+
Halide::Tuple(0.0f, Expr(std::numeric_limits<float>::lowest()));
206+
#endif
141207
Expr running_max_exp =
142-
max(softmax_sums(args_inner)[1], ext_exp(args_reduction)[1]);
208+
max(softmax_sum(args_inner)[1], ext_exp(args_reduction)[1]);
143209
Expr m_sub_i_term = ext_exp(args_reduction)[0] *
144210
pow(2.0f, ext_exp(args_reduction)[1] - running_max_exp);
145-
Expr m_sum_term = softmax_sums(args_inner)[0] *
146-
pow(2.0f, softmax_sums(args_inner)[1] - running_max_exp);
211+
Expr m_sum_term = softmax_sum(args_inner)[0] *
212+
pow(2.0f, softmax_sum(args_inner)[1] - running_max_exp);
147213
Expr running_sum = m_sub_i_term + m_sum_term;
148-
softmax_sums(args_inner) = Tuple(running_sum, running_max_exp);
149-
Expr lambda = 1 / softmax_sums(args_inner)[0];
214+
softmax_sum(args_inner) = Tuple(running_sum, running_max_exp);
215+
Expr lambda = 1 / softmax_sum(args_inner)[0];
150216
Expr t =
151217
cast(generating_type,
152218
ext_exp(args)[0] * lambda *
153-
pow(2.0f, ext_exp(args)[1] - softmax_sums(args_inner)[1]));
219+
pow(2.0f, ext_exp(args)[1] - softmax_sum(args_inner)[1]));
154220
result(args) = t;
155221
result_inner = args[0];
156222
softmax_sum_inner = r;
157223
softmax_sum_inner_var = args_inner[0];
158224
softmax_sum_compute_at = LoopLevel(result, args[1]);
159225
}
160226

227+
void three_pass_algorithm(Func input, Expr size, const Type &generating_type) {
228+
auto args = input.args();
229+
RDom r(0, size);
230+
231+
std::vector<Var> args_inner(args.begin() + 1, args.end());
232+
std::vector<Expr> args_reduction;
233+
args_reduction.emplace_back(r.x);
234+
args_reduction.insert(args_reduction.end(), args_inner.begin(),
235+
args_inner.end());
236+
237+
max_bias(args_inner) = std::numeric_limits<float>::lowest();
238+
max_bias(args_inner) = max(max_bias(args_inner), input(args_reduction));
239+
240+
biased_exp(args) = halide_exp(input(args) - max_bias(args_inner));
241+
softmax_sum(args_inner) = 0.0f;
242+
softmax_sum(args_inner) += biased_exp(args_reduction);
243+
244+
Expr lambda = 1 / softmax_sum(args_inner);
245+
result(args) = halide_exp(input(args) - max_bias(args_inner)) * lambda;
246+
result_inner = args[0];
247+
softmax_sum_inner = r;
248+
softmax_sum_inner_var = args_inner[0];
249+
softmax_sum_compute_at = LoopLevel(result, args[1]);
250+
}
251+
252+
// TODO: add support for resuse vs. recompute scheduling on exp operations.
253+
161254
void default_schedule(LoopLevel result_loop_level, const Target &t,
162255
bool vectorize) {
163-
ext_exp.compute_inline();
164-
softmax_sums.compute_at(softmax_sum_compute_at)
256+
if (algorithm == Algorithm::Naive) {
257+
exponentials.compute_at(softmax_sum_compute_at);
258+
} else if (algorithm == Algorithm::TwoPass) {
259+
ext_exp.compute_inline();
260+
} else if (algorithm == Algorithm::ThreePass) {
261+
max_bias.compute_at(softmax_sum_compute_at);
262+
// TODO: vectorize max loop, maybe parallelize
263+
biased_exp.compute_at(softmax_sum_compute_at);
264+
}
265+
softmax_sum.compute_at(softmax_sum_compute_at)
165266
.store_in(MemoryType::Register)
166267
.vectorize(softmax_sum_inner_var, t.natural_vector_size<float>())
167268
.update(0)
@@ -170,7 +271,11 @@ struct Softmax : public Halide::NamesInterface {
170271
if (vectorize) {
171272
// In some modes, this dimension is narrow and we don't want to vectorize
172273
// it
274+
#if USE_DOUBLE
173275
result.vectorize(result_inner, t.natural_vector_size<double>());
276+
#else
277+
result.vectorize(result_inner, t.natural_vector_size<float>());
278+
#endif
174279
}
175280
}
176281
};

0 commit comments

Comments
 (0)