Skip to content

Commit 708053c

Browse files
authored
[CIR] Upstream initial support for union type (llvm#137501)
Closes llvm#136059
1 parent afd738c commit 708053c

File tree

8 files changed

+349
-41
lines changed

8 files changed

+349
-41
lines changed

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ def CIR_RecordType : CIR_Type<"Record", "record",
531531
bool isComplete() const { return !isIncomplete(); };
532532
bool isIncomplete() const;
533533

534+
mlir::Type getLargestMember(const mlir::DataLayout &dataLayout) const;
534535
size_t getNumElements() const { return getMembers().size(); };
535536
std::string getKindAsStr() {
536537
switch (getKind()) {

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -317,20 +317,21 @@ LValue CIRGenFunction::emitLValueForField(LValue base, const FieldDecl *field) {
317317
}
318318

319319
unsigned recordCVR = base.getVRQualifiers();
320-
if (rec->isUnion()) {
321-
cgm.errorNYI(field->getSourceRange(), "emitLValueForField: union");
322-
return LValue();
323-
}
324320

325-
assert(!cir::MissingFeatures::preservedAccessIndexRegion());
326321
llvm::StringRef fieldName = field->getName();
327-
const CIRGenRecordLayout &layout =
328-
cgm.getTypes().getCIRGenRecordLayout(field->getParent());
329-
unsigned fieldIndex = layout.getCIRFieldNo(field);
330-
322+
unsigned fieldIndex;
331323
assert(!cir::MissingFeatures::lambdaFieldToName());
332324

325+
if (rec->isUnion())
326+
fieldIndex = field->getFieldIndex();
327+
else {
328+
const CIRGenRecordLayout &layout =
329+
cgm.getTypes().getCIRGenRecordLayout(field->getParent());
330+
fieldIndex = layout.getCIRFieldNo(field);
331+
}
332+
333333
addr = emitAddrOfFieldStorage(addr, field, fieldName, fieldIndex);
334+
assert(!cir::MissingFeatures::preservedAccessIndexRegion());
334335

335336
// If this is a reference field, load the reference right now.
336337
if (fieldType->isReferenceType()) {

clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===--- CIRGenExprAgg.cpp - Emit CIR Code from Aggregate Expressions -----===//
1+
//===- CIRGenExprAggregrate.cpp - Emit CIR Code from Aggregate Expressions ===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

clang/lib/CIR/CodeGen/CIRGenRecordLayout.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class CIRGenRecordLayout {
3131

3232
/// Map from (non-bit-field) record field to the corresponding cir record type
3333
/// field no. This info is populated by the record builder.
34-
llvm::DenseMap<const clang::FieldDecl *, unsigned> fieldInfo;
34+
llvm::DenseMap<const clang::FieldDecl *, unsigned> fieldIdxMap;
3535

3636
public:
3737
CIRGenRecordLayout(cir::RecordType completeObjectType)
@@ -44,8 +44,8 @@ class CIRGenRecordLayout {
4444
/// Return cir::RecordType element number that corresponds to the field FD.
4545
unsigned getCIRFieldNo(const clang::FieldDecl *fd) const {
4646
fd = fd->getCanonicalDecl();
47-
assert(fieldInfo.count(fd) && "Invalid field for record!");
48-
return fieldInfo.lookup(fd);
47+
assert(fieldIdxMap.count(fd) && "Invalid field for record!");
48+
return fieldIdxMap.lookup(fd);
4949
}
5050
};
5151

clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,15 @@ struct CIRRecordLowering final {
5656
};
5757
// The constructor.
5858
CIRRecordLowering(CIRGenTypes &cirGenTypes, const RecordDecl *recordDecl,
59-
bool isPacked);
59+
bool packed);
6060

6161
/// Constructs a MemberInfo instance from an offset and mlir::Type.
6262
MemberInfo makeStorageInfo(CharUnits offset, mlir::Type data) {
6363
return MemberInfo(offset, MemberInfo::InfoKind::Field, data);
6464
}
6565

6666
void lower();
67+
void lowerUnion();
6768

6869
/// Determines if we need a packed llvm struct.
6970
void determinePacked();
@@ -83,6 +84,10 @@ struct CIRRecordLowering final {
8384
return CharUnits::fromQuantity(dataLayout.layout.getTypeABIAlignment(Ty));
8485
}
8586

87+
bool isZeroInitializable(const FieldDecl *fd) {
88+
return cirGenTypes.isZeroInitializable(fd->getType());
89+
}
90+
8691
/// Wraps cir::IntType with some implicit arguments.
8792
mlir::Type getUIntNType(uint64_t numBits) {
8893
unsigned alignedBits = llvm::PowerOf2Ceil(numBits);
@@ -121,6 +126,13 @@ struct CIRRecordLowering final {
121126
/// Fills out the structures that are ultimately consumed.
122127
void fillOutputFields();
123128

129+
void appendPaddingBytes(CharUnits size) {
130+
if (!size.isZero()) {
131+
fieldTypes.push_back(getByteArrayType(size));
132+
padded = true;
133+
}
134+
}
135+
124136
CIRGenTypes &cirGenTypes;
125137
CIRGenBuilderTy &builder;
126138
const ASTContext &astContext;
@@ -130,12 +142,14 @@ struct CIRRecordLowering final {
130142
std::vector<MemberInfo> members;
131143
// Output fields, consumed by CIRGenTypes::computeRecordLayout
132144
llvm::SmallVector<mlir::Type, 16> fieldTypes;
133-
llvm::DenseMap<const FieldDecl *, unsigned> fields;
145+
llvm::DenseMap<const FieldDecl *, unsigned> fieldIdxMap;
134146
cir::CIRDataLayout dataLayout;
135147

136148
LLVM_PREFERRED_TYPE(bool)
137149
unsigned zeroInitializable : 1;
138150
LLVM_PREFERRED_TYPE(bool)
151+
unsigned zeroInitializableAsBase : 1;
152+
LLVM_PREFERRED_TYPE(bool)
139153
unsigned packed : 1;
140154
LLVM_PREFERRED_TYPE(bool)
141155
unsigned padded : 1;
@@ -147,19 +161,19 @@ struct CIRRecordLowering final {
147161
} // namespace
148162

149163
CIRRecordLowering::CIRRecordLowering(CIRGenTypes &cirGenTypes,
150-
const RecordDecl *recordDecl,
151-
bool isPacked)
164+
const RecordDecl *recordDecl, bool packed)
152165
: cirGenTypes(cirGenTypes), builder(cirGenTypes.getBuilder()),
153166
astContext(cirGenTypes.getASTContext()), recordDecl(recordDecl),
154167
astRecordLayout(
155168
cirGenTypes.getASTContext().getASTRecordLayout(recordDecl)),
156169
dataLayout(cirGenTypes.getCGModule().getModule()),
157-
zeroInitializable(true), packed(isPacked), padded(false) {}
170+
zeroInitializable(true), zeroInitializableAsBase(true), packed(packed),
171+
padded(false) {}
158172

159173
void CIRRecordLowering::lower() {
160174
if (recordDecl->isUnion()) {
161-
cirGenTypes.getCGModule().errorNYI(recordDecl->getSourceRange(),
162-
"lower: union");
175+
lowerUnion();
176+
assert(!cir::MissingFeatures::bitfields());
163177
return;
164178
}
165179

@@ -194,7 +208,8 @@ void CIRRecordLowering::fillOutputFields() {
194208
fieldTypes.push_back(member.data);
195209
if (member.kind == MemberInfo::InfoKind::Field) {
196210
if (member.fieldDecl)
197-
fields[member.fieldDecl->getCanonicalDecl()] = fieldTypes.size() - 1;
211+
fieldIdxMap[member.fieldDecl->getCanonicalDecl()] =
212+
fieldTypes.size() - 1;
198213
// A field without storage must be a bitfield.
199214
assert(!cir::MissingFeatures::bitfields());
200215
}
@@ -296,7 +311,7 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *rd, cir::RecordType *ty) {
296311
assert(!cir::MissingFeatures::bitfields());
297312

298313
// Add all the field numbers.
299-
rl->fieldInfo.swap(lowering.fields);
314+
rl->fieldIdxMap.swap(lowering.fieldIdxMap);
300315

301316
// Dump the layout, if requested.
302317
if (getASTContext().getLangOpts().DumpRecordLayouts) {
@@ -306,3 +321,68 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *rd, cir::RecordType *ty) {
306321
// TODO: implement verification
307322
return rl;
308323
}
324+
325+
void CIRRecordLowering::lowerUnion() {
326+
CharUnits layoutSize = astRecordLayout.getSize();
327+
mlir::Type storageType = nullptr;
328+
bool seenNamedMember = false;
329+
330+
// Iterate through the fields setting bitFieldInfo and the Fields array. Also
331+
// locate the "most appropriate" storage type.
332+
for (const FieldDecl *field : recordDecl->fields()) {
333+
mlir::Type fieldType;
334+
if (field->isBitField())
335+
cirGenTypes.getCGModule().errorNYI(recordDecl->getSourceRange(),
336+
"bitfields in lowerUnion");
337+
else
338+
fieldType = getStorageType(field);
339+
340+
// This maps a field to its index. For unions, the index is always 0.
341+
fieldIdxMap[field->getCanonicalDecl()] = 0;
342+
343+
// Compute zero-initializable status.
344+
// This union might not be zero initialized: it may contain a pointer to
345+
// data member which might have some exotic initialization sequence.
346+
// If this is the case, then we ought not to try and come up with a "better"
347+
// type, it might not be very easy to come up with a Constant which
348+
// correctly initializes it.
349+
if (!seenNamedMember) {
350+
seenNamedMember = field->getIdentifier();
351+
if (!seenNamedMember)
352+
if (const RecordDecl *fieldRD = field->getType()->getAsRecordDecl())
353+
seenNamedMember = fieldRD->findFirstNamedDataMember();
354+
if (seenNamedMember && !isZeroInitializable(field)) {
355+
zeroInitializable = zeroInitializableAsBase = false;
356+
storageType = fieldType;
357+
}
358+
}
359+
360+
// Because our union isn't zero initializable, we won't be getting a better
361+
// storage type.
362+
if (!zeroInitializable)
363+
continue;
364+
365+
// Conditionally update our storage type if we've got a new "better" one.
366+
if (!storageType || getAlignment(fieldType) > getAlignment(storageType) ||
367+
(getAlignment(fieldType) == getAlignment(storageType) &&
368+
getSize(fieldType) > getSize(storageType)))
369+
storageType = fieldType;
370+
371+
// NOTE(cir): Track all union member's types, not just the largest one. It
372+
// allows for proper type-checking and retain more info for analisys.
373+
fieldTypes.push_back(fieldType);
374+
}
375+
376+
if (!storageType)
377+
cirGenTypes.getCGModule().errorNYI(recordDecl->getSourceRange(),
378+
"No-storage Union NYI");
379+
380+
if (layoutSize < getSize(storageType))
381+
storageType = getByteArrayType(layoutSize);
382+
else
383+
appendPaddingBytes(layoutSize - getSize(storageType));
384+
385+
// Set packed if we need it.
386+
if (layoutSize % getAlignment(storageType))
387+
packed = true;
388+
}

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,17 +230,34 @@ void RecordType::complete(ArrayRef<Type> members, bool packed, bool padded) {
230230
llvm_unreachable("failed to complete record");
231231
}
232232

233+
/// Return the largest member of in the type.
234+
///
235+
/// Recurses into union members never returning a union as the largest member.
236+
Type RecordType::getLargestMember(const ::mlir::DataLayout &dataLayout) const {
237+
assert(isUnion() && "Only call getLargestMember on unions");
238+
llvm::ArrayRef<Type> members = getMembers();
239+
// If the union is padded, we need to ignore the last member,
240+
// which is the padding.
241+
return *std::max_element(
242+
members.begin(), getPadded() ? members.end() - 1 : members.end(),
243+
[&](Type lhs, Type rhs) {
244+
return dataLayout.getTypeABIAlignment(lhs) <
245+
dataLayout.getTypeABIAlignment(rhs) ||
246+
(dataLayout.getTypeABIAlignment(lhs) ==
247+
dataLayout.getTypeABIAlignment(rhs) &&
248+
dataLayout.getTypeSize(lhs) < dataLayout.getTypeSize(rhs));
249+
});
250+
}
251+
233252
//===----------------------------------------------------------------------===//
234253
// Data Layout information for types
235254
//===----------------------------------------------------------------------===//
236255

237256
llvm::TypeSize
238257
RecordType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
239258
mlir::DataLayoutEntryListRef params) const {
240-
if (isUnion()) {
241-
// TODO(CIR): Implement union layout.
242-
return llvm::TypeSize::getFixed(8);
243-
}
259+
if (isUnion())
260+
return dataLayout.getTypeSize(getLargestMember(dataLayout));
244261

245262
unsigned recordSize = computeStructSize(dataLayout);
246263
return llvm::TypeSize::getFixed(recordSize * 8);
@@ -249,10 +266,8 @@ RecordType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
249266
uint64_t
250267
RecordType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
251268
::mlir::DataLayoutEntryListRef params) const {
252-
if (isUnion()) {
253-
// TODO(CIR): Implement union layout.
254-
return 8;
255-
}
269+
if (isUnion())
270+
return dataLayout.getTypeABIAlignment(getLargestMember(dataLayout));
256271

257272
// Packed structures always have an ABI alignment of 1.
258273
if (getPacked())
@@ -268,8 +283,6 @@ RecordType::computeStructSize(const mlir::DataLayout &dataLayout) const {
268283
unsigned recordSize = 0;
269284
uint64_t recordAlignment = 1;
270285

271-
// We can't use a range-based for loop here because we might be ignoring the
272-
// last element.
273286
for (mlir::Type ty : getMembers()) {
274287
// This assumes that we're calculating size based on the ABI alignment, not
275288
// the preferred alignment for each type.

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,7 +1436,14 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
14361436
break;
14371437
// Unions are lowered as only the largest member.
14381438
case cir::RecordType::Union:
1439-
llvm_unreachable("Lowering of unions is NYI");
1439+
if (auto largestMember = type.getLargestMember(dataLayout))
1440+
llvmMembers.push_back(
1441+
convertTypeForMemory(converter, dataLayout, largestMember));
1442+
if (type.getPadded()) {
1443+
auto last = *type.getMembers().rbegin();
1444+
llvmMembers.push_back(
1445+
convertTypeForMemory(converter, dataLayout, last));
1446+
}
14401447
break;
14411448
}
14421449

@@ -1609,7 +1616,11 @@ mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite(
16091616
return mlir::success();
16101617
}
16111618
case cir::RecordType::Union:
1612-
return op.emitError() << "NYI: union get_member lowering";
1619+
// Union members share the address space, so we just need a bitcast to
1620+
// conform to type-checking.
1621+
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy,
1622+
adaptor.getAddr());
1623+
return mlir::success();
16131624
}
16141625
}
16151626

0 commit comments

Comments
 (0)