Skip to content

Commit 9ae1496

Browse files
committed
Implement Field and update traverse implementation to use static _fields var
1 parent 4ac8a2e commit 9ae1496

File tree

6 files changed

+848
-37
lines changed

6 files changed

+848
-37
lines changed

Sources/SwiftProtobuf/Field.swift

+735
Large diffs are not rendered by default.

Sources/SwiftProtobuf/Message.swift

+11
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ public protocol Message: _CommonMessageConformances {
110110
/// normal `Equatable`. `Equatable` is provided with specific generated
111111
/// types.
112112
func isEqualTo(message: any Message) -> Bool
113+
114+
/// Provides `Field` information for this `Message` type used to provide a default implementation of `traverse`
115+
static var _fields: [Field<Self>] { get }
113116
}
114117

115118
#if DEBUG
@@ -129,6 +132,14 @@ extension Message {
129132
return true
130133
}
131134

135+
/// Default traverse implementation
136+
public func traverse<V: Visitor>(visitor: inout V) throws {
137+
for field in Self._fields {
138+
try field.traverse(message: self, visitor: &visitor)
139+
}
140+
try unknownFields.traverse(visitor: &visitor)
141+
}
142+
132143
/// A hash based on the message's full contents.
133144
public func hash(into hasher: inout Hasher) {
134145
var visitor = HashVisitor(hasher)

Sources/protoc-gen-swift/FieldGenerator.swift

+3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ protocol FieldGenerator {
4444

4545
/// Generate the support for traversing this field.
4646
func generateTraverse(printer: inout CodePrinter)
47+
48+
/// Generate the field node
49+
func generateFieldNode(printer: inout CodePrinter)
4750

4851
/// Generate support for comparing this field's value.
4952
/// The generated code should return false in the current scope if the field's don't match.

Sources/protoc-gen-swift/MessageFieldGenerator.swift

+28
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class MessageFieldGenerator: FieldGeneratorBase, FieldGenerator {
2626
private let swiftName: String
2727
private let underscoreSwiftName: String
2828
private let storedProperty: String
29+
private let storedPropertyWithoutSelf: String
2930
private let swiftHasName: String
3031
private let swiftClearName: String
3132
private let swiftType: String
@@ -75,8 +76,10 @@ class MessageFieldGenerator: FieldGeneratorBase, FieldGenerator {
7576

7677
if usesHeapStorage {
7778
storedProperty = "_storage.\(underscoreSwiftName)"
79+
storedPropertyWithoutSelf = storedProperty
7880
} else {
7981
storedProperty = "self.\(hasFieldPresence ? underscoreSwiftName : swiftName)"
82+
storedPropertyWithoutSelf = "\(hasFieldPresence ? underscoreSwiftName : swiftName)"
8083
}
8184

8285
super.init(descriptor: descriptor)
@@ -221,4 +224,29 @@ class MessageFieldGenerator: FieldGeneratorBase, FieldGenerator {
221224
p.printIndented("try visitor.\(visitMethod)(\(traitsArg)value: \(varName), fieldNumber: \(number))")
222225
p.print("}\(suffix)")
223226
}
227+
228+
func generateFieldNode(printer p: inout SwiftProtobufPluginLibrary.CodePrinter) {
229+
let factoryMethod: String
230+
let traitsArg: String
231+
if isMap {
232+
factoryMethod = "map"
233+
traitsArg = "type: \(traitsType).self, "
234+
} else {
235+
let modifier = isPacked ? "packed" : isRepeated ? "repeated" : "singular"
236+
factoryMethod = "\(modifier)\(fieldDescriptor.protoGenericType)"
237+
traitsArg = ""
238+
}
239+
240+
let suffix: String
241+
if isRepeated {
242+
suffix = ""
243+
} else if hasFieldPresence {
244+
suffix = ", isUnset: { $0.\(storedPropertyWithoutSelf) == nil }"
245+
} else if swiftDefaultValue != "0" && swiftDefaultValue != "false" && swiftDefaultValue != "String()" && swiftDefaultValue != "Data()" {
246+
suffix = ", defaultValue: \(swiftDefaultValue)"
247+
} else {
248+
suffix = ""
249+
}
250+
p.print(".\(factoryMethod)(\(traitsArg){ $0.\(swiftName) }, fieldNumber: \(number)\(suffix)),")
251+
}
224252
}

Sources/protoc-gen-swift/MessageGenerator.swift

+39-37
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ class MessageGenerator {
217217
// generateIsInitialized provides a blank line after itself.
218218
generateDecodeMessage(printer: &p)
219219
p.print()
220+
generateFieldNodes(printer: &p)
221+
p.print()
220222
generateTraverse(printer: &p)
221223
p.print()
222224
generateMessageEquality(printer: &p)
@@ -311,53 +313,53 @@ class MessageGenerator {
311313
p.print("}")
312314
}
313315
}
314-
315316
}
316317
}
317318
p.print("}")
318319
}
319320

320-
/// Generates the `traverse` method for the message.
321-
///
322-
/// - Parameter p: The code printer.
323-
private func generateTraverse(printer p: inout CodePrinter) {
324-
p.print("\(visibility)func traverse<V: \(namer.swiftProtobufModulePrefix)Visitor>(visitor: inout V) throws {")
325-
p.withIndentation { p in
326-
generateWithLifetimeExtension(printer: &p, throws: true) { p in
327-
if let storage = storage {
328-
storage.generatePreTraverse(printer: &p)
329-
}
321+
private func generateFieldNodes(printer p: inout CodePrinter) {
322+
let visitExtensionsName = descriptor.useMessageSetWireFormat ? "extensionFieldsAsMessageSet" : "extensionFields"
330323

331-
let visitExtensionsName =
332-
descriptor.useMessageSetWireFormat ? "visitExtensionFieldsAsMessageSet" : "visitExtensionFields"
333-
334-
let usesLocals = fields.reduce(false) { $0 || $1.generateTraverseUsesLocals }
335-
if usesLocals {
336-
p.print("""
337-
// The use of inline closures is to circumvent an issue where the compiler
338-
// allocates stack space for every if/case branch local when no optimizations
339-
// are enabled. https://github.com/apple/swift-protobuf/issues/1034 and
340-
// https://github.com/apple/swift-protobuf/issues/1182
341-
""")
342-
}
324+
p.print("\(visibility)static let _fields: [Field<Self>] = [")
325+
326+
// Use the "ambitious" ranges because for visit because subranges with no
327+
// intermixed fields can be merged to reduce the number of calls for
328+
// extension visitation.
329+
var ranges = descriptor.ambitiousExtensionRanges.makeIterator()
330+
var nextRange = ranges.next()
343331

344-
// Use the "ambitious" ranges because for visit because subranges with no
345-
// intermixed fields can be merged to reduce the number of calls for
346-
// extension visitation.
347-
var ranges = descriptor.ambitiousExtensionRanges.makeIterator()
348-
var nextRange = ranges.next()
349-
for f in fieldsSortedByNumber {
350-
while nextRange != nil && Int(nextRange!.lowerBound) < f.number {
351-
p.print("try visitor.\(visitExtensionsName)(fields: _protobuf_extensionFieldValues, start: \(nextRange!.lowerBound), end: \(nextRange!.upperBound))")
352-
nextRange = ranges.next()
353-
}
354-
f.generateTraverse(printer: &p)
355-
}
356-
while nextRange != nil {
357-
p.print("try visitor.\(visitExtensionsName)(fields: _protobuf_extensionFieldValues, start: \(nextRange!.lowerBound), end: \(nextRange!.upperBound))")
332+
p.withIndentation { p in
333+
for f in fieldsSortedByNumber {
334+
while nextRange != nil && Int(nextRange!.lowerBound) < f.number {
335+
p.print(".\(visitExtensionsName)({ $0._protobuf_extensionFieldValues }, start: \(nextRange!.lowerBound), end: \(nextRange!.upperBound)),")
358336
nextRange = ranges.next()
359337
}
338+
f.generateFieldNode(printer: &p)
339+
}
340+
while nextRange != nil {
341+
p.print(".\(visitExtensionsName)({ $0._protobuf_extensionFieldValues }, start: \(nextRange!.lowerBound), end: \(nextRange!.upperBound)),")
342+
nextRange = ranges.next()
360343
}
344+
}
345+
p.print("]")
346+
347+
for oneof in oneofs {
348+
oneof.generateFieldNodeStaticLet(printer: &p)
349+
}
350+
}
351+
352+
private func generateTraverse(printer p: inout CodePrinter) {
353+
// If we have an AnyMessageStorageClass we need to generate the traverse function so it can include a preTraverse call
354+
guard let storage, storage is AnyMessageStorageClassGenerator else {
355+
return
356+
}
357+
p.print("\(visibility)func traverse<V: Visitor>(visitor: inout V) throws {")
358+
p.withIndentation { p in
359+
p.print("try _storage.preTraverse()")
360+
p.print("for field in Self._fields {")
361+
p.printIndented("try field.traverse(message: self, visitor: &visitor)")
362+
p.print("}")
361363
p.print("try unknownFields.traverse(visitor: &visitor)")
362364
}
363365
p.print("}")

Sources/protoc-gen-swift/OneofGenerator.swift

+32
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ class OneofGenerator {
104104
func generateTraverse(printer p: inout CodePrinter) {
105105
oneof.generateTraverse(printer: &p, field: self)
106106
}
107+
108+
func generateFieldNode(printer p: inout SwiftProtobufPluginLibrary.CodePrinter) {
109+
oneof.generateFieldNode(printer: &p, field: self)
110+
}
111+
107112
}
108113

109114
private let oneofDescriptor: OneofDescriptor
@@ -360,6 +365,33 @@ class OneofGenerator {
360365
}
361366

362367
var generateTraverseUsesLocals: Bool { return true }
368+
369+
func generateFieldNode(printer p: inout CodePrinter, field: MemberFieldGenerator) {
370+
// First field in the group causes the output.
371+
let group = fieldSortedGrouped[field.group]
372+
guard field === group.first else { return }
373+
p.print(".oneOf({ $0.\(swiftFieldName) }) {")
374+
p.withIndentation { p in
375+
p.print("switch $0 {")
376+
for field in group {
377+
p.print("case \(field.dottedSwiftName):")
378+
p.printIndented("return _oneOfField_\(field.swiftName)")
379+
}
380+
if group.count != fields.count {
381+
p.print("default:")
382+
p.printIndented("return nil")
383+
}
384+
p.print("}")
385+
}
386+
p.print("},")
387+
}
388+
389+
func generateFieldNodeStaticLet(printer p: inout CodePrinter) {
390+
for field in fieldsSortedByNumber {
391+
p.print("private static let _oneOfField_\(field.swiftName): Field<Self> = .singular\(field.protoGenericType)({ $0.\(field.swiftName) }, fieldNumber: \(field.number), isUnset: { _ in false })")
392+
}
393+
394+
}
363395

364396
func generateTraverse(printer p: inout CodePrinter, field: MemberFieldGenerator) {
365397
// First field in the group causes the output.

0 commit comments

Comments
 (0)