Skip to content

Commit df77d5f

Browse files
drganjooFahad Zubair
and
Fahad Zubair
authored
Enforce constraints for unnamed enums (#3884)
### Enforces Constraints for Unnamed Enums This PR addresses the issue where, on the server side, unnamed enums were incorrectly treated as infallible during deserialization, allowing any string value to be converted without validation. The solution introduces a `ConstraintViolation` and `TryFrom` implementation for unnamed enums, ensuring that deserialized values conform to the enum variants defined in the Smithy model. The following is an example of an unnamed enum: ```smithy @enum([ { value: "MONDAY" }, { value: "TUESDAY" } ]) string UnnamedDayOfWeek ``` On the server side the following type is generated for the Smithy shape: ```rust pub struct UnnamedDayOfWeek(String); impl ::std::convert::TryFrom<::std::string::String> for UnnamedDayOfWeek { type Error = crate::model::unnamed_day_of_week::ConstraintViolation; fn try_from( s: ::std::string::String, ) -> ::std::result::Result<Self, <Self as ::std::convert::TryFrom<::std::string::String>>::Error> { match s.as_str() { "MONDAY" | "TUESDAY" => Ok(Self(s)), _ => Err(crate::model::unnamed_day_of_week::ConstraintViolation(s)), } } } ``` This change prevents invalid values from being deserialized into unnamed enums and raises appropriate constraint violations when necessary. There is one difference between the Rust code generated for `TryFrom<String>` for named enums versus unnamed enums. The implementation for unnamed enums passes the ownership of the `String` parameter to the generated structure, and the implementation for `TryFrom<&str>` delegates to `TryFrom<String>`. ```rust impl ::std::convert::TryFrom<::std::string::String> for UnnamedDayOfWeek { type Error = crate::model::unnamed_day_of_week::ConstraintViolation; fn try_from( s: ::std::string::String, ) -> ::std::result::Result<Self, <Self as ::std::convert::TryFrom<::std::string::String>>::Error> { match s.as_str() { "MONDAY" | "TUESDAY" => Ok(Self(s)), _ => Err(crate::model::unnamed_day_of_week::ConstraintViolation(s)), } } } impl ::std::convert::TryFrom<&str> for UnnamedDayOfWeek { type Error = crate::model::unnamed_day_of_week::ConstraintViolation; fn try_from( s: &str, ) -> ::std::result::Result<Self, <Self as ::std::convert::TryFrom<&str>>::Error> { s.to_owned().try_into() } } ``` On the client side, the behaviour is unchanged, and the client does not validate for backward compatibility reasons. An [existing test](https://github.com/smithy-lang/smithy-rs/pull/3884/files#diff-021ec60146cfe231105d21a7389f2dffcd546595964fbb3f0684ebf068325e48R82) has been modified to ensure this. ```rust #[test] fn generate_unnamed_enums() { let result = "t2.nano" .parse::<crate::types::UnnamedEnum>() .expect("static value validated to member"); assert_eq!(result, UnnamedEnum("t2.nano".to_owned())); let result = "not-a-valid-variant" .parse::<crate::types::UnnamedEnum>() .expect("static value validated to member"); assert_eq!(result, UnnamedEnum("not-a-valid-variant".to_owned())); } ``` Fixes issue #3880 --------- Co-authored-by: Fahad Zubair <[email protected]>
1 parent 8cf9ebd commit df77d5f

File tree

8 files changed

+235
-62
lines changed

8 files changed

+235
-62
lines changed

.changelog/4329788.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
---
2+
applies_to: ["server"]
3+
authors: ["drganjoo"]
4+
references: ["smithy-rs#3880"]
5+
breaking: true
6+
new_feature: false
7+
bug_fix: true
8+
---
9+
Unnamed enums now validate assigned values and will raise a `ConstraintViolation` if an unknown variant is set.
10+
11+
The following is an example of an unnamed enum:
12+
```smithy
13+
@enum([
14+
{ value: "MONDAY" },
15+
{ value: "TUESDAY" }
16+
])
17+
string UnnamedDayOfWeek
18+
```

codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,37 @@ data class InfallibleEnumType(
7979
)
8080
}
8181

