@@ -1073,31 +1073,26 @@ typedef void (fintrinsic_op1)(unsigned, jl_value_t*, void*, void*);
1073
1073
static inline jl_value_t * jl_fintrinsic_1 (jl_value_t * ty , jl_value_t * a , const char * name , fintrinsic_op1 * bfloatop , fintrinsic_op1 * halfop , fintrinsic_op1 * floatop , fintrinsic_op1 * doubleop )
1074
1074
{
1075
1075
jl_task_t * ct = jl_current_task ;
1076
- if (!jl_is_primitivetype (jl_typeof (a )))
1076
+ jl_datatype_t * aty = (jl_datatype_t * )jl_typeof (a );
1077
+ if (!jl_is_primitivetype (aty ))
1077
1078
jl_errorf ("%s: value is not a primitive type" , name );
1078
1079
if (!jl_is_primitivetype (ty ))
1079
1080
jl_errorf ("%s: type is not a primitive type" , name );
1080
1081
unsigned sz2 = jl_datatype_size (ty );
1081
1082
jl_value_t * newv = jl_gc_alloc (ct -> ptls , sz2 , ty );
1082
1083
void * pa = jl_data_ptr (a ), * pr = jl_data_ptr (newv );
1083
- unsigned sz = jl_datatype_size (jl_typeof (a ));
1084
- switch (sz ) {
1085
- /* choose the right size c-type operation based on the input */
1086
- case 2 :
1087
- if (jl_typeof (a ) == (jl_value_t * )jl_float16_type )
1088
- halfop (sz2 * host_char_bit , ty , pa , pr );
1089
- else /*if (jl_typeof(a) == (jl_value_t*)jl_bfloat16_type)*/
1090
- bfloatop (sz2 * host_char_bit , ty , pa , pr );
1091
- break ;
1092
- case 4 :
1084
+
1085
+ if (aty == jl_float16_type )
1086
+ halfop (sz2 * host_char_bit , ty , pa , pr );
1087
+ else if (aty == jl_bfloat16_type )
1088
+ bfloatop (sz2 * host_char_bit , ty , pa , pr );
1089
+ else if (aty == jl_float32_type )
1093
1090
floatop (sz2 * host_char_bit , ty , pa , pr );
1094
- break ;
1095
- case 8 :
1091
+ else if (aty == jl_float64_type )
1096
1092
doubleop (sz2 * host_char_bit , ty , pa , pr );
1097
- break ;
1098
- default :
1099
- jl_errorf ("%s: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64" , name );
1100
- }
1093
+ else
1094
+ jl_errorf ("%s: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64" , name );
1095
+
1101
1096
return newv ;
1102
1097
}
1103
1098
@@ -1273,30 +1268,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
1273
1268
{ \
1274
1269
jl_task_t *ct = jl_current_task; \
1275
1270
jl_value_t *ty = jl_typeof(a); \
1271
+ jl_datatype_t *aty = (jl_datatype_t *)ty; \
1276
1272
if (jl_typeof(b) != ty) \
1277
1273
jl_error(#name ": types of a and b must match"); \
1278
1274
if (!jl_is_primitivetype(ty)) \
1279
1275
jl_error(#name ": values are not primitive types"); \
1280
1276
int sz = jl_datatype_size(ty); \
1281
1277
jl_value_t *newv = jl_gc_alloc(ct->ptls, sz, ty); \
1282
1278
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b), *pr = jl_data_ptr(newv); \
1283
- switch (sz) { \
1284
- /* choose the right size c-type operation */ \
1285
- case 2 : \
1286
- if ((jl_datatype_t * )ty == jl_float16_type ) \
1287
- jl_ ##name ##16(16, pa, pb, pr); \
1288
- else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
1289
- jl_ ##name ##bf16(16, pa, pb, pr); \
1290
- break; \
1291
- case 4: \
1279
+ if (aty == jl_float16_type) \
1280
+ jl_##name##16(16, pa, pb, pr); \
1281
+ else if (aty == jl_bfloat16_type) \
1282
+ jl_##name##bf16(16, pa, pb, pr); \
1283
+ else if (aty == jl_float32_type) \
1292
1284
jl_##name##32(32, pa, pb, pr); \
1293
- break; \
1294
- case 8: \
1285
+ else if (aty == jl_float64_type) \
1295
1286
jl_##name##64(64, pa, pb, pr); \
1296
- break; \
1297
- default: \
1298
- jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); \
1299
- } \
1287
+ else \
1288
+ jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
1300
1289
return newv; \
1301
1290
}
1302
1291
@@ -1308,30 +1297,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
1308
1297
JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
1309
1298
{ \
1310
1299
jl_value_t *ty = jl_typeof(a); \
1300
+ jl_datatype_t *aty = (jl_datatype_t *)ty; \
1311
1301
if (jl_typeof(b) != ty) \
1312
1302
jl_error(#name ": types of a and b must match"); \
1313
1303
if (!jl_is_primitivetype(ty)) \
1314
1304
jl_error(#name ": values are not primitive types"); \
1315
1305
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b); \
1316
- int sz = jl_datatype_size(ty); \
1317
1306
int cmp; \
1318
- switch (sz) { \
1319
- /* choose the right size c-type operation */ \
1320
- case 2 : \
1321
- if ((jl_datatype_t * )ty == jl_float16_type ) \
1322
- cmp = jl_ ##name ##16(16, pa, pb); \
1323
- else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
1324
- cmp = jl_ ##name ##bf16(16, pa, pb); \
1325
- break; \
1326
- case 4: \
1307
+ if (aty == jl_float16_type) \
1308
+ cmp = jl_##name##16(16, pa, pb); \
1309
+ else if (aty == jl_bfloat16_type) \
1310
+ cmp = jl_##name##bf16(16, pa, pb); \
1311
+ else if (aty == jl_float32_type) \
1327
1312
cmp = jl_##name##32(32, pa, pb); \
1328
- break; \
1329
- case 8: \
1313
+ else if (aty == jl_float64_type) \
1330
1314
cmp = jl_##name##64(64, pa, pb); \
1331
- break; \
1332
- default: \
1333
- jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64"); \
1334
- } \
1315
+ else \
1316
+ jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
1317
+ \
1335
1318
return cmp ? jl_true : jl_false; \
1336
1319
}
1337
1320
@@ -1344,30 +1327,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c)
1344
1327
{ \
1345
1328
jl_task_t *ct = jl_current_task; \
1346
1329
jl_value_t *ty = jl_typeof(a); \
1330
+ jl_datatype_t *aty = (jl_datatype_t *)ty; \
1347
1331
if (jl_typeof(b) != ty || jl_typeof(c) != ty) \
1348
1332
jl_error(#name ": types of a, b, and c must match"); \
1349
1333
if (!jl_is_primitivetype(ty)) \
1350
1334
jl_error(#name ": values are not primitive types"); \
1351
1335
int sz = jl_datatype_size(ty); \
1352
1336
jl_value_t *newv = jl_gc_alloc(ct->ptls, sz, ty); \
1353
1337
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b), *pc = jl_data_ptr(c), *pr = jl_data_ptr(newv); \
1354
- switch (sz) { \
1355
- /* choose the right size c-type operation */ \
1356
- case 2 : \
1357
- if ((jl_datatype_t * )ty == jl_float16_type ) \
1338
+ if (aty == jl_float16_type) \
1358
1339
jl_##name##16(16, pa, pb, pc, pr); \
1359
- else /* if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
1340
+ else if (aty == jl_bfloat16_type) \
1360
1341
jl_##name##bf16(16, pa, pb, pc, pr); \
1361
- break; \
1362
- case 4: \
1342
+ else if (aty == jl_float32_type) \
1363
1343
jl_##name##32(32, pa, pb, pc, pr); \
1364
- break; \
1365
- case 8: \
1344
+ else if (aty == jl_float64_type) \
1366
1345
jl_##name##64(64, pa, pb, pc, pr); \
1367
- break; \
1368
- default: \
1369
- jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); \
1370
- } \
1346
+ else \
1347
+ jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
1371
1348
return newv; \
1372
1349
}
1373
1350
@@ -1661,7 +1638,7 @@ static inline void fptrunc(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void
1661
1638
fptrunc_convert (float64 , bfloat16 );
1662
1639
fptrunc_convert (float64 , float32 );
1663
1640
else
1664
- jl_error ("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64 " );
1641
+ jl_error ("fptrunc: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64 " );
1665
1642
#undef fptrunc_convert
1666
1643
}
1667
1644
@@ -1685,7 +1662,7 @@ static inline void fpext(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void *
1685
1662
fpext_convert (bfloat16 , float64 );
1686
1663
fpext_convert (float32 , float64 );
1687
1664
else
1688
- jl_error ("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64 " );
1665
+ jl_error ("fptrunc: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64 " );
1689
1666
#undef fpext_convert
1690
1667
}
1691
1668
0 commit comments