Skip to content

Commit 1cf7f28

Browse files
vtjnashquinnj
authored andcommitted
correctly handle union-store-splitting in "is" and "object_id"
1 parent f0488f2 commit 1cf7f28

File tree

5 files changed

+156
-64
lines changed

5 files changed

+156
-64
lines changed

src/builtins.c

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,30 +77,40 @@ static int NOINLINE compare_svec(jl_svec_t *a, jl_svec_t *b)
7777
// See comment above for an explanation of NOINLINE.
7878
static int NOINLINE compare_fields(jl_value_t *a, jl_value_t *b, jl_datatype_t *dt)
7979
{
80-
size_t nf = jl_datatype_nfields(dt);
81-
for (size_t f=0; f < nf; f++) {
80+
size_t f, nf = jl_datatype_nfields(dt);
81+
for (f = 0; f < nf; f++) {
8282
size_t offs = jl_field_offset(dt, f);
8383
char *ao = (char*)jl_data_ptr(a) + offs;
8484
char *bo = (char*)jl_data_ptr(b) + offs;
85-
int eq;
8685
if (jl_field_isptr(dt, f)) {
8786
jl_value_t *af = *(jl_value_t**)ao;
8887
jl_value_t *bf = *(jl_value_t**)bo;
89-
if (af == bf) eq = 1;
90-
else if (af==NULL || bf==NULL) eq = 0;
91-
else eq = jl_egal(af, bf);
88+
if (af != bf) {
89+
if (af == NULL || bf == NULL)
90+
return 0;
91+
if (!jl_egal(af, bf))
92+
return 0;
93+
}
9294
}
9395
else {
9496
jl_datatype_t *ft = (jl_datatype_t*)jl_field_type(dt, f);
97+
if (jl_is_uniontype(ft)) {
98+
uint8_t asel = ((uint8_t*)ao)[jl_field_size(dt, f) - 1];
99+
uint8_t bsel = ((uint8_t*)bo)[jl_field_size(dt, f) - 1];
100+
if (asel != bsel)
101+
return 0;
102+
ft = (jl_datatype_t*)jl_nth_union_component((jl_value_t*)ft, asel);
103+
}
95104
if (!ft->layout->haspadding) {
96-
eq = bits_equal(ao, bo, jl_field_size(dt, f));
105+
if (!bits_equal(ao, bo, jl_field_size(dt, f)))
106+
return 0;
97107
}
98108
else {
99109
assert(jl_datatype_nfields(ft) > 0);
100-
eq = compare_fields((jl_value_t*)ao, (jl_value_t*)bo, ft);
110+
if (!compare_fields((jl_value_t*)ao, (jl_value_t*)bo, ft))
111+
return 0;
101112
}
102113
}
103-
if (!eq) return 0;
104114
}
105115
return 1;
106116
}
@@ -127,9 +137,11 @@ JL_DLLEXPORT int jl_egal(jl_value_t *a, jl_value_t *b)
127137
return 0;
128138
return !memcmp(jl_string_data(a), jl_string_data(b), l);
129139
}
130-
if (dt->mutabl) return 0;
140+
if (dt->mutabl)
141+
return 0;
131142
size_t sz = jl_datatype_size(dt);
132-
if (sz == 0) return 1;
143+
if (sz == 0)
144+
return 1;
133145
size_t nf = jl_datatype_nfields(dt);
134146
if (nf == 0)
135147
return bits_equal(jl_data_ptr(a), jl_data_ptr(b), sz);
@@ -161,10 +173,10 @@ static uintptr_t bits_hash(void *b, size_t sz)
161173
static uintptr_t NOINLINE hash_svec(jl_svec_t *v)
162174
{
163175
uintptr_t h = 0;
164-
size_t l = jl_svec_len(v);
165-
for(size_t i = 0; i < l; i++) {
166-
jl_value_t *x = jl_svecref(v,i);
167-
uintptr_t u = x==NULL ? 0 : jl_object_id(x);
176+
size_t i, l = jl_svec_len(v);
177+
for (i = 0; i < l; i++) {
178+
jl_value_t *x = jl_svecref(v, i);
179+
uintptr_t u = (x == NULL) ? 0 : jl_object_id(x);
168180
h = bitmix(h, u);
169181
}
170182
return h;
@@ -188,9 +200,11 @@ static uintptr_t jl_object_id_(jl_value_t *tv, jl_value_t *v)
188200
if (dt == jl_typename_type)
189201
return ((jl_typename_t*)v)->hash;
190202
#ifdef _P64
191-
if (v == jl_ANY_flag) return 0x31c472f68ee30bddULL;
203+
if (v == jl_ANY_flag)
204+
return 0x31c472f68ee30bddULL;
192205
#else
193-
if (v == jl_ANY_flag) return 0x8ee30bdd;
206+
if (v == jl_ANY_flag)
207+
return 0x8ee30bdd;
194208
#endif
195209
if (dt == jl_string_type) {
196210
#ifdef _P64
@@ -199,24 +213,29 @@ static uintptr_t jl_object_id_(jl_value_t *tv, jl_value_t *v)
199213
return memhash32_seed(jl_string_data(v), jl_string_len(v), 0xedc3b677);
200214
#endif
201215
}
202-
if (dt->mutabl) return inthash((uintptr_t)v);
216+
if (dt->mutabl)
217+
return inthash((uintptr_t)v);
203218
size_t sz = jl_datatype_size(tv);
204219
uintptr_t h = jl_object_id(tv);
205-
if (sz == 0) return ~h;
206-
size_t nf = jl_datatype_nfields(dt);
207-
if (nf == 0) {
220+
if (sz == 0)
221+
return ~h;
222+
size_t f, nf = jl_datatype_nfields(dt);
223+
if (nf == 0)
208224
return bits_hash(jl_data_ptr(v), sz) ^ h;
209-
}
210-
for (size_t f=0; f < nf; f++) {
225+
for (f = 0; f < nf; f++) {
211226
size_t offs = jl_field_offset(dt, f);
212227
char *vo = (char*)jl_data_ptr(v) + offs;
213228
uintptr_t u;
214229
if (jl_field_isptr(dt, f)) {
215230
jl_value_t *f = *(jl_value_t**)vo;
216-
u = f==NULL ? 0 : jl_object_id(f);
231+
u = (f == NULL) ? 0 : jl_object_id(f);
217232
}
218233
else {
219234
jl_datatype_t *fieldtype = (jl_datatype_t*)jl_field_type(dt, f);
235+
if (jl_is_uniontype(fieldtype)) {
236+
uint8_t sel = ((uint8_t*)vo)[jl_field_size(dt, f) - 1];
237+
fieldtype = (jl_datatype_t*)jl_nth_union_component((jl_value_t*)fieldtype, sel);
238+
}
220239
assert(jl_is_datatype(fieldtype) && !fieldtype->abstract && !fieldtype->mutabl);
221240
if (fieldtype->layout->haspadding)
222241
u = jl_object_id_((jl_value_t*)fieldtype, (jl_value_t*)vo);
@@ -244,7 +263,7 @@ JL_CALLABLE(jl_f_is)
244263
JL_NARGS(===, 2, 2);
245264
if (args[0] == args[1])
246265
return jl_true;
247-
return jl_egal(args[0],args[1]) ? jl_true : jl_false;
266+
return jl_egal(args[0], args[1]) ? jl_true : jl_false;
248267
}
249268

250269
JL_CALLABLE(jl_f_typeof)

src/cgutils.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,28 +2214,26 @@ static void emit_setfield(jl_codectx_t &ctx,
22142214
if (wb && strct.isboxed)
22152215
emit_checked_write_barrier(ctx, boxed(ctx, strct), r);
22162216
}
2217-
else {
2218-
if (jl_is_uniontype(jfty)) {
2219-
int fsz = jl_field_size(sty, idx0);
2220-
// compute tindex from rhs
2221-
jl_cgval_t rhs_union = convert_julia_type(ctx, rhs, jfty);
2222-
if (rhs_union.typ == jl_bottom_type)
2223-
return;
2224-
Value *tindex = compute_tindex_unboxed(ctx, rhs_union, jfty);
2225-
tindex = ctx.builder.CreateNUWSub(tindex, ConstantInt::get(T_int8, 1));
2226-
Value *ptindex = ctx.builder.CreateGEP(T_int8, emit_bitcast(ctx, addr, T_pint8), ConstantInt::get(T_size, fsz - 1));
2227-
ctx.builder.CreateStore(tindex, ptindex);
2228-
// copy data
2229-
if (!rhs.isghost) {
2230-
emit_unionmove(ctx, addr, rhs, NULL, false, NULL);
2231-
}
2232-
}
2233-
else {
2234-
int align = jl_field_align(sty, idx0);
2235-
typed_store(ctx, addr, ConstantInt::get(T_size, 0), rhs, jfty,
2236-
strct.tbaa, data_pointer(ctx, strct, T_pjlvalue), align);
2217+
else if (jl_is_uniontype(jfty)) {
2218+
int fsz = jl_field_size(sty, idx0);
2219+
// compute tindex from rhs
2220+
jl_cgval_t rhs_union = convert_julia_type(ctx, rhs, jfty);
2221+
if (rhs_union.typ == jl_bottom_type)
2222+
return;
2223+
Value *tindex = compute_tindex_unboxed(ctx, rhs_union, jfty);
2224+
tindex = ctx.builder.CreateNUWSub(tindex, ConstantInt::get(T_int8, 1));
2225+
Value *ptindex = ctx.builder.CreateGEP(T_int8, emit_bitcast(ctx, addr, T_pint8), ConstantInt::get(T_size, fsz - 1));
2226+
ctx.builder.CreateStore(tindex, ptindex);
2227+
// copy data
2228+
if (!rhs.isghost) {
2229+
emit_unionmove(ctx, addr, rhs, NULL, false, NULL);
22372230
}
22382231
}
2232+
else {
2233+
int align = jl_field_align(sty, idx0);
2234+
typed_store(ctx, addr, ConstantInt::get(T_size, 0), rhs, jfty,
2235+
strct.tbaa, data_pointer(ctx, strct, T_pjlvalue), align);
2236+
}
22392237
}
22402238
else {
22412239
// TODO: better error

src/codegen.cpp

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2036,11 +2036,48 @@ static jl_cgval_t emit_getfield(jl_codectx_t &ctx, const jl_cgval_t &strct, jl_s
20362036
return mark_julia_type(ctx, result, true, jl_any_type);
20372037
}
20382038

2039+
static Value *emit_bits_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgval_t &arg2);
2040+
2041+
static Value *emit_bitsunion_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgval_t &arg2)
2042+
{
2043+
assert(arg1.typ == arg2.typ && arg1.TIndex && arg2.TIndex && jl_is_uniontype(arg1.typ) && "unimplemented");
2044+
Value *tindex = arg1.TIndex;
2045+
BasicBlock *defaultBB = BasicBlock::Create(jl_LLVMContext, "unionbits_is_boxed", ctx.f);
2046+
SwitchInst *switchInst = ctx.builder.CreateSwitch(tindex, defaultBB);
2047+
BasicBlock *postBB = BasicBlock::Create(jl_LLVMContext, "post_unionbits_is", ctx.f);
2048+
ctx.builder.SetInsertPoint(postBB);
2049+
PHINode *phi = ctx.builder.CreatePHI(T_int1, 2);
2050+
unsigned counter = 0;
2051+
for_each_uniontype_small(
2052+
[&](unsigned idx, jl_datatype_t *jt) {
2053+
BasicBlock *tempBB = BasicBlock::Create(jl_LLVMContext, "unionbits_is", ctx.f);
2054+
ctx.builder.SetInsertPoint(tempBB);
2055+
switchInst->addCase(ConstantInt::get(T_int8, idx), tempBB);
2056+
jl_cgval_t sel_arg1(arg1, (jl_value_t*)jt, NULL);
2057+
jl_cgval_t sel_arg2(arg2, (jl_value_t*)jt, NULL);
2058+
phi->addIncoming(emit_bits_compare(ctx, sel_arg1, sel_arg2), tempBB);
2059+
ctx.builder.CreateBr(postBB);
2060+
},
2061+
arg1.typ,
2062+
counter);
2063+
ctx.builder.SetInsertPoint(defaultBB);
2064+
Function *trap_func = Intrinsic::getDeclaration(
2065+
ctx.f->getParent(),
2066+
Intrinsic::trap);
2067+
ctx.builder.CreateCall(trap_func);
2068+
ctx.builder.CreateUnreachable();
2069+
ctx.builder.SetInsertPoint(postBB);
2070+
return ctx.builder.CreateAnd(phi, ctx.builder.CreateICmpEQ(arg1.TIndex, arg2.TIndex));
2071+
}
2072+
20392073
static Value *emit_bits_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgval_t &arg2)
20402074
{
20412075
assert(jl_is_datatype(arg1.typ) && arg1.typ == arg2.typ);
20422076
Type *at = julia_type_to_llvm(arg1.typ);
20432077

2078+
if (type_is_ghost(at))
2079+
return ConstantInt::get(T_int1, 1);
2080+
20442081
if (at->isIntegerTy() || at->isPointerTy() || at->isFloatingPointTy()) {
20452082
Type *at_int = INTT(at);
20462083
Value *varg1 = emit_unbox(ctx, at_int, arg1, arg1.typ);
@@ -2089,11 +2126,29 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const
20892126
Value *subAns, *fld1, *fld2;
20902127
fld1 = ctx.builder.CreateConstGEP2_32(at, varg1, 0, i);
20912128
fld2 = ctx.builder.CreateConstGEP2_32(at, varg2, 0, i);
2092-
if (type_is_ghost(fld1->getType()->getPointerElementType()))
2129+
Type *at_i = cast<GetElementPtrInst>(fld1)->getResultElementType();
2130+
if (type_is_ghost(at_i))
20932131
continue;
2094-
subAns = emit_bits_compare(ctx,
2095-
mark_julia_slot(fld1, fldty, NULL, arg1.tbaa),
2096-
mark_julia_slot(fld2, fldty, NULL, arg2.tbaa));
2132+
if (jl_is_uniontype(fldty)) {
2133+
unsigned tindex_offset = cast<StructType>(at_i)->getNumElements() - 1;
2134+
Value *ptindex1 = ctx.builder.CreateConstInBoundsGEP2_32(
2135+
at_i, fld1, 0, tindex_offset);
2136+
Value *ptindex2 = ctx.builder.CreateConstInBoundsGEP2_32(
2137+
at_i, fld2, 0, tindex_offset);
2138+
Value *tindex1 = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1),
2139+
ctx.builder.CreateLoad(T_int8, ptindex1));
2140+
Value *tindex2 = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1),
2141+
ctx.builder.CreateLoad(T_int8, ptindex2));
2142+
subAns = emit_bitsunion_compare(ctx,
2143+
mark_julia_slot(fld1, fldty, tindex1, arg1.tbaa),
2144+
mark_julia_slot(fld2, fldty, tindex2, arg2.tbaa));
2145+
}
2146+
else {
2147+
assert(jl_is_leaf_type(fldty));
2148+
subAns = emit_bits_compare(ctx,
2149+
mark_julia_slot(fld1, fldty, NULL, arg1.tbaa),
2150+
mark_julia_slot(fld2, fldty, NULL, arg2.tbaa));
2151+
}
20972152
answer = ctx.builder.CreateAnd(answer, subAns);
20982153
}
20992154
return answer;
@@ -2157,6 +2212,9 @@ static Value *emit_f_is(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgva
21572212
return cmp;
21582213
}
21592214

