Skip to content

Commit d96a9a9

Browse files
committed
correctly handle union-store-splitting in "is" and "object_id"
1 parent a35763d commit d96a9a9

File tree

5 files changed

+156
-64
lines changed

5 files changed

+156
-64
lines changed

src/builtins.c

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,30 +77,40 @@ static int NOINLINE compare_svec(jl_svec_t *a, jl_svec_t *b)
7777
// See comment above for an explanation of NOINLINE.
7878
static int NOINLINE compare_fields(jl_value_t *a, jl_value_t *b, jl_datatype_t *dt)
7979
{
80-
size_t nf = jl_datatype_nfields(dt);
81-
for (size_t f=0; f < nf; f++) {
80+
size_t f, nf = jl_datatype_nfields(dt);
81+
for (f = 0; f < nf; f++) {
8282
size_t offs = jl_field_offset(dt, f);
8383
char *ao = (char*)jl_data_ptr(a) + offs;
8484
char *bo = (char*)jl_data_ptr(b) + offs;
85-
int eq;
8685
if (jl_field_isptr(dt, f)) {
8786
jl_value_t *af = *(jl_value_t**)ao;
8887
jl_value_t *bf = *(jl_value_t**)bo;
89-
if (af == bf) eq = 1;
90-
else if (af==NULL || bf==NULL) eq = 0;
91-
else eq = jl_egal(af, bf);
88+
if (af != bf) {
89+
if (af == NULL || bf == NULL)
90+
return 0;
91+
if (!jl_egal(af, bf))
92+
return 0;
93+
}
9294
}
9395
else {
9496
jl_datatype_t *ft = (jl_datatype_t*)jl_field_type(dt, f);
97+
if (jl_is_uniontype(ft)) {
98+
uint8_t asel = ((uint8_t*)ao)[jl_field_size(dt, f) - 1];
99+
uint8_t bsel = ((uint8_t*)bo)[jl_field_size(dt, f) - 1];
100+
if (asel != bsel)
101+
return 0;
102+
ft = (jl_datatype_t*)jl_nth_union_component((jl_value_t*)ft, asel);
103+
}
95104
if (!ft->layout->haspadding) {
96-
eq = bits_equal(ao, bo, jl_field_size(dt, f));
105+
if (!bits_equal(ao, bo, jl_field_size(dt, f)))
106+
return 0;
97107
}
98108
else {
99109
assert(jl_datatype_nfields(ft) > 0);
100-
eq = compare_fields((jl_value_t*)ao, (jl_value_t*)bo, ft);
110+
if (!compare_fields((jl_value_t*)ao, (jl_value_t*)bo, ft))
111+
return 0;
101112
}
102113
}
103-
if (!eq) return 0;
104114
}
105115
return 1;
106116
}
@@ -127,9 +137,11 @@ JL_DLLEXPORT int jl_egal(jl_value_t *a, jl_value_t *b)
127137
return 0;
128138
return !memcmp(jl_string_data(a), jl_string_data(b), l);
129139
}
130-
if (dt->mutabl) return 0;
140+
if (dt->mutabl)
141+
return 0;
131142
size_t sz = jl_datatype_size(dt);
132-
if (sz == 0) return 1;
143+
if (sz == 0)
144+
return 1;
133145
size_t nf = jl_datatype_nfields(dt);
134146
if (nf == 0)
135147
return bits_equal(jl_data_ptr(a), jl_data_ptr(b), sz);
@@ -161,10 +173,10 @@ static uintptr_t bits_hash(void *b, size_t sz)
161173
static uintptr_t NOINLINE hash_svec(jl_svec_t *v)
162174
{
163175
uintptr_t h = 0;
164-
size_t l = jl_svec_len(v);
165-
for(size_t i = 0; i < l; i++) {
166-
jl_value_t *x = jl_svecref(v,i);
167-
uintptr_t u = x==NULL ? 0 : jl_object_id(x);
176+
size_t i, l = jl_svec_len(v);
177+
for (i = 0; i < l; i++) {
178+
jl_value_t *x = jl_svecref(v, i);
179+
uintptr_t u = (x == NULL) ? 0 : jl_object_id(x);
168180
h = bitmix(h, u);
169181
}
170182
return h;
@@ -188,9 +200,11 @@ static uintptr_t jl_object_id_(jl_value_t *tv, jl_value_t *v)
188200
if (dt == jl_typename_type)
189201
return ((jl_typename_t*)v)->hash;
190202
#ifdef _P64
191-
if (v == jl_ANY_flag) return 0x31c472f68ee30bddULL;
203+
if (v == jl_ANY_flag)
204+
return 0x31c472f68ee30bddULL;
192205
#else
193-
if (v == jl_ANY_flag) return 0x8ee30bdd;
206+
if (v == jl_ANY_flag)
207+
return 0x8ee30bdd;
194208
#endif
195209
if (dt == jl_string_type) {
196210
#ifdef _P64
@@ -199,24 +213,29 @@ static uintptr_t jl_object_id_(jl_value_t *tv, jl_value_t *v)
199213
return memhash32_seed(jl_string_data(v), jl_string_len(v), 0xedc3b677);
200214
#endif
201215
}
202-
if (dt->mutabl) return inthash((uintptr_t)v);
216+
if (dt->mutabl)
217+
return inthash((uintptr_t)v);
203218
size_t sz = jl_datatype_size(tv);
204219
uintptr_t h = jl_object_id(tv);
205-
if (sz == 0) return ~h;
206-
size_t nf = jl_datatype_nfields(dt);
207-
if (nf == 0) {
220+
if (sz == 0)
221+
return ~h;
222+
size_t f, nf = jl_datatype_nfields(dt);
223+
if (nf == 0)
208224
return bits_hash(jl_data_ptr(v), sz) ^ h;
209-
}
210-
for (size_t f=0; f < nf; f++) {
225+
for (f = 0; f < nf; f++) {
211226
size_t offs = jl_field_offset(dt, f);
212227
char *vo = (char*)jl_data_ptr(v) + offs;
213228
uintptr_t u;
214229
if (jl_field_isptr(dt, f)) {
215230
jl_value_t *f = *(jl_value_t**)vo;
216-
u = f==NULL ? 0 : jl_object_id(f);
231+
u = (f == NULL) ? 0 : jl_object_id(f);
217232
}
218233
else {
219234
jl_datatype_t *fieldtype = (jl_datatype_t*)jl_field_type(dt, f);
235+
if (jl_is_uniontype(fieldtype)) {
236+
uint8_t sel = ((uint8_t*)vo)[jl_field_size(dt, f) - 1];
237+
fieldtype = (jl_datatype_t*)jl_nth_union_component((jl_value_t*)fieldtype, sel);
238+
}
220239
assert(jl_is_datatype(fieldtype) && !fieldtype->abstract && !fieldtype->mutabl);
221240
if (fieldtype->layout->haspadding)
222241
u = jl_object_id_((jl_value_t*)fieldtype, (jl_value_t*)vo);
@@ -244,7 +263,7 @@ JL_CALLABLE(jl_f_is)
244263
JL_NARGS(===, 2, 2);
245264
if (args[0] == args[1])
246265
return jl_true;
247-
return jl_egal(args[0],args[1]) ? jl_true : jl_false;
266+
return jl_egal(args[0], args[1]) ? jl_true : jl_false;
248267
}
249268

250269
JL_CALLABLE(jl_f_typeof)

src/cgutils.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2245,28 +2245,26 @@ static void emit_setfield(jl_codectx_t &ctx,
22452245
if (wb && strct.isboxed)
22462246
emit_checked_write_barrier(ctx, boxed(ctx, strct), r);
22472247
}
2248-
else {
2249-
if (jl_is_uniontype(jfty)) {
2250-
int fsz = jl_field_size(sty, idx0);
2251-
// compute tindex from rhs
2252-
jl_cgval_t rhs_union = convert_julia_type(ctx, rhs, jfty);
2253-
if (rhs_union.typ == jl_bottom_type)
2254-
return;
2255-
Value *tindex = compute_tindex_unboxed(ctx, rhs_union, jfty);
2256-
tindex = ctx.builder.CreateNUWSub(tindex, ConstantInt::get(T_int8, 1));
2257-
Value *ptindex = ctx.builder.CreateGEP(T_int8, emit_bitcast(ctx, addr, T_pint8), ConstantInt::get(T_size, fsz - 1));
2258-
ctx.builder.CreateStore(tindex, ptindex);
2259-
// copy data
2260-
if (!rhs.isghost) {
2261-
emit_unionmove(ctx, addr, rhs, NULL, false, NULL);
2262-
}
2263-
}
2264-
else {
2265-
int align = jl_field_align(sty, idx0);
2266-
typed_store(ctx, addr, ConstantInt::get(T_size, 0), rhs, jfty,
2267-
strct.tbaa, data_pointer(ctx, strct, T_pjlvalue), align);
2248+
else if (jl_is_uniontype(jfty)) {
2249+
int fsz = jl_field_size(sty, idx0);
2250+
// compute tindex from rhs
2251+
jl_cgval_t rhs_union = convert_julia_type(ctx, rhs, jfty);
2252+
if (rhs_union.typ == jl_bottom_type)
2253+
return;
2254+
Value *tindex = compute_tindex_unboxed(ctx, rhs_union, jfty);
2255+
tindex = ctx.builder.CreateNUWSub(tindex, ConstantInt::get(T_int8, 1));
2256+
Value *ptindex = ctx.builder.CreateGEP(T_int8, emit_bitcast(ctx, addr, T_pint8), ConstantInt::get(T_size, fsz - 1));
2257+
ctx.builder.CreateStore(tindex, ptindex);
2258+
// copy data
2259+
if (!rhs.isghost) {
2260+
emit_unionmove(ctx, addr, rhs, NULL, false, NULL);
22682261
}
22692262
}
2263+
else {
2264+
int align = jl_field_align(sty, idx0);
2265+
typed_store(ctx, addr, ConstantInt::get(T_size, 0), rhs, jfty,
2266+
strct.tbaa, data_pointer(ctx, strct, T_pjlvalue), align);
2267+
}
22702268
}
22712269
else {
22722270
// TODO: better error

src/codegen.cpp

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,11 +2077,48 @@ static jl_cgval_t emit_getfield(jl_codectx_t &ctx, const jl_cgval_t &strct, jl_s
20772077
return mark_julia_type(ctx, result, true, jl_any_type, needsgcroot);
20782078
}
20792079

2080+
static Value *emit_bits_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgval_t &arg2);
2081+
2082+
static Value *emit_bitsunion_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgval_t &arg2)
2083+
{
2084+
assert(arg1.typ == arg2.typ && arg1.TIndex && arg2.TIndex && jl_is_uniontype(arg1.typ) && "unimplemented");
2085+
Value *tindex = arg1.TIndex;
2086+
BasicBlock *defaultBB = BasicBlock::Create(jl_LLVMContext, "unionbits_is_boxed", ctx.f);
2087+
SwitchInst *switchInst = ctx.builder.CreateSwitch(tindex, defaultBB);
2088+
BasicBlock *postBB = BasicBlock::Create(jl_LLVMContext, "post_unionbits_is", ctx.f);
2089+
ctx.builder.SetInsertPoint(postBB);
2090+
PHINode *phi = ctx.builder.CreatePHI(T_int1, 2);
2091+
unsigned counter = 0;
2092+
for_each_uniontype_small(
2093+
[&](unsigned idx, jl_datatype_t *jt) {
2094+
BasicBlock *tempBB = BasicBlock::Create(jl_LLVMContext, "unionbits_is", ctx.f);
2095+
ctx.builder.SetInsertPoint(tempBB);
2096+
switchInst->addCase(ConstantInt::get(T_int8, idx), tempBB);
2097+
jl_cgval_t sel_arg1(arg1, (jl_value_t*)jt, NULL);
2098+
jl_cgval_t sel_arg2(arg2, (jl_value_t*)jt, NULL);
2099+
phi->addIncoming(emit_bits_compare(ctx, sel_arg1, sel_arg2), tempBB);
2100+
ctx.builder.CreateBr(postBB);
2101+
},
2102+
arg1.typ,
2103+
counter);
2104+
ctx.builder.SetInsertPoint(defaultBB);
2105+
Function *trap_func = Intrinsic::getDeclaration(
2106+
ctx.f->getParent(),
2107+
Intrinsic::trap);
2108+
ctx.builder.CreateCall(trap_func);
2109+
ctx.builder.CreateUnreachable();
2110+
ctx.builder.SetInsertPoint(postBB);
2111+
return ctx.builder.CreateAnd(phi, ctx.builder.CreateICmpEQ(arg1.TIndex, arg2.TIndex));
2112+
}
2113+
20802114
static Value *emit_bits_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgval_t &arg2)
20812115
{
20822116
assert(jl_is_datatype(arg1.typ) && arg1.typ == arg2.typ);
20832117
Type *at = julia_type_to_llvm(arg1.typ);
20842118

2119+
if (type_is_ghost(at))
2120+
return ConstantInt::get(T_int1, 1);
2121+
20852122
if (at->isIntegerTy() || at->isPointerTy() || at->isFloatingPointTy()) {
20862123
Type *at_int = INTT(at);
20872124
Value *varg1 = emit_unbox(ctx, at_int, arg1, arg1.typ);
@@ -2130,11 +2167,29 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const
21302167
Value *subAns, *fld1, *fld2;
21312168
fld1 = ctx.builder.CreateConstGEP2_32(at, varg1, 0, i);
21322169
fld2 = ctx.builder.CreateConstGEP2_32(at, varg2, 0, i);
2133-
if (type_is_ghost(fld1->getType()->getPointerElementType()))
2170+
Type *at_i = cast<GetElementPtrInst>(fld1)->getResultElementType();
2171+
if (type_is_ghost(at_i))
21342172
continue;
2135-
subAns = emit_bits_compare(ctx,
2136-
mark_julia_slot(fld1, fldty, NULL, arg1.tbaa),
2137-
mark_julia_slot(fld2, fldty, NULL, arg2.tbaa));
2173+
if (jl_is_uniontype(fldty)) {
2174+
unsigned tindex_offset = cast<StructType>(at_i)->getNumElements() - 1;
2175+
Value *ptindex1 = ctx.builder.CreateConstInBoundsGEP2_32(
2176+
at_i, fld1, 0, tindex_offset);
2177+
Value *ptindex2 = ctx.builder.CreateConstInBoundsGEP2_32(
2178+
at_i, fld2, 0, tindex_offset);
2179+
Value *tindex1 = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1),
2180+
ctx.builder.CreateLoad(T_int8, ptindex1));
2181+
Value *tindex2 = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1),
2182+
ctx.builder.CreateLoad(T_int8, ptindex2));
2183+
subAns = emit_bitsunion_compare(ctx,
2184+
mark_julia_slot(fld1, fldty, tindex1, arg1.tbaa),
2185+
mark_julia_slot(fld2, fldty, tindex2, arg2.tbaa));
2186+
}
2187+
else {
2188+
assert(jl_is_leaf_type(fldty));
2189+
subAns = emit_bits_compare(ctx,
2190+
mark_julia_slot(fld1, fldty, NULL, arg1.tbaa),
2191+
mark_julia_slot(fld2, fldty, NULL, arg2.tbaa));
2192+
}
21382193
answer = ctx.builder.CreateAnd(answer, subAns);
21392194
}
21402195
return answer;
@@ -2198,6 +2253,9 @@ static Value *emit_f_is(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgva
21982253
return cmp;
21992254
}
22002255

