Skip to content

Commit 0bcc9cd

Browse files
authored
Make remaining float intrinsics require float arguments (#57398)
The `fptrunc`/`fpext` intrinsics were modified in #57160 to throw on non-float arguments. - The arithmetic and math float intrinsics now require all their arguments to be floats - `fptosi`/`fptoui` require their source to be a float - `sitofp`/`uitofp` require their destination type to be a float Also fixes #57384.
1 parent e985652 commit 0bcc9cd

File tree

6 files changed

+155
-82
lines changed

6 files changed

+155
-82
lines changed

Compiler/src/tfuncs.jl

+51-3
Original file line numberDiff line numberDiff line change
@@ -2432,6 +2432,43 @@ const _SPECIAL_BUILTINS = Any[
24322432
Core._apply_iterate,
24332433
]
24342434

2435+
# Intrinsics that require all arguments to be floats
2436+
const _FLOAT_INTRINSICS = Any[
2437+
Intrinsics.neg_float,
2438+
Intrinsics.add_float,
2439+
Intrinsics.sub_float,
2440+
Intrinsics.mul_float,
2441+
Intrinsics.div_float,
2442+
Intrinsics.min_float,
2443+
Intrinsics.max_float,
2444+
Intrinsics.fma_float,
2445+
Intrinsics.muladd_float,
2446+
Intrinsics.neg_float_fast,
2447+
Intrinsics.add_float_fast,
2448+
Intrinsics.sub_float_fast,
2449+
Intrinsics.mul_float_fast,
2450+
Intrinsics.div_float_fast,
2451+
Intrinsics.min_float_fast,
2452+
Intrinsics.max_float_fast,
2453+
Intrinsics.eq_float,
2454+
Intrinsics.ne_float,
2455+
Intrinsics.lt_float,
2456+
Intrinsics.le_float,
2457+
Intrinsics.eq_float_fast,
2458+
Intrinsics.ne_float_fast,
2459+
Intrinsics.lt_float_fast,
2460+
Intrinsics.le_float_fast,
2461+
Intrinsics.fpiseq,
2462+
Intrinsics.abs_float,
2463+
Intrinsics.copysign_float,
2464+
Intrinsics.ceil_llvm,
2465+
Intrinsics.floor_llvm,
2466+
Intrinsics.trunc_llvm,
2467+
Intrinsics.rint_llvm,
2468+
Intrinsics.sqrt_llvm,
2469+
Intrinsics.sqrt_llvm_fast
2470+
]
2471+
24352472
# Types compatible with fpext/fptrunc
24362473
const CORE_FLOAT_TYPES = Union{Core.BFloat16, Float16, Float32, Float64}
24372474

@@ -2849,7 +2886,8 @@ function intrinsic_exct(𝕃::AbstractLattice, f::IntrinsicFunction, argtypes::V
28492886
return ErrorException
28502887
end
28512888

2852-
# fpext and fptrunc have further restrictions on the allowed types.
2889+
# fpext, fptrunc, fptoui, fptosi, uitofp, and sitofp have further
2890+
# restrictions on the allowed types.
28532891
if f === Intrinsics.fpext &&
28542892
!(ty <: CORE_FLOAT_TYPES && xty <: CORE_FLOAT_TYPES && Core.sizeof(ty) > Core.sizeof(xty))
28552893
return ErrorException
@@ -2858,6 +2896,12 @@ function intrinsic_exct(𝕃::AbstractLattice, f::IntrinsicFunction, argtypes::V
28582896
!(ty <: CORE_FLOAT_TYPES && xty <: CORE_FLOAT_TYPES && Core.sizeof(ty) < Core.sizeof(xty))
28592897
return ErrorException
28602898
end
2899+
if (f === Intrinsics.fptoui || f === Intrinsics.fptosi) && !(xty <: CORE_FLOAT_TYPES)
2900+
return ErrorException
2901+
end
2902+
if (f === Intrinsics.uitofp || f === Intrinsics.sitofp) && !(ty <: CORE_FLOAT_TYPES)
2903+
return ErrorException
2904+
end
28612905

28622906
return Union{}
28632907
end
@@ -2870,11 +2914,15 @@ function intrinsic_exct(𝕃::AbstractLattice, f::IntrinsicFunction, argtypes::V
28702914
return Union{}
28712915
end
28722916

2873-
# The remaining intrinsics are math/bits/comparison intrinsics. They work on all
2874-
# primitive types of the same type.
2917+
# The remaining intrinsics are math/bits/comparison intrinsics.
2918+
# All the non-floating point intrinsics work on primitive values of the same type.
28752919
isshift = f === shl_int || f === lshr_int || f === ashr_int
28762920
argtype1 = widenconst(argtypes[1])
28772921
isprimitivetype(argtype1) || return ErrorException
2922+
if contains_is(_FLOAT_INTRINSICS, f)
2923+
argtype1 <: CORE_FLOAT_TYPES || return ErrorException
2924+
end
2925+
28782926
for i = 2:length(argtypes)
28792927
argtype = widenconst(argtypes[i])
28802928
if isshift ? !isprimitivetype(argtype) : argtype !== argtype1

Compiler/test/effects.jl

+26
Original file line numberDiff line numberDiff line change
@@ -1401,3 +1401,29 @@ end == Compiler.EFFECTS_UNKNOWN
14011401
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Float64])
14021402
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Int32}, Float16])
14031403
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Int16])
1404+
1405+
# Float intrinsics require float arguments
1406+
@test Base.infer_effects((Int16,)) do x
1407+
return Core.Intrinsics.abs_float(x)
1408+
end |> !Compiler.is_nothrow
1409+
@test Base.infer_effects((Int32, Int32)) do x, y
1410+
return Core.Intrinsics.add_float(x, y)
1411+
end |> !Compiler.is_nothrow
1412+
@test Base.infer_effects((Int32, Int32)) do x, y
1413+
return Core.Intrinsics.add_float(x, y)
1414+
end |> !Compiler.is_nothrow
1415+
@test Base.infer_effects((Int64, Int64, Int64)) do x, y, z
1416+
return Core.Intrinsics.fma_float(x, y, z)
1417+
end |> !Compiler.is_nothrow
1418+
@test Base.infer_effects((Int64,)) do x
1419+
return Core.Intrinsics.fptoui(UInt32, x)
1420+
end |> !Compiler.is_nothrow
1421+
@test Base.infer_effects((Int64,)) do x
1422+
return Core.Intrinsics.fptosi(Int32, x)
1423+
end |> !Compiler.is_nothrow
1424+
@test Base.infer_effects((Int64,)) do x
1425+
return Core.Intrinsics.sitofp(Int64, x)
1426+
end |> !Compiler.is_nothrow
1427+
@test Base.infer_effects((UInt64,)) do x
1428+
return Core.Intrinsics.uitofp(Int64, x)
1429+
end |> !Compiler.is_nothrow

