@@ -1325,15 +1325,30 @@ constexpr int const_min(int a, int b) {
1325
1325
return a < b ? a : b;
1326
1326
}
1327
1327
1328
- template <typename ... Args>
1328
+ template <Call::IntrinsicOp intrin>
1329
+ struct OptionalIntrinType {
1330
+ bool check (const Type &) const {
1331
+ return true ;
1332
+ }
1333
+ };
1334
+
1335
+ template <>
1336
+ struct OptionalIntrinType <Call::saturating_cast> {
1337
+ halide_type_t type;
1338
+ bool check (const Type &t) const {
1339
+ return t == Type (type);
1340
+ }
1341
+ };
1342
+
1343
+ template <Call::IntrinsicOp intrin, typename ... Args>
1329
1344
struct Intrin {
1330
1345
struct pattern_tag {};
1331
- Call::IntrinsicOp intrin;
1332
1346
std::tuple<Args...> args;
1333
1347
// The type of the output of the intrinsic node.
1334
1348
// Only necessary in cases where it can't be inferred
1335
1349
// from the input types (e.g. saturating_cast).
1336
- Type optional_type_hint;
1350
+
1351
+ OptionalIntrinType<intrin> optional_type_hint;
1337
1352
1338
1353
static constexpr uint32_t binds = bitwise_or_reduce((bindings<Args>::mask)...);
1339
1354
@@ -1362,7 +1377,7 @@ struct Intrin {
1362
1377
}
1363
1378
const Call &c = (const Call &)e;
1364
1379
return (c.is_intrinsic (intrin) &&
1365
- (( optional_type_hint == Type ()) || optional_type_hint == e.type ) &&
1380
+ optional_type_hint. check ( e.type ) &&
1366
1381
match_args<0 , bound>(0 , c, state));
1367
1382
}
1368
1383
@@ -1394,8 +1409,8 @@ struct Intrin {
1394
1409
return likely_if_innermost (std::move (arg0));
1395
1410
} else if (intrin == Call::abs ) {
1396
1411
return abs (std::move (arg0));
1397
- } else if (intrin == Call::saturating_cast) {
1398
- return saturating_cast (optional_type_hint, std::move (arg0));
1412
+ } else if constexpr (intrin == Call::saturating_cast) {
1413
+ return saturating_cast (optional_type_hint. type , std::move (arg0));
1399
1414
}
1400
1415
1401
1416
Expr arg1 = std::get<const_min (1 , sizeof ...(Args) - 1 )>(args).make (state, type_hint);
@@ -1489,98 +1504,113 @@ struct Intrin {
1489
1504
}
1490
1505
1491
1506
HALIDE_ALWAYS_INLINE
1492
- Intrin (Call::IntrinsicOp intrin, Args... args) noexcept
1493
- : intrin(intrin), args(args...) {
1507
+ Intrin (Args... args) noexcept
1508
+ : args(args...) {
1494
1509
}
1495
1510
};
1496
1511
1497
- template <typename ... Args>
1498
- std::ostream &operator <<(std::ostream &s, const Intrin<Args...> &op) {
1499
- s << op. intrin << " (" ;
1512
+ template <Call::IntrinsicOp intrin, typename ... Args>
1513
+ std::ostream &operator <<(std::ostream &s, const Intrin<intrin, Args...> &op) {
1514
+ s << intrin << " (" ;
1500
1515
op.print_args (s);
1501
1516
s << " )" ;
1502
1517
return s;
1503
1518
}
1504
1519
1505
- template <typename ... Args>
1506
- HALIDE_ALWAYS_INLINE auto intrin (Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1507
- return {intrinsic_op, pattern_arg (args)...};
1508
- }
1509
-
1510
1520
template <typename A, typename B>
1511
- auto widen_right_add (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1512
- return {Call::widen_right_add, pattern_arg (a), pattern_arg (b)};
1521
+ auto widen_right_add (A &&a, B &&b) noexcept -> Intrin<Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1522
+ return {pattern_arg (a), pattern_arg (b)};
1513
1523
}
1514
1524
template <typename A, typename B>
1515
- auto widen_right_mul (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1516
- return {Call::widen_right_mul, pattern_arg (a), pattern_arg (b)};
1525
+ auto widen_right_mul (A &&a, B &&b) noexcept -> Intrin<Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1526
+ return {pattern_arg (a), pattern_arg (b)};
1517
1527
}
1518
1528
template <typename A, typename B>
1519
- auto widen_right_sub (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1520
- return {Call::widen_right_sub, pattern_arg (a), pattern_arg (b)};
1529
+ auto widen_right_sub (A &&a, B &&b) noexcept -> Intrin<Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1530
+ return {pattern_arg (a), pattern_arg (b)};
1521
1531
}
1522
1532
1523
1533
template <typename A, typename B>
1524
- auto widening_add (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1525
- return {Call::widening_add, pattern_arg (a), pattern_arg (b)};
1534
+ auto widening_add (A &&a, B &&b) noexcept -> Intrin<Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1535
+ return {pattern_arg (a), pattern_arg (b)};
1526
1536
}
1527
1537
template <typename A, typename B>
1528
- auto widening_sub (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1529
- return {Call::widening_sub, pattern_arg (a), pattern_arg (b)};
1538
+ auto widening_sub (A &&a, B &&b) noexcept -> Intrin<Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1539
+ return {pattern_arg (a), pattern_arg (b)};
1530
1540
}
1531
1541
template <typename A, typename B>
1532
- auto widening_mul (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1533
- return {Call::widening_mul, pattern_arg (a), pattern_arg (b)};
1542
+ auto widening_mul (A &&a, B &&b) noexcept -> Intrin<Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1543
+ return {pattern_arg (a), pattern_arg (b)};
1534
1544
}
1535
1545
template <typename A, typename B>
1536
- auto saturating_add (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1537
- return {Call::saturating_add, pattern_arg (a), pattern_arg (b)};
1546
+ auto saturating_add (A &&a, B &&b) noexcept -> Intrin<Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1547
+ return {pattern_arg (a), pattern_arg (b)};
1538
1548
}
1539
1549
template <typename A, typename B>
1540
- auto saturating_sub (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1541
- return {Call::saturating_sub, pattern_arg (a), pattern_arg (b)};
1550
+ auto saturating_sub (A &&a, B &&b) noexcept -> Intrin<Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1551
+ return {pattern_arg (a), pattern_arg (b)};
1542
1552
}
1543
1553
template <typename A>
1544
- auto saturating_cast (const Type &t, A &&a) noexcept -> Intrin<decltype(pattern_arg(a))> {
1545
- Intrin<decltype (pattern_arg (a))> p = {Call::saturating_cast, pattern_arg (a)};
1546
- p.optional_type_hint = t;
1554
+ auto saturating_cast (const Type &t, A &&a) noexcept -> Intrin<Call::saturating_cast, decltype(pattern_arg(a))> {
1555
+ Intrin<Call::saturating_cast, decltype (pattern_arg (a))> p = {pattern_arg (a)};
1556
+ p.optional_type_hint . type = t;
1547
1557
return p;
1548
1558
}
1549
1559
template <typename A, typename B>
1550
- auto halving_add (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1551
- return {Call::halving_add, pattern_arg (a), pattern_arg (b)};
1560
+ auto halving_add (A &&a, B &&b) noexcept -> Intrin<Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1561
+ return {pattern_arg (a), pattern_arg (b)};
1552
1562
}
1553
1563
template <typename A, typename B>
1554
- auto halving_sub (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1555
- return {Call::halving_sub, pattern_arg (a), pattern_arg (b)};
1564
+ auto halving_sub (A &&a, B &&b) noexcept -> Intrin<Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1565
+ return {pattern_arg (a), pattern_arg (b)};
1556
1566
}
1557
1567
template <typename A, typename B>
1558
- auto rounding_halving_add (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1559
- return {Call::rounding_halving_add, pattern_arg (a), pattern_arg (b)};
1568
+ auto rounding_halving_add (A &&a, B &&b) noexcept -> Intrin<Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1569
+ return {pattern_arg (a), pattern_arg (b)};
1560
1570
}
1561
1571
template <typename A, typename B>
1562
- auto shift_left (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1563
- return {Call::shift_left, pattern_arg (a), pattern_arg (b)};
1572
+ auto shift_left (A &&a, B &&b) noexcept -> Intrin<Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1573
+ return {pattern_arg (a), pattern_arg (b)};
1564
1574
}
1565
1575
template <typename A, typename B>
1566
- auto shift_right (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1567
- return {Call::shift_right, pattern_arg (a), pattern_arg (b)};
1576
+ auto shift_right (A &&a, B &&b) noexcept -> Intrin<Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1577
+ return {pattern_arg (a), pattern_arg (b)};
1568
1578
}
1569
1579
template <typename A, typename B>
1570
- auto rounding_shift_left (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1571
- return {Call::rounding_shift_left, pattern_arg (a), pattern_arg (b)};
1580
+ auto rounding_shift_left (A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1581
+ return {pattern_arg (a), pattern_arg (b)};
1572
1582
}
1573
1583
template <typename A, typename B>
1574
- auto rounding_shift_right (A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1575
- return {Call::rounding_shift_right, pattern_arg (a), pattern_arg (b)};
1584
+ auto rounding_shift_right (A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1585
+ return {pattern_arg (a), pattern_arg (b)};
1576
1586
}
1577
1587
template <typename A, typename B, typename C>
1578
- auto mul_shift_right (A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1579
- return {Call::mul_shift_right, pattern_arg (a), pattern_arg (b), pattern_arg (c)};
1588
+ auto mul_shift_right (A &&a, B &&b, C &&c) noexcept -> Intrin<Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1589
+ return {pattern_arg (a), pattern_arg (b), pattern_arg (c)};
1580
1590
}
1581
1591
template <typename A, typename B, typename C>
1582
- auto rounding_mul_shift_right (A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1583
- return {Call::rounding_mul_shift_right, pattern_arg (a), pattern_arg (b), pattern_arg (c)};
1592
+ auto rounding_mul_shift_right (A &&a, B &&b, C &&c) noexcept -> Intrin<Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1593
+ return {pattern_arg (a), pattern_arg (b), pattern_arg (c)};
1594
+ }
1595
+
1596
+ template <typename A>
1597
+ auto abs (A &&a) noexcept -> Intrin<Call::abs, decltype(pattern_arg(a))> {
1598
+ return {pattern_arg (a)};
1599
+ }
1600
+
1601
+ template <typename A, typename B>
1602
+ auto absd (A &&a, B &&b) noexcept -> Intrin<Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1603
+ return {pattern_arg (a), pattern_arg (b)};
1604
+ }
1605
+
1606
+ template <typename A>
1607
+ auto likely (A &&a) noexcept -> Intrin<Call::likely, decltype(pattern_arg(a))> {
1608
+ return {pattern_arg (a)};
1609
+ }
1610
+
1611
+ template <typename A>
1612
+ auto likely_if_innermost (A &&a) noexcept -> Intrin<Call::likely_if_innermost, decltype(pattern_arg(a))> {
1613
+ return {pattern_arg (a)};
1584
1614
}
1585
1615
1586
1616
template <typename A>
@@ -2425,7 +2455,8 @@ template<typename A>
2425
2455
struct IsInt {
2426
2456
struct pattern_tag {};
2427
2457
A a;
2428
- int bits, lanes;
2458
+ uint8_t bits;
2459
+ uint16_t lanes;
2429
2460
2430
2461
constexpr static uint32_t binds = bindings<A>::mask;
2431
2462
@@ -2448,7 +2479,7 @@ struct IsInt {
2448
2479
};
2449
2480
2450
2481
template <typename A>
2451
- HALIDE_ALWAYS_INLINE auto is_int (A &&a, int bits = 0 , int lanes = 0 ) noexcept -> IsInt<decltype(pattern_arg(a))> {
2482
+ HALIDE_ALWAYS_INLINE auto is_int (A &&a, uint8_t bits = 0 , uint16_t lanes = 0 ) noexcept -> IsInt<decltype(pattern_arg(a))> {
2452
2483
assert_is_lvalue_if_expr<A>();
2453
2484
return {pattern_arg (a), bits, lanes};
2454
2485
}
@@ -2470,7 +2501,8 @@ template<typename A>
2470
2501
struct IsUInt {
2471
2502
struct pattern_tag {};
2472
2503
A a;
2473
- int bits, lanes;
2504
+ uint8_t bits;
2505
+ uint16_t lanes;
2474
2506
2475
2507
constexpr static uint32_t binds = bindings<A>::mask;
2476
2508
@@ -2493,7 +2525,7 @@ struct IsUInt {
2493
2525
};
2494
2526
2495
2527
template <typename A>
2496
- HALIDE_ALWAYS_INLINE auto is_uint (A &&a, int bits = 0 , int lanes = 0 ) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2528
+ HALIDE_ALWAYS_INLINE auto is_uint (A &&a, uint8_t bits = 0 , uint16_t lanes = 0 ) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2497
2529
assert_is_lvalue_if_expr<A>();
2498
2530
return {pattern_arg (a), bits, lanes};
2499
2531
}
0 commit comments