Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make remaining float intrinsics require float arguments #57398

Merged
merged 9 commits into from
Mar 21, 2025
54 changes: 51 additions & 3 deletions Compiler/src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2432,6 +2432,43 @@ const _SPECIAL_BUILTINS = Any[
Core._apply_iterate,
]

# Intrinsics that require all arguments to be floats
const _FLOAT_INTRINSICS = Any[
Intrinsics.neg_float,
Intrinsics.add_float,
Intrinsics.sub_float,
Intrinsics.mul_float,
Intrinsics.div_float,
Intrinsics.min_float,
Intrinsics.max_float,
Intrinsics.fma_float,
Intrinsics.muladd_float,
Intrinsics.neg_float_fast,
Intrinsics.add_float_fast,
Intrinsics.sub_float_fast,
Intrinsics.mul_float_fast,
Intrinsics.div_float_fast,
Intrinsics.min_float_fast,
Intrinsics.max_float_fast,
Intrinsics.eq_float,
Intrinsics.ne_float,
Intrinsics.lt_float,
Intrinsics.le_float,
Intrinsics.eq_float_fast,
Intrinsics.ne_float_fast,
Intrinsics.lt_float_fast,
Intrinsics.le_float_fast,
Intrinsics.fpiseq,
Intrinsics.abs_float,
Intrinsics.copysign_float,
Intrinsics.ceil_llvm,
Intrinsics.floor_llvm,
Intrinsics.trunc_llvm,
Intrinsics.rint_llvm,
Intrinsics.sqrt_llvm,
Intrinsics.sqrt_llvm_fast
]

# Types compatible with fpext/fptrunc
const CORE_FLOAT_TYPES = Union{Core.BFloat16, Float16, Float32, Float64}

Expand Down Expand Up @@ -2849,7 +2886,8 @@ function intrinsic_exct(𝕃::AbstractLattice, f::IntrinsicFunction, argtypes::V
return ErrorException
end

# fpext and fptrunc have further restrictions on the allowed types.
# fpext, fptrunc, fptoui, fptosi, uitofp, and sitofp have further
# restrictions on the allowed types.
if f === Intrinsics.fpext &&
!(ty <: CORE_FLOAT_TYPES && xty <: CORE_FLOAT_TYPES && Core.sizeof(ty) > Core.sizeof(xty))
return ErrorException
Expand All @@ -2858,6 +2896,12 @@ function intrinsic_exct(𝕃::AbstractLattice, f::IntrinsicFunction, argtypes::V
!(ty <: CORE_FLOAT_TYPES && xty <: CORE_FLOAT_TYPES && Core.sizeof(ty) < Core.sizeof(xty))
return ErrorException
end
if (f === Intrinsics.fptoui || f === Intrinsics.fptosi) && !(xty <: CORE_FLOAT_TYPES)
return ErrorException
end
if (f === Intrinsics.uitofp || f === Intrinsics.sitofp) && !(ty <: CORE_FLOAT_TYPES)
return ErrorException
end

return Union{}
end
Expand All @@ -2870,11 +2914,15 @@ function intrinsic_exct(𝕃::AbstractLattice, f::IntrinsicFunction, argtypes::V
return Union{}
end

# The remaining intrinsics are math/bits/comparison intrinsics. They work on all
# primitive types of the same type.
# The remaining intrinsics are math/bits/comparison intrinsics.
# All the non-floating point intrinsics work on primitive values of the same type.
isshift = f === shl_int || f === lshr_int || f === ashr_int
argtype1 = widenconst(argtypes[1])
isprimitivetype(argtype1) || return ErrorException
if contains_is(_FLOAT_INTRINSICS, f)
argtype1 <: CORE_FLOAT_TYPES || return ErrorException
end

for i = 2:length(argtypes)
argtype = widenconst(argtypes[i])
if isshift ? !isprimitivetype(argtype) : argtype !== argtype1
Expand Down
26 changes: 26 additions & 0 deletions Compiler/test/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1401,3 +1401,29 @@ end == Compiler.EFFECTS_UNKNOWN
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Float64])
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Int32}, Float16])
@test !Compiler.intrinsic_nothrow(Core.Intrinsics.fpext, Any[Type{Float32}, Int16])