src/APInt-C.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ void LLVMFPtoInt(jl_datatype_t *ty, void *pa, jl_datatype_t *oty, integerPart *p
321321
Val = julia_bfloat_to_float(*(uint16_t*)pa);
322322
else if (ty == jl_float32_type)
323323
Val = *(float*)pa;
324-
else if (jl_float64_type)
324+
else if (ty == jl_float64_type)
325325
Val = *(double*)pa;
326326
else
327327
jl_error("FPtoSI: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
@@ -352,7 +352,7 @@ void LLVMFPtoInt(jl_datatype_t *ty, void *pa, jl_datatype_t *oty, integerPart *p
352352
else {
353353
APFloat a(Val);
354354
bool isVeryExact;
355-
APFloat::roundingMode rounding_mode = APFloat::rmNearestTiesToEven;
355+
APFloat::roundingMode rounding_mode = RoundingMode::TowardZero;
356356
unsigned nbytes = alignTo(onumbits, integerPartWidth) / host_char_bit;
357357
integerPart *parts = (integerPart*)alloca(nbytes);
358358
APFloat::opStatus status = a.convertToInteger(MutableArrayRef<integerPart>(parts, nbytes), onumbits, isSigned, rounding_mode, &isVeryExact);

src/intrinsics.cpp

+22-13
Original file line numberDiff line numberDiff line change
@@ -676,17 +676,23 @@ static jl_cgval_t generic_cast(
676676
Type *to = bitstype_to_llvm((jl_value_t*)jlto, ctx.builder.getContext(), true);
677677
Type *vt = bitstype_to_llvm(v.typ, ctx.builder.getContext(), true);
678678

679-
// fptrunc fpext depend on the specific floating point format to work
680-
// correctly, and so do not pun their argument types.
679+
// fptrunc and fpext depend on the specific floating point
680+
// format to work correctly, and so do not pun their argument types.
681681
if (!(f == fpext || f == fptrunc)) {
682-
if (toint)
683-
to = INTT(to, DL);
684-
else
685-
to = FLOATT(to);
686-
if (fromint)
687-
vt = INTT(vt, DL);
688-
else
689-
vt = FLOATT(vt);
682+
// uitofp/sitofp require a specific float type argument
683+
if (!(f == uitofp || f == sitofp)){
684+
if (toint)
685+
to = INTT(to, DL);
686+
else
687+
to = FLOATT(to);
688+
}
689+
// fptoui/fptosi require a specific float value argument
690+
if (!(f == fptoui || f == fptosi)) {
691+
if (fromint)
692+
vt = INTT(vt, DL);
693+
else
694+
vt = FLOATT(vt);
695+
}
690696
}
691697

692698
if (!to || !vt)
@@ -1428,10 +1434,13 @@ static jl_cgval_t emit_intrinsic(jl_codectx_t &ctx, intrinsic f, jl_value_t **ar
14281434
if (!jl_is_primitivetype(xinfo.typ))
14291435
return emit_runtime_call(ctx, f, argv, nargs);
14301436
Type *xtyp = bitstype_to_llvm(xinfo.typ, ctx.builder.getContext(), true);
1431-
if (float_func()[f])
1432-
xtyp = FLOATT(xtyp);
1433-
else
1437+
if (float_func()[f]) {
1438+
if (!xtyp->isFloatingPointTy())
1439+
return emit_runtime_call(ctx, f, argv, nargs);
1440+
}
1441+
else {
14341442
xtyp = INTT(xtyp, DL);
1443+
}
14351444
if (!xtyp)
14361445
return emit_runtime_call(ctx, f, argv, nargs);
14371446
////Bool are required to be in the range [0,1]

src/runtime_intrinsics.c

+40-63
Original file line numberDiff line numberDiff line change
@@ -1073,31 +1073,26 @@ typedef void (fintrinsic_op1)(unsigned, jl_value_t*, void*, void*);
10731073
static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const char *name, fintrinsic_op1 *bfloatop, fintrinsic_op1 *halfop, fintrinsic_op1 *floatop, fintrinsic_op1 *doubleop)
10741074
{
10751075
jl_task_t *ct = jl_current_task;
1076-
if (!jl_is_primitivetype(jl_typeof(a)))
1076+
jl_datatype_t *aty = (jl_datatype_t *)jl_typeof(a);
1077+
if (!jl_is_primitivetype(aty))
10771078
jl_errorf("%s: value is not a primitive type", name);
10781079
if (!jl_is_primitivetype(ty))
10791080
jl_errorf("%s: type is not a primitive type", name);
10801081
unsigned sz2 = jl_datatype_size(ty);
10811082
jl_value_t *newv = jl_gc_alloc(ct->ptls, sz2, ty);
10821083
void *pa = jl_data_ptr(a), *pr = jl_data_ptr(newv);
1083-
unsigned sz = jl_datatype_size(jl_typeof(a));
1084-
switch (sz) {
1085-
/* choose the right size c-type operation based on the input */
1086-
case 2:
1087-
if (jl_typeof(a) == (jl_value_t*)jl_float16_type)
1088-
halfop(sz2 * host_char_bit, ty, pa, pr);
1089-
else /*if (jl_typeof(a) == (jl_value_t*)jl_bfloat16_type)*/
1090-
bfloatop(sz2 * host_char_bit, ty, pa, pr);
1091-
break;
1092-
case 4:
1084+
1085+
if (aty == jl_float16_type)
1086+
halfop(sz2 * host_char_bit, ty, pa, pr);
1087+
else if (aty == jl_bfloat16_type)
1088+
bfloatop(sz2 * host_char_bit, ty, pa, pr);
1089+
else if (aty == jl_float32_type)
10931090
floatop(sz2 * host_char_bit, ty, pa, pr);
1094-
break;
1095-
case 8:
1091+
else if (aty == jl_float64_type)
10961092
doubleop(sz2 * host_char_bit, ty, pa, pr);
1097-
break;
1098-
default:
1099-
jl_errorf("%s: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64", name);
1100-
}
1093+
else
1094+
jl_errorf("%s: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64", name);
1095+
11011096
return newv;
11021097
}
11031098

@@ -1273,30 +1268,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
12731268
{ \
12741269
jl_task_t *ct = jl_current_task; \
12751270
jl_value_t *ty = jl_typeof(a); \
1271+
jl_datatype_t *aty = (jl_datatype_t *)ty; \
12761272
if (jl_typeof(b) != ty) \
12771273
jl_error(#name ": types of a and b must match"); \
12781274
if (!jl_is_primitivetype(ty)) \
12791275
jl_error(#name ": values are not primitive types"); \
12801276
int sz = jl_datatype_size(ty); \
12811277
jl_value_t *newv = jl_gc_alloc(ct->ptls, sz, ty); \
12821278
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b), *pr = jl_data_ptr(newv); \
1283-
switch (sz) { \
1284-
/* choose the right size c-type operation */ \
1285-
case 2: \
1286-
if ((jl_datatype_t*)ty == jl_float16_type) \
1287-
jl_##name##16(16, pa, pb, pr); \
1288-
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
1289-
jl_##name##bf16(16, pa, pb, pr); \
1290-
break; \
1291-
case 4: \
1279+
if (aty == jl_float16_type) \
1280+
jl_##name##16(16, pa, pb, pr); \
1281+
else if (aty == jl_bfloat16_type) \
1282+
jl_##name##bf16(16, pa, pb, pr); \
1283+
else if (aty == jl_float32_type) \
12921284
jl_##name##32(32, pa, pb, pr); \
1293-
break; \
1294-
case 8: \
1285+
else if (aty == jl_float64_type) \
12951286
jl_##name##64(64, pa, pb, pr); \
1296-
break; \
1297-
default: \
1298-
jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); \
1299-
} \
1287+
else \
1288+
jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
13001289
return newv; \
13011290
}
13021291

@@ -1308,30 +1297,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
13081297
JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
13091298
{ \
13101299
jl_value_t *ty = jl_typeof(a); \
1300+
jl_datatype_t *aty = (jl_datatype_t *)ty; \
13111301
if (jl_typeof(b) != ty) \
13121302
jl_error(#name ": types of a and b must match"); \
13131303
if (!jl_is_primitivetype(ty)) \
13141304
jl_error(#name ": values are not primitive types"); \
13151305
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b); \
1316-
int sz = jl_datatype_size(ty); \
13171306
int cmp; \
1318-
switch (sz) { \
1319-
/* choose the right size c-type operation */ \
1320-
case 2: \
1321-
if ((jl_datatype_t*)ty == jl_float16_type) \
1322-
cmp = jl_##name##16(16, pa, pb); \
1323-
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
1324-
cmp = jl_##name##bf16(16, pa, pb); \
1325-
break; \
1326-
case 4: \
1307+
if (aty == jl_float16_type) \
1308+
cmp = jl_##name##16(16, pa, pb); \
1309+
else if (aty == jl_bfloat16_type) \
1310+
cmp = jl_##name##bf16(16, pa, pb); \
1311+
else if (aty == jl_float32_type) \
13271312
cmp = jl_##name##32(32, pa, pb); \
1328-
break; \
1329-
case 8: \
1313+
else if (aty == jl_float64_type) \
13301314
cmp = jl_##name##64(64, pa, pb); \
1331-
break; \
1332-
default: \
1333-
jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64"); \
1334-
} \
1315+
else \
1316+
jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
1317+
\
13351318
return cmp ? jl_true : jl_false; \
13361319
}
13371320

@@ -1344,30 +1327,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c)
13441327
{ \
13451328
jl_task_t *ct = jl_current_task; \
13461329
jl_value_t *ty = jl_typeof(a); \
1330+
jl_datatype_t *aty = (jl_datatype_t *)ty; \
13471331
if (jl_typeof(b) != ty || jl_typeof(c) != ty) \
13481332
jl_error(#name ": types of a, b, and c must match"); \
13491333
if (!jl_is_primitivetype(ty)) \
13501334
jl_error(#name ": values are not primitive types"); \
13511335
int sz = jl_datatype_size(ty); \
13521336
jl_value_t *newv = jl_gc_alloc(ct->ptls, sz, ty); \
13531337
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b), *pc = jl_data_ptr(c), *pr = jl_data_ptr(newv); \
1354-
switch (sz) { \
1355-
/* choose the right size c-type operation */ \
1356-
case 2: \
1357-
if ((jl_datatype_t*)ty == jl_float16_type) \
1338+
if (aty == jl_float16_type) \
13581339
jl_##name##16(16, pa, pb, pc, pr); \
1359-
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
1340+
else if (aty == jl_bfloat16_type) \
13601341
jl_##name##bf16(16, pa, pb, pc, pr); \
1361-
break; \
1362-
case 4: \
1342+
else if (aty == jl_float32_type) \
13631343
jl_##name##32(32, pa, pb, pc, pr); \
1364-
break; \
1365-
case 8: \
1344+
else if (aty == jl_float64_type) \
13661345
jl_##name##64(64, pa, pb, pc, pr); \
1367-
break; \
1368-
default: \
1369-
jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); \
1370-
} \
1346+
else \
1347+
jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
13711348
return newv; \
13721349
}
13731350

@@ -1661,7 +1638,7 @@ static inline void fptrunc(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void
16611638
fptrunc_convert(float64, bfloat16);
16621639
fptrunc_convert(float64, float32);
16631640
else
1664-
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
1641+
jl_error("fptrunc: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64");
16651642
#undef fptrunc_convert
16661643
}
16671644

@@ -1685,7 +1662,7 @@ static inline void fpext(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void *
16851662
fpext_convert(bfloat16, float64);
16861663
fpext_convert(float32, float64);
16871664
else
1688-
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
1665+
jl_error("fptrunc: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64");
16891666
#undef fpext_convert
16901667
}
16911668

0 commit comments

Comments
 (0)