2215+
// if (arg1.tindex || arg2.tindex)
2216+
// TODO: handle with emit_bitsunion_compare
2217+
21602218
int ptr_comparable = 0; // whether this type is unique'd by pointer
21612219
if (rt1 == (jl_value_t*)jl_sym_type || rt2 == (jl_value_t*)jl_sym_type)
21622220
ptr_comparable = 1;
@@ -2373,7 +2431,8 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
23732431
Value *selidx = ctx.builder.CreateMul(emit_arraylen_prim(ctx, ary), nbytes);
23742432
selidx = ctx.builder.CreateAdd(selidx, idx);
23752433
Value *ptindex = ctx.builder.CreateGEP(T_int8, data, selidx);
2376-
Value *tindex = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateLoad(T_int8, ptindex)));
2434+
Value *tindex = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1),
2435+
tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateLoad(T_int8, ptindex)));
23772436
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
23782437
AllocaInst *lv = emit_static_alloca(ctx, AT);
23792438
if (al > 1)

src/rtutils.c

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -887,9 +887,12 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
887887
n += jl_static_show_x(out, *(jl_value_t**)fld_ptr, depth);
888888
}
889889
else {
890-
n += jl_static_show_x_(out, (jl_value_t*)fld_ptr,
891-
(jl_datatype_t*)jl_field_type(vt, i),
892-
depth);
890+
jl_datatype_t *ft = (jl_datatype_t*)jl_field_type(vt, i);
891+
if (jl_is_uniontype(ft)) {
892+
uint8_t sel = ((uint8_t*)fld_ptr)[jl_field_size(vt, i) - 1];
893+
ft = (jl_datatype_t*)jl_nth_union_component((jl_value_t*)ft, sel);
894+
}
895+
n += jl_static_show_x_(out, (jl_value_t*)fld_ptr, ft, depth);
893896
}
894897
if (istuple && tlen == 1)
895898
n += jl_printf(out, ",");