82+
override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
83+
writable {
84+
rustTemplate(
85+
"""
86+
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
87+
fn from(s: T) -> Self {
88+
${context.enumName}(s.as_ref().to_owned())
89+
}
90+
}
91+
""",
92+
*preludeScope,
93+
)
94+
}
95+
96+
override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
97+
writable {
98+
// Add an infallible FromStr implementation for uniformity
99+
rustTemplate(
100+
"""
101+
impl ::std::str::FromStr for ${context.enumName} {
102+
type Err = ::std::convert::Infallible;
103+
104+
fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
105+
#{Ok}(${context.enumName}::from(s))
106+
}
107+
}
108+
""",
109+
*preludeScope,
110+
)
111+
}
112+
82113
override fun additionalEnumImpls(context: EnumGeneratorContext): Writable =
83114
writable {
84115
// `try_parse` isn't needed for unnamed enums

codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ internal class ClientInstantiatorTest {
6969
val shape = model.lookup<StringShape>("com.test#UnnamedEnum")
7070
val sut = ClientInstantiator(codegenContext)
7171
val data = Node.parse("t2.nano".dq())
72+
// The client SDK should accept unknown variants as valid.
73+
val notValidVariant = Node.parse("not-a-valid-variant".dq())
7274

7375
val project = TestWorkspace.testProject(symbolProvider)
7476
project.moduleFor(shape) {
@@ -77,7 +79,11 @@ internal class ClientInstantiatorTest {
7779
withBlock("let result = ", ";") {
7880
sut.render(this, shape, data)
7981
}
80-
rust("""assert_eq!(result, UnnamedEnum("t2.nano".to_owned()));""")
82+
rust("""assert_eq!(result, UnnamedEnum("$data".to_owned()));""")
83+
withBlock("let result = ", ";") {
84+
sut.render(this, shape, notValidVariant)
85+
}
86+
rust("""assert_eq!(result, UnnamedEnum("$notValidVariant".to_owned()));""")
8187
}
8288
}
8389
project.compileAndTest()

codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ abstract class EnumType {
5959
/** Returns a writable that implements `FromStr` for the enum */
6060
abstract fun implFromStr(context: EnumGeneratorContext): Writable
6161

62+
/** Returns a writable that implements `From<&str>` and/or `TryFrom<&str>` for the unnamed enum */
63+
abstract fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable
64+
65+
/** Returns a writable that implements `FromStr` for the unnamed enum */
66+
abstract fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable
67+
6268
/** Optionally adds additional documentation to the `enum` docs */
6369
open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {}
6470

@@ -237,32 +243,10 @@ open class EnumGenerator(
237243
rust("&self.0")
238244
},
239245
)
240-
241-
// Add an infallible FromStr implementation for uniformity
242-
rustTemplate(
243-
"""
244-
impl ::std::str::FromStr for ${context.enumName} {
245-
type Err = ::std::convert::Infallible;
246-
247-
fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
248-
#{Ok}(${context.enumName}::from(s))
249-
}
250-
}
251-
""",
252-
*preludeScope,
253-
)
254-
255-
rustTemplate(
256-
"""
257-
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
258-
fn from(s: T) -> Self {
259-
${context.enumName}(s.as_ref().to_owned())
260-
}
261-
}
262-
263-
""",
264-
*preludeScope,
265-
)
246+
// impl From<str> for Blah { ... }
247+
enumType.implFromForStrForUnnamedEnum(context)(this)
248+
// impl FromStr for Blah { ... }
249+
enumType.implFromStrForUnnamedEnum(context)(this)
266250
}
267251

268252
private fun RustWriter.renderEnum() {

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,16 @@ class EnumGeneratorTest {
494494
// intentional no-op
495495
}
496496

497+
override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
498+
writable {
499+
// intentional no-op
500+
}
501+
502+
override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
503+
writable {
504+
// intentional no-op
505+
}
506+
497507
override fun additionalEnumMembers(context: EnumGeneratorContext): Writable =
498508
writable {
499509
rust("// additional enum members")

codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
1010
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
1111
import software.amazon.smithy.rust.codegen.core.rustlang.writable
1212
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
13+
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
1314
import software.amazon.smithy.rust.codegen.core.util.dq
1415

1516
object TestEnumType : EnumType() {
@@ -49,4 +50,35 @@ object TestEnumType : EnumType() {
4950
""",
5051
)
5152
}
53+
54+
override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
55+
writable {
56+
rustTemplate(
57+
"""
58+
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
59+
fn from(s: T) -> Self {
60+
${context.enumName}(s.as_ref().to_owned())
61+
}
62+
}
63+
""",
64+
*preludeScope,
65+
)
66+
}
67+
68+
override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
69+
writable {
70+
// Add an infallible FromStr implementation for uniformity
71+
rustTemplate(
72+
"""
73+
impl ::std::str::FromStr for ${context.enumName} {
74+
type Err = ::std::convert::Infallible;
75+
76+
fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
77+
#{Ok}(${context.enumName}::from(s))
78+
}
79+
}
80+
""",
81+
*preludeScope,
82+
)
83+
}
5284
}

codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt

Lines changed: 71 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
package software.amazon.smithy.rust.codegen.server.smithy.generators
66

77
import software.amazon.smithy.model.shapes.StringShape
8+
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
89
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
910
import software.amazon.smithy.rust.codegen.core.rustlang.rust
10-
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
11-
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
1211
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
1312
import software.amazon.smithy.rust.codegen.core.rustlang.writable
1413
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
@@ -39,16 +38,14 @@ open class ConstrainedEnum(
3938
}
4039
private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape)
4140
private val constraintViolationName = constraintViolationSymbol.name
42-
private val codegenScope =
43-
arrayOf(
44-
"String" to RuntimeType.String,
45-
)
4641

47-
override fun implFromForStr(context: EnumGeneratorContext): Writable =
48-
writable {
49-
withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) {
50-
rustTemplate(
51-
"""
42+
private fun generateConstraintViolation(
43+
context: EnumGeneratorContext,
44+
generateTryFromStrAndString: RustWriter.(EnumGeneratorContext) -> Unit,
45+
) = writable {
46+
withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) {
47+
rustTemplate(
48+
"""
5249
##[derive(Debug, PartialEq)]
5350
pub struct $constraintViolationName(pub(crate) #{String});
5451
@@ -60,47 +57,86 @@ open class ConstrainedEnum(
6057
6158
impl #{Error} for $constraintViolationName {}
6259
""",
63-
*codegenScope,
64-
"Error" to RuntimeType.StdError,
65-
"Display" to RuntimeType.Display,
66-
)
60+
*preludeScope,
61+
"Error" to RuntimeType.StdError,
62+
"Display" to RuntimeType.Display,
63+
)
6764

68-
if (shape.isReachableFromOperationInput()) {
69-
rustTemplate(
70-
"""
65+
if (shape.isReachableFromOperationInput()) {
66+
rustTemplate(
67+
"""
7168
impl $constraintViolationName {
7269
#{EnumShapeConstraintViolationImplBlock:W}
7370
}
7471
""",
75-
"EnumShapeConstraintViolationImplBlock" to
76-
validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(
77-
context.enumTrait,
78-
),
79-
)
80-
}
72+
"EnumShapeConstraintViolationImplBlock" to
73+
validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(
74+
context.enumTrait,
75+
),
76+
)
8177
}
82-
rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) {
83-
rust("type Error = #T;", constraintViolationSymbol)
84-
rustBlockTemplate("fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error>", *preludeScope) {
85-
rustBlock("match s") {
86-
context.sortedMembers.forEach { member ->
87-
rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),")
78+
}
79+
80+
generateTryFromStrAndString(context)
81+
}
82+
83+
override fun implFromForStr(context: EnumGeneratorContext): Writable =
84+
generateConstraintViolation(context) {
85+
rustTemplate(
86+
"""
87+
impl #{TryFrom}<&str> for ${context.enumName} {
88+
type Error = #{ConstraintViolation};
89+
fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error> {
90+
match s {
91+
#{MatchArms}
92+
_ => Err(#{ConstraintViolation}(s.to_owned()))
8893
}
89-
rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol)
9094
}
9195
}
92-
}
96+
impl #{TryFrom}<#{String}> for ${context.enumName} {
97+
type Error = #{ConstraintViolation};
98+
fn try_from(s: #{String}) -> #{Result}<Self, <Self as #{TryFrom}<#{String}>>::Error> {
99+
s.as_str().try_into()
100+
}
101+
}
102+
""",
103+
*preludeScope,
104+
"ConstraintViolation" to constraintViolationSymbol,
105+
"MatchArms" to
106+
writable {
107+
context.sortedMembers.forEach { member ->
108+
rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),")
109+
}
110+
},
111+
)
112+
}
113+
114+
override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
115+
generateConstraintViolation(context) {
93116
rustTemplate(
94117
"""
118+
impl #{TryFrom}<&str> for ${context.enumName} {
119+
type Error = #{ConstraintViolation};
120+
fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error> {
121+
s.to_owned().try_into()
122+
}
123+
}
95124
impl #{TryFrom}<#{String}> for ${context.enumName} {
96125
type Error = #{ConstraintViolation};
97126
fn try_from(s: #{String}) -> #{Result}<Self, <Self as #{TryFrom}<#{String}>>::Error> {
98-
s.as_str().try_into()
127+
match s.as_str() {
128+
#{Values} => Ok(Self(s)),
129+
_ => Err(#{ConstraintViolation}(s))
130+
}
99131
}
100132
}
101133
""",
102134
*preludeScope,
103135
"ConstraintViolation" to constraintViolationSymbol,
136+
"Values" to
137+
writable {
138+
rust(context.sortedMembers.joinToString(" | ") { it.value.dq() })
139+
},
104140
)
105141
}
106142

@@ -118,6 +154,8 @@ open class ConstrainedEnum(
118154
"ConstraintViolation" to constraintViolationSymbol,
119155
)
120156
}
157+
158+
override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext) = implFromStr(context)
121159
}
122160

123161
class ServerEnumGenerator(

0 commit comments

Comments
 (0)