diff --git a/enumeratum-core/src/test/scala/enumeratum/EnumSpec.scala b/enumeratum-core/src/test/scala/enumeratum/EnumSpec.scala index 05ac0c6c..53c7b41e 100644 --- a/enumeratum-core/src/test/scala/enumeratum/EnumSpec.scala +++ b/enumeratum-core/src/test/scala/enumeratum/EnumSpec.scala @@ -43,6 +43,38 @@ class EnumSpec extends AnyFunSpec with Matchers with EnumSpecCompat { Word.values should be(IndexedSeq(Word.Hello, Word.Hi)) } + + it("should contain instance of generic subclass") { + sealed trait NestedGenericEnum[T] extends EnumEntry with Serializable + + object NestedGenericEnum extends Enum[NestedGenericEnum[Unit]] { + sealed trait Intermediary[T] extends NestedGenericEnum[T] + + case object A extends Intermediary[Unit] + case object B extends NestedGenericEnum[Unit] + val values = findValues + } + + NestedGenericEnum.values should contain theSameElementsAs Seq( + NestedGenericEnum.A, + NestedGenericEnum.B + ) + } + + it("should contain exactly one instance even if it inherits two subclasses") { + sealed trait MyEnum extends EnumEntry with Serializable + object MyEnum extends Enum[MyEnum] { + sealed trait Intermediary1 extends MyEnum + sealed trait Intermediary2 extends MyEnum + + case object A extends Intermediary1 with Intermediary2 + val values = findValues + } + + MyEnum.values should contain theSameElementsAs Seq( + MyEnum.A + ) + } } describe("#withName") { diff --git a/macros/src/main/scala-3/enumeratum/EnumMacros.scala b/macros/src/main/scala-3/enumeratum/EnumMacros.scala index 465feb14..05857f77 100644 --- a/macros/src/main/scala-3/enumeratum/EnumMacros.scala +++ b/macros/src/main/scala-3/enumeratum/EnumMacros.scala @@ -149,22 +149,14 @@ object EnumMacros: } } - childTpr match { case Some(child) => { + val tpeSym = child.typeSymbol child.asType match { - case ct @ '[IsEntry[t]] => { - val tpeSym = child.typeSymbol - - if (!isObject(tpeSym)) { - subclasses(tpeSym.children.map(_.tree) ::: children.tail, out) - } else { - subclasses(children.tail, child :: out) - } - } - + case ct @ '[IsEntry[t]] if isObject(tpeSym) => + subclasses(children.tail, child :: out) case _ => - subclasses(children.tail, out) + subclasses(tpeSym.children.map(_.tree) ::: children.tail, out) } } @@ -175,7 +167,7 @@ object EnumMacros: tpr.classSymbol .flatMap { cls => - val types = subclasses(cls.children.map(_.tree), Nil) + val types = subclasses(cls.children.map(_.tree), Nil).distinct if (types.isEmpty) None else Some(types) }