test/core.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5324,17 +5324,30 @@ struct B
53245324
y::AA
53255325
z::Int8
53265326
end
5327-
b = B(91, AA(ntuple(i -> Int8(i), Val(7))), 23)
5328-
5329-
@test b.x === Int8(91)
5330-
@test b.z === Int8(23)
5331-
@test b.y === AA(ntuple(i -> Int8(i), Val(7)))
5332-
@test sizeof(b) == 12
5333-
@test AA(Int8(1)).x === Int8(1)
5334-
@test AA(Int8(0)).x === Int8(0)
5335-
@test AA(Int16(1)).x === Int16(1)
5336-
@test AA(nothing).x === nothing
5337-
@test sizeof(b.y) == 8
5327+
@noinline compare(a, b) = (a === b) # test code-generation of `is`
5328+
let
5329+
b = B(91, AA(ntuple(i -> Int8(i), Val(7))), 23)
5330+
b2 = Ref(b)[] # copy b via field assignment
5331+
b3 = B[b][1] # copy b via array assignment
5332+
@test pointer_from_objref(b) == pointer_from_objref(b)
5333+
@test pointer_from_objref(b) != pointer_from_objref(b2)
5334+
@test pointer_from_objref(b) != pointer_from_objref(b3)
5335+
@test pointer_from_objref(b2) != pointer_from_objref(b3)
5336+
5337+
@test b === b2 === b3
5338+
@test compare(b, b2)
5339+
@test compare(b, b3)
5340+
@test object_id(b) === object_id(b2) == object_id(b3)
5341+
@test b.x === Int8(91)
5342+
@test b.z === Int8(23)
5343+
@test b.y === AA((Int8(1), Int8(2), Int8(3), Int8(4), Int8(5), Int8(6), Int8(7)))
5344+
@test sizeof(b) == 12
5345+
@test AA(Int8(1)).x === Int8(1)
5346+
@test AA(Int8(0)).x === Int8(0)
5347+
@test AA(Int16(1)).x === Int16(1)
5348+
@test AA(nothing).x === nothing
5349+
@test sizeof(b.y) == 8
5350+
end
53385351

53395352
for U in boxedunions
53405353
local U

0 commit comments

Comments
 (0)