2256+
// if (arg1.tindex || arg2.tindex)
2257+
// TODO: handle with emit_bitsunion_compare
2258+
22012259
int ptr_comparable = 0; // whether this type is unique'd by pointer
22022260
if (rt1 == (jl_value_t*)jl_sym_type || rt2 == (jl_value_t*)jl_sym_type)
22032261
ptr_comparable = 1;
@@ -2413,7 +2471,8 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
24132471
Value *selidx = ctx.builder.CreateMul(emit_arraylen_prim(ctx, ary), nbytes);
24142472
selidx = ctx.builder.CreateAdd(selidx, idx);
24152473
Value *ptindex = ctx.builder.CreateGEP(T_int8, data, selidx);
2416-
Value *tindex = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateLoad(T_int8, ptindex)));
2474+
Value *tindex = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1),
2475+
tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateLoad(T_int8, ptindex)));
24172476
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
24182477
AllocaInst *lv = emit_static_alloca(ctx, AT);
24192478
if (al > 1)

src/rtutils.c

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -883,9 +883,12 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
883883
n += jl_static_show_x(out, *(jl_value_t**)fld_ptr, depth);
884884
}
885885
else {
886-
n += jl_static_show_x_(out, (jl_value_t*)fld_ptr,
887-
(jl_datatype_t*)jl_field_type(vt, i),
888-
depth);
886+
jl_datatype_t *ft = (jl_datatype_t*)jl_field_type(vt, i);
887+
if (jl_is_uniontype(ft)) {
888+
uint8_t sel = ((uint8_t*)fld_ptr)[jl_field_size(vt, i) - 1];
889+
ft = (jl_datatype_t*)jl_nth_union_component((jl_value_t*)ft, sel);
890+
}
891+
n += jl_static_show_x_(out, (jl_value_t*)fld_ptr, ft, depth);
889892
}
890893
if (istuple && tlen == 1)
891894
n += jl_printf(out, ",");

