-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathSimplify_Cast.cpp
139 lines (133 loc) · 6.34 KB
/
Simplify_Cast.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include "Simplify_Internal.h"
namespace Halide {
namespace Internal {
Expr Simplify::visit(const Cast *op, ExprInfo *bounds) {
Expr value = mutate(op->value, bounds);
if (bounds) {
// If either the min value or the max value can't be represented
// in the destination type, or the min/max value is undefined,
// the bounds need to be cleared.
if ((bounds->min_defined && !op->type.can_represent(bounds->min)) ||
!bounds->min_defined ||
(bounds->max_defined && !op->type.can_represent(bounds->max)) ||
!bounds->max_defined) {
bounds->min_defined = false;
bounds->max_defined = false;
}
if (!op->type.can_represent(bounds->alignment.modulus) ||
!op->type.can_represent(bounds->alignment.remainder)) {
bounds->alignment = ModulusRemainder();
}
}
if (may_simplify(op->type) && may_simplify(op->value.type())) {
const Cast *cast = value.as<Cast>();
const Broadcast *broadcast_value = value.as<Broadcast>();
const Ramp *ramp_value = value.as<Ramp>();
double f = 0.0;
int64_t i = 0;
uint64_t u = 0;
if (Call::as_intrinsic(value, {Call::signed_integer_overflow})) {
clear_bounds_info(bounds);
return make_signed_integer_overflow(op->type);
} else if (value.type() == op->type) {
return value;
} else if (op->type.is_int() &&
const_float(value, &f) &&
std::isfinite(f)) {
// float -> int
// Recursively call mutate just to set the bounds
return mutate(make_const(op->type, safe_numeric_cast<int64_t>(f)), bounds);
} else if (op->type.is_uint() &&
const_float(value, &f) &&
std::isfinite(f)) {
// float -> uint
return make_const(op->type, safe_numeric_cast<uint64_t>(f));
} else if (op->type.is_float() &&
const_float(value, &f)) {
// float -> float
return make_const(op->type, f);
} else if (op->type.is_int() &&
const_int(value, &i)) {
// int -> int
// Recursively call mutate just to set the bounds
return mutate(make_const(op->type, i), bounds);
} else if (op->type.is_uint() &&
const_int(value, &i)) {
// int -> uint
return make_const(op->type, safe_numeric_cast<uint64_t>(i));
} else if (op->type.is_float() &&
const_int(value, &i)) {
// int -> float
return make_const(op->type, safe_numeric_cast<double>(i));
} else if (op->type.is_int() &&
const_uint(value, &u) &&
op->type.bits() < value.type().bits()) {
// uint -> int narrowing
// Recursively call mutate just to set the bounds
return mutate(make_const(op->type, safe_numeric_cast<int64_t>(u)), bounds);
} else if (op->type.is_int() &&
const_uint(value, &u) &&
op->type.bits() == value.type().bits()) {
// uint -> int reinterpret
// Recursively call mutate just to set the bounds
return mutate(make_const(op->type, safe_numeric_cast<int64_t>(u)), bounds);
} else if (op->type.is_int() &&
const_uint(value, &u) &&
op->type.bits() > value.type().bits()) {
// uint -> int widening
if (op->type.can_represent(u) || op->type.bits() < 32) {
// If the type can represent the value or overflow is well-defined.
// Recursively call mutate just to set the bounds
return mutate(make_const(op->type, safe_numeric_cast<int64_t>(u)), bounds);
} else {
return make_signed_integer_overflow(op->type);
}
} else if (op->type.is_uint() &&
const_uint(value, &u)) {
// uint -> uint
return make_const(op->type, u);
} else if (op->type.is_float() &&
const_uint(value, &u)) {
// uint -> float
return make_const(op->type, safe_numeric_cast<double>(u));
} else if (cast &&
op->type.code() == cast->type.code() &&
op->type.bits() < cast->type.bits()) {
// If this is a cast of a cast of the same type, where the
// outer cast is narrower, the inner cast can be
// eliminated.
return mutate(Cast::make(op->type, cast->value), bounds);
} else if (cast &&
(op->type.is_int() || op->type.is_uint()) &&
(cast->type.is_int() || cast->type.is_uint()) &&
op->type.bits() <= cast->type.bits() &&
op->type.bits() <= op->value.type().bits()) {
// If this is a cast between integer types, where the
// outer cast is narrower than the inner cast and the
// inner cast's argument, the inner cast can be
// eliminated. The inner cast is either a sign extend
// or a zero extend, and the outer cast truncates the extended bits
return mutate(Cast::make(op->type, cast->value), bounds);
} else if (broadcast_value) {
// cast(broadcast(x)) -> broadcast(cast(x))
return mutate(Broadcast::make(Cast::make(op->type.with_lanes(broadcast_value->value.type().lanes()), broadcast_value->value), broadcast_value->lanes), bounds);
} else if (ramp_value &&
op->type.element_of() == Int(64) &&
op->value.type().element_of() == Int(32)) {
// cast(ramp(a, b, w)) -> ramp(cast(a), cast(b), w)
return mutate(Ramp::make(Cast::make(op->type.with_lanes(ramp_value->base.type().lanes()),
ramp_value->base),
Cast::make(op->type.with_lanes(ramp_value->stride.type().lanes()),
ramp_value->stride),
ramp_value->lanes),
bounds);
}
}
if (value.same_as(op->value)) {
return op;
} else {
return Cast::make(op->type, value);
}
}
} // namespace Internal
} // namespace Halide