diff --git a/compiler/src/dotty/tools/dotc/core/Constants.scala b/compiler/src/dotty/tools/dotc/core/Constants.scala index 959c677b143f..5f25dbc7d431 100644 --- a/compiler/src/dotty/tools/dotc/core/Constants.scala +++ b/compiler/src/dotty/tools/dotc/core/Constants.scala @@ -147,46 +147,48 @@ object Constants { case _ => throw new Error("value " + value + " is not a Double") } - /** Convert constant value to conform to given type. - */ - def convertTo(pt: Type)(using Context): Constant | Null = { - def classBound(pt: Type): Type = pt.dealias.stripTypeVar.stripNull() match { - case tref: TypeRef if !tref.symbol.isClass && tref.info.exists => - classBound(tref.info.bounds.lo) - case param: TypeParamRef => - ctx.typerState.constraint.entry(param) match { - case TypeBounds(lo, hi) => - if (hi.classSymbol.isPrimitiveValueClass) hi //constrain further with high bound - else classBound(lo) - case NoType => classBound(param.binder.paramInfos(param.paramNum).lo) - case inst => classBound(inst) - } - case pt => pt - } - pt match - case ConstantType(value) if value == this => this - case _: SingletonType => null - case _ => - val target = classBound(pt).typeSymbol - if (target == tpe.typeSymbol) - this - else if ((target == defn.ByteClass) && isByteRange) - Constant(byteValue) - else if (target == defn.ShortClass && isShortRange) - Constant(shortValue) - else if (target == defn.CharClass && isCharRange) - Constant(charValue) - else if (target == defn.IntClass && isIntRange) - Constant(intValue) - else if (target == defn.LongClass && isLongRange) - Constant(longValue) - else if (target == defn.FloatClass && isFloatRange) - Constant(floatValue) - else if (target == defn.DoubleClass && isNumeric) - Constant(doubleValue) - else - null - } + import dotty.tools.dotc.core.Decorators.i + + /** Convert constant value to conform to given type. */ + def convertTo(pt: Type)(using Context): Constant | Null = pt.dealias.stripTypeVar match + case ConstantType(value) if value == this => this + case _: SingletonType => null + case tref: TypeRef if !tref.symbol.isClass && tref.info.exists => + convertTo(tref.info.bounds.lo) + case param: TypeParamRef => + ctx.typerState.constraint.entry(param) match + case TypeBounds(lo, hi) => + val hiResult = convertTo(hi) + if hiResult != null then hiResult + else convertTo(lo) + case NoType => convertTo(param.binder.paramInfos(param.paramNum).lo) + case inst => convertTo(inst) + case pt: OrType => + // For a union type, if both sides convert to constants, + // return a constant only if two sides convert to the same constant + // (same value and same tag). + // For example, `2` can be converted to `Byte | String | Byte`, + // but not to `Byte | Int` (which would be ambiguous). + // TODO: However, the logic will not work here, since `2: Int` is already + // a subtype of `Byte | Int` before `adapt`. + val leftResult = convertTo(pt.tp1) + val rightResult = convertTo(pt.tp2) + // println(s"convertTo OrType: $this to $pt, leftResult = $leftResult, rightResult = $rightResult, compare = ${leftResult == rightResult}") + if leftResult == null then rightResult + else if rightResult == null then leftResult + else if leftResult == rightResult then leftResult + else null + case pt => + val target = pt.typeSymbol + if target == tpe.typeSymbol then this + else if (target == defn.ByteClass) && isByteRange then Constant(byteValue) + else if (target == defn.ShortClass) && isShortRange then Constant(shortValue) + else if (target == defn.CharClass) && isCharRange then Constant(charValue) + else if (target == defn.IntClass) && isIntRange then Constant(intValue) + else if (target == defn.LongClass) && isLongRange then Constant(longValue) + else if (target == defn.FloatClass) && isFloatRange then Constant(floatValue) + else if (target == defn.DoubleClass) && isNumeric then Constant(doubleValue) + else null def stringValue: String = value.toString diff --git a/tests/pos/i24571.scala b/tests/pos/i24571.scala index 7b5180acc20f..2319c3b35371 100644 --- a/tests/pos/i24571.scala +++ b/tests/pos/i24571.scala @@ -4,12 +4,16 @@ val n3: Int = 2 val n4: Int | Null = 2222 val n5: Int | Byte = 2 val n6: Byte | Int = 10000 +val n7: 1 | Null = 1 +val n8: Byte | String = 2 val x: Option[Byte] = Option(2) val x2: Option[Byte] = Option[Byte](2) val x3: Option[Int] = Option(2) val x4: Option[Null] = Option(null) val x5: Option[Byte | Null] = Option(2) +val x6: Option[1 | Null] = Option(1) +val x7: Option[Byte | String] = Option(2) trait MyOption[+T] @@ -22,3 +26,4 @@ val test2: MyOption[Byte] = MyOption.applyOld(2) val test3: MyOption[Int] = MyOption(2) val test4: MyOption[Null] = MyOption(null) val test5: MyOption[Byte | Null] = MyOption(2) +val test6: MyOption[Byte | String] = MyOption(2)