test/core.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5326,17 +5326,30 @@ struct B
53265326
y::AA
53275327
z::Int8
53285328
end
5329-
b = B(91, AA(ntuple(i -> Int8(i), Val(7))), 23)
5330-
5331-
@test b.x === Int8(91)
5332-
@test b.z === Int8(23)
5333-
@test b.y === AA(ntuple(i -> Int8(i), Val(7)))
5334-
@test sizeof(b) == 12
5335-
@test AA(Int8(1)).x === Int8(1)
5336-
@test AA(Int8(0)).x === Int8(0)
5337-
@test AA(Int16(1)).x === Int16(1)
5338-
@test AA(nothing).x === nothing
5339-
@test sizeof(b.y) == 8
5329+
@noinline compare(a, b) = (a === b) # test code-generation of `is`
5330+
let
5331+
b = B(91, AA(ntuple(i -> Int8(i), Val(7))), 23)
5332+
b2 = Ref(b)[] # copy b via field assignment
5333+
b3 = B[b][1] # copy b via array assignment
5334+
@test pointer_from_objref(b) == pointer_from_objref(b)
5335+
@test pointer_from_objref(b) != pointer_from_objref(b2)
5336+
@test pointer_from_objref(b) != pointer_from_objref(b3)
5337+
@test pointer_from_objref(b2) != pointer_from_objref(b3)
5338+
5339+
@test b === b2 === b3
5340+
@test compare(b, b2)
5341+
@test compare(b, b3)
5342+
@test object_id(b) === object_id(b2) == object_id(b3)
5343+
@test b.x === Int8(91)
5344+
@test b.z === Int8(23)
5345+
@test b.y === AA((Int8(1), Int8(2), Int8(3), Int8(4), Int8(5), Int8(6), Int8(7)))
5346+
@test sizeof(b) == 12
5347+
@test AA(Int8(1)).x === Int8(1)
5348+
@test AA(Int8(0)).x === Int8(0)
5349+
@test AA(Int16(1)).x === Int16(1)
5350+
@test AA(nothing).x === nothing
5351+
@test sizeof(b.y) == 8
5352+
end
53405353

53415354
for U in boxedunions
53425355
local U

0 commit comments

Comments
 (0)