# Float intrinsics require float arguments
@test Base.infer_effects((Int16,)) do x
return Core.Intrinsics.abs_float(x)
end |> !Compiler.is_nothrow
@test Base.infer_effects((Int32, Int32)) do x, y
return Core.Intrinsics.add_float(x, y)
end |> !Compiler.is_nothrow
@test Base.infer_effects((Int32, Int32)) do x, y
return Core.Intrinsics.add_float(x, y)
end |> !Compiler.is_nothrow
@test Base.infer_effects((Int64, Int64, Int64)) do x, y, z
return Core.Intrinsics.fma_float(x, y, z)
end |> !Compiler.is_nothrow
@test Base.infer_effects((Int64,)) do x
return Core.Intrinsics.fptoui(UInt32, x)
end |> !Compiler.is_nothrow
@test Base.infer_effects((Int64,)) do x
return Core.Intrinsics.fptosi(Int32, x)
end |> !Compiler.is_nothrow
@test Base.infer_effects((Int64,)) do x
return Core.Intrinsics.sitofp(Int64, x)
end |> !Compiler.is_nothrow
@test Base.infer_effects((UInt64,)) do x
return Core.Intrinsics.uitofp(Int64, x)
end |> !Compiler.is_nothrow
4 changes: 2 additions & 2 deletions src/APInt-C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ void LLVMFPtoInt(jl_datatype_t *ty, void *pa, jl_datatype_t *oty, integerPart *p
Val = julia_bfloat_to_float(*(uint16_t*)pa);
else if (ty == jl_float32_type)
Val = *(float*)pa;
else if (jl_float64_type)
else if (ty == jl_float64_type)
Val = *(double*)pa;
else
jl_error("FPtoSI: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
Expand Down Expand Up @@ -352,7 +352,7 @@ void LLVMFPtoInt(jl_datatype_t *ty, void *pa, jl_datatype_t *oty, integerPart *p
else {
APFloat a(Val);
bool isVeryExact;
APFloat::roundingMode rounding_mode = APFloat::rmNearestTiesToEven;
APFloat::roundingMode rounding_mode = RoundingMode::TowardZero;
unsigned nbytes = alignTo(onumbits, integerPartWidth) / host_char_bit;
integerPart *parts = (integerPart*)alloca(nbytes);
APFloat::opStatus status = a.convertToInteger(MutableArrayRef<integerPart>(parts, nbytes), onumbits, isSigned, rounding_mode, &isVeryExact);
Expand Down
35 changes: 22 additions & 13 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,17 +676,23 @@ static jl_cgval_t generic_cast(
Type *to = bitstype_to_llvm((jl_value_t*)jlto, ctx.builder.getContext(), true);
Type *vt = bitstype_to_llvm(v.typ, ctx.builder.getContext(), true);

// fptrunc fpext depend on the specific floating point format to work
// correctly, and so do not pun their argument types.
// fptrunc and fpext depend on the specific floating point
// format to work correctly, and so do not pun their argument types.
if (!(f == fpext || f == fptrunc)) {
if (toint)
to = INTT(to, DL);
else
to = FLOATT(to);
if (fromint)
vt = INTT(vt, DL);
else
vt = FLOATT(vt);
// uitofp/sitofp require a specific float type argument
if (!(f == uitofp || f == sitofp)){
if (toint)
to = INTT(to, DL);
else
to = FLOATT(to);
}
// fptoui/fptosi require a specific float value argument
if (!(f == fptoui || f == fptosi)) {
if (fromint)
vt = INTT(vt, DL);
else
vt = FLOATT(vt);
}
}

if (!to || !vt)
Expand Down Expand Up @@ -1428,10 +1434,13 @@ static jl_cgval_t emit_intrinsic(jl_codectx_t &ctx, intrinsic f, jl_value_t **ar
if (!jl_is_primitivetype(xinfo.typ))
return emit_runtime_call(ctx, f, argv, nargs);
Type *xtyp = bitstype_to_llvm(xinfo.typ, ctx.builder.getContext(), true);
if (float_func()[f])
xtyp = FLOATT(xtyp);
else
if (float_func()[f]) {
if (!xtyp->isFloatingPointTy())
return emit_runtime_call(ctx, f, argv, nargs);
}
else {
xtyp = INTT(xtyp, DL);
}
if (!xtyp)
return emit_runtime_call(ctx, f, argv, nargs);
////Bool are required to be in the range [0,1]
Expand Down
103 changes: 40 additions & 63 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -1073,31 +1073,26 @@ typedef void (fintrinsic_op1)(unsigned, jl_value_t*, void*, void*);
static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const char *name, fintrinsic_op1 *bfloatop, fintrinsic_op1 *halfop, fintrinsic_op1 *floatop, fintrinsic_op1 *doubleop)
{
jl_task_t *ct = jl_current_task;
if (!jl_is_primitivetype(jl_typeof(a)))
jl_datatype_t *aty = (jl_datatype_t *)jl_typeof(a);
if (!jl_is_primitivetype(aty))
jl_errorf("%s: value is not a primitive type", name);
if (!jl_is_primitivetype(ty))
jl_errorf("%s: type is not a primitive type", name);
unsigned sz2 = jl_datatype_size(ty);
jl_value_t *newv = jl_gc_alloc(ct->ptls, sz2, ty);
void *pa = jl_data_ptr(a), *pr = jl_data_ptr(newv);
unsigned sz = jl_datatype_size(jl_typeof(a));
switch (sz) {
/* choose the right size c-type operation based on the input */
case 2:
if (jl_typeof(a) == (jl_value_t*)jl_float16_type)
halfop(sz2 * host_char_bit, ty, pa, pr);
else /*if (jl_typeof(a) == (jl_value_t*)jl_bfloat16_type)*/
bfloatop(sz2 * host_char_bit, ty, pa, pr);
break;
case 4:

if (aty == jl_float16_type)
halfop(sz2 * host_char_bit, ty, pa, pr);
else if (aty == jl_bfloat16_type)
bfloatop(sz2 * host_char_bit, ty, pa, pr);
else if (aty == jl_float32_type)
floatop(sz2 * host_char_bit, ty, pa, pr);
break;
case 8:
else if (aty == jl_float64_type)
doubleop(sz2 * host_char_bit, ty, pa, pr);
break;
default:
jl_errorf("%s: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64", name);
}
else
jl_errorf("%s: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64", name);

return newv;
}

Expand Down Expand Up @@ -1273,30 +1268,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
{ \
jl_task_t *ct = jl_current_task; \
jl_value_t *ty = jl_typeof(a); \
jl_datatype_t *aty = (jl_datatype_t *)ty; \
if (jl_typeof(b) != ty) \
jl_error(#name ": types of a and b must match"); \
if (!jl_is_primitivetype(ty)) \
jl_error(#name ": values are not primitive types"); \
int sz = jl_datatype_size(ty); \
jl_value_t *newv = jl_gc_alloc(ct->ptls, sz, ty); \
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b), *pr = jl_data_ptr(newv); \
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
if ((jl_datatype_t*)ty == jl_float16_type) \
jl_##name##16(16, pa, pb, pr); \
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
jl_##name##bf16(16, pa, pb, pr); \
break; \
case 4: \
if (aty == jl_float16_type) \
jl_##name##16(16, pa, pb, pr); \
else if (aty == jl_bfloat16_type) \
jl_##name##bf16(16, pa, pb, pr); \
else if (aty == jl_float32_type) \
jl_##name##32(32, pa, pb, pr); \
break; \
case 8: \
else if (aty == jl_float64_type) \
jl_##name##64(64, pa, pb, pr); \
break; \
default: \
jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); \
} \
else \
jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
return newv; \
}

Expand All @@ -1308,30 +1297,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
{ \
jl_value_t *ty = jl_typeof(a); \
jl_datatype_t *aty = (jl_datatype_t *)ty; \
if (jl_typeof(b) != ty) \
jl_error(#name ": types of a and b must match"); \
if (!jl_is_primitivetype(ty)) \
jl_error(#name ": values are not primitive types"); \
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b); \
int sz = jl_datatype_size(ty); \
int cmp; \
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
if ((jl_datatype_t*)ty == jl_float16_type) \
cmp = jl_##name##16(16, pa, pb); \
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
cmp = jl_##name##bf16(16, pa, pb); \
break; \
case 4: \
if (aty == jl_float16_type) \
cmp = jl_##name##16(16, pa, pb); \
else if (aty == jl_bfloat16_type) \
cmp = jl_##name##bf16(16, pa, pb); \
else if (aty == jl_float32_type) \
cmp = jl_##name##32(32, pa, pb); \
break; \
case 8: \
else if (aty == jl_float64_type) \
cmp = jl_##name##64(64, pa, pb); \
break; \
default: \
jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64"); \
} \
else \
jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
\
return cmp ? jl_true : jl_false; \
}

Expand All @@ -1344,30 +1327,24 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c)
{ \
jl_task_t *ct = jl_current_task; \
jl_value_t *ty = jl_typeof(a); \
jl_datatype_t *aty = (jl_datatype_t *)ty; \
if (jl_typeof(b) != ty || jl_typeof(c) != ty) \
jl_error(#name ": types of a, b, and c must match"); \
if (!jl_is_primitivetype(ty)) \
jl_error(#name ": values are not primitive types"); \
int sz = jl_datatype_size(ty); \
jl_value_t *newv = jl_gc_alloc(ct->ptls, sz, ty); \
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b), *pc = jl_data_ptr(c), *pr = jl_data_ptr(newv); \
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
if ((jl_datatype_t*)ty == jl_float16_type) \
if (aty == jl_float16_type) \
jl_##name##16(16, pa, pb, pc, pr); \
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
else if (aty == jl_bfloat16_type) \
jl_##name##bf16(16, pa, pb, pc, pr); \
break; \
case 4: \
else if (aty == jl_float32_type) \
jl_##name##32(32, pa, pb, pc, pr); \
break; \
case 8: \
else if (aty == jl_float64_type) \
jl_##name##64(64, pa, pb, pc, pr); \
break; \
default: \
jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); \
} \
else \
jl_error(#name ": runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64"); \
return newv; \
}

Expand Down Expand Up @@ -1661,7 +1638,7 @@ static inline void fptrunc(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void
fptrunc_convert(float64, bfloat16);
fptrunc_convert(float64, float32);
else
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
jl_error("fptrunc: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64");
#undef fptrunc_convert
}

Expand All @@ -1685,7 +1662,7 @@ static inline void fpext(jl_datatype_t *aty, void *pa, jl_datatype_t *ty, void *
fpext_convert(bfloat16, float64);
fpext_convert(float32, float64);
else
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
jl_error("fptrunc: runtime floating point intrinsics require both arguments to be Float16, BFloat16, Float32, or Float64");
#undef fpext_convert
}

Expand Down
Loading