|
5 | 5 |
|
6 | 6 | package software.amazon.smithy.rust.codegen.core.testutil
|
7 | 7 |
|
| 8 | +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait |
| 9 | +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait |
| 10 | +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait |
| 11 | +import software.amazon.smithy.aws.traits.protocols.RestXmlTrait |
8 | 12 | import software.amazon.smithy.build.PluginContext
|
9 | 13 | import software.amazon.smithy.model.Model
|
10 | 14 | import software.amazon.smithy.model.node.ObjectNode
|
11 | 15 | import software.amazon.smithy.model.node.ToNode
|
| 16 | +import software.amazon.smithy.model.shapes.ServiceShape |
| 17 | +import software.amazon.smithy.model.shapes.ShapeId |
| 18 | +import software.amazon.smithy.model.traits.AbstractTrait |
| 19 | +import software.amazon.smithy.model.transform.ModelTransformer |
| 20 | +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait |
12 | 21 | import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
|
13 | 22 | import software.amazon.smithy.rust.codegen.core.util.runCommand
|
14 | 23 | import java.io.File
|
@@ -153,3 +162,128 @@ fun codegenIntegrationTest(
|
153 | 162 | logger.fine(out.toString())
|
154 | 163 | return testDir
|
155 | 164 | }
|
| 165 | + |
| 166 | +/** |
| 167 | + * Metadata associated with a protocol that provides additional information needed for testing. |
| 168 | + * |
| 169 | + * @property protocol The protocol enum value this metadata is associated with |
| 170 | + * @property contentType The HTTP Content-Type header value associated with this protocol. |
| 171 | + */ |
| 172 | +data class ProtocolMetadata( |
| 173 | + val protocol: ModelProtocol, |
| 174 | + val contentType: String, |
| 175 | +) |
| 176 | + |
| 177 | +/** |
| 178 | + * Represents the supported protocol traits in Smithy models. |
| 179 | + * |
| 180 | + * @property trait The Smithy trait instance with which the service shape must be annotated. |
| 181 | + */ |
| 182 | +enum class ModelProtocol(val trait: AbstractTrait) { |
| 183 | + AwsJson10(AwsJson1_0Trait.builder().build()), |
| 184 | + AwsJson11(AwsJson1_1Trait.builder().build()), |
| 185 | + RestJson(RestJson1Trait.builder().build()), |
| 186 | + RestXml(RestXmlTrait.builder().build()), |
| 187 | + Rpcv2Cbor(Rpcv2CborTrait.builder().build()), |
| 188 | + ; |
| 189 | + |
| 190 | + // Create metadata after enum is initialized |
| 191 | + val metadata: ProtocolMetadata by lazy { |
| 192 | + when (this) { |
| 193 | + AwsJson10 -> ProtocolMetadata(this, "application/x-amz-json-1.0") |
| 194 | + AwsJson11 -> ProtocolMetadata(this, "application/x-amz-json-1.1") |
| 195 | + RestJson -> ProtocolMetadata(this, "application/json") |
| 196 | + RestXml -> ProtocolMetadata(this, "application/xml") |
| 197 | + Rpcv2Cbor -> ProtocolMetadata(this, "application/cbor") |
| 198 | + } |
| 199 | + } |
| 200 | + |
| 201 | + companion object { |
| 202 | + private val TRAIT_IDS = values().map { it.trait.toShapeId() }.toSet() |
| 203 | + val ALL: Set<ModelProtocol> = values().toSet() |
| 204 | + |
| 205 | + fun getTraitIds() = TRAIT_IDS |
| 206 | + } |
| 207 | +} |
| 208 | + |
| 209 | +/** |
| 210 | + * Removes all existing protocol traits annotated on the given service, |
| 211 | + * then sets the provided `protocol` as the sole protocol trait for the service. |
| 212 | + */ |
| 213 | +fun Model.replaceProtocolTraitOnServerShapeId( |
| 214 | + serviceShapeId: ShapeId, |
| 215 | + modelProtocol: ModelProtocol, |
| 216 | +): Model { |
| 217 | + val serviceShape = this.expectShape(serviceShapeId, ServiceShape::class.java) |
| 218 | + return replaceProtocolTraitOnServiceShape(serviceShape, modelProtocol) |
| 219 | +} |
| 220 | + |
| 221 | +/** |
| 222 | + * Removes all existing protocol traits annotated on the given service shape, |
| 223 | + * then sets the provided `protocol` as the sole protocol trait for the service. |
| 224 | + */ |
| 225 | +fun Model.replaceProtocolTraitOnServiceShape( |
| 226 | + serviceShape: ServiceShape, |
| 227 | + modelProtocol: ModelProtocol, |
| 228 | +): Model { |
| 229 | + val serviceBuilder = serviceShape.toBuilder() |
| 230 | + ModelProtocol.getTraitIds().forEach { traitId -> |
| 231 | + serviceBuilder.removeTrait(traitId) |
| 232 | + } |
| 233 | + val service = serviceBuilder.addTrait(modelProtocol.trait).build() |
| 234 | + return ModelTransformer.create().replaceShapes(this, listOf(service)) |
| 235 | +} |
| 236 | + |
| 237 | +/** |
| 238 | + * Processes a Smithy model string by applying different protocol traits and invoking the tests block on the model. |
| 239 | + * For each protocol, this function: |
| 240 | + * 1. Parses the Smithy model string |
| 241 | + * 2. Replaces any existing protocol traits on service shapes with the specified protocol |
| 242 | + * 3. Runs the provided test with the transformed model and protocol metadata |
| 243 | + * |
| 244 | + * @param protocolTraitIds Set of protocols to test against |
| 245 | + * @param test Function that receives the transformed model and protocol metadata for testing |
| 246 | + */ |
| 247 | +fun String.forProtocols( |
| 248 | + protocolTraitIds: Set<ModelProtocol>, |
| 249 | + test: (Model, ProtocolMetadata) -> Unit, |
| 250 | +) { |
| 251 | + val baseModel = this.asSmithyModel(smithyVersion = "2") |
| 252 | + val serviceShapes = baseModel.serviceShapes.toList() |
| 253 | + |
| 254 | + protocolTraitIds.forEach { protocol -> |
| 255 | + val transformedModel = |
| 256 | + serviceShapes.fold(baseModel) { acc, shape -> |
| 257 | + acc.replaceProtocolTraitOnServiceShape(shape, protocol) |
| 258 | + } |
| 259 | + test(transformedModel, protocol.metadata) |
| 260 | + } |
| 261 | +} |
| 262 | + |
| 263 | +/** |
| 264 | + * Convenience overload that accepts vararg protocols instead of a Set. |
| 265 | + * |
| 266 | + * @param protocols Variable number of protocols to test against |
| 267 | + * @param test Function that receives the transformed model and protocol metadata for testing |
| 268 | + * @see forProtocols |
| 269 | + */ |
| 270 | +fun String.forProtocols( |
| 271 | + vararg protocols: ModelProtocol, |
| 272 | + test: (Model, ProtocolMetadata) -> Unit, |
| 273 | +) { |
| 274 | + forProtocols(protocols.toSet(), test) |
| 275 | +} |
| 276 | + |
| 277 | +/** |
| 278 | + * Tests a Smithy model string against all supported protocols, with optional exclusions. |
| 279 | + * |
| 280 | + * @param exclude Set of protocols to exclude from testing (default is empty) |
| 281 | + * @param test Function that receives the transformed model and protocol metadata for testing |
| 282 | + * @see forProtocols |
| 283 | + */ |
| 284 | +fun String.forAllProtocols( |
| 285 | + exclude: Set<ModelProtocol> = emptySet(), |
| 286 | + test: (Model, ProtocolMetadata) -> Unit, |
| 287 | +) { |
| 288 | + forProtocols(ModelProtocol.ALL - exclude, test) |
| 289 | +} |
0 commit comments