diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index 332129e72850..739a06abb7f3 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -225,16 +225,18 @@ object TypeOps: */ def orDominator(tp: Type)(using Context): Type = { - /** a faster version of cs1 intersect cs2 that treats bottom types correctly */ + /** a faster version of cs1 intersect cs2 */ def intersect(cs1: List[ClassSymbol], cs2: List[ClassSymbol]): List[ClassSymbol] = - if cs1.head == defn.NothingClass then cs2 - else if cs2.head == defn.NothingClass then cs1 - else if cs1.head == defn.NullClass && !ctx.explicitNulls && cs2.head.derivesFrom(defn.ObjectClass) then cs2 - else if cs2.head == defn.NullClass && !ctx.explicitNulls && cs1.head.derivesFrom(defn.ObjectClass) then cs1 - else - val cs2AsSet = new util.HashSet[ClassSymbol](128) - cs2.foreach(cs2AsSet += _) - cs1.filter(cs2AsSet.contains) + val cs2AsSet = BaseClassSet(cs2) + cs1.filter(cs2AsSet.contains) + + /** a version of Type#baseClasses that treats bottom types correctly */ + def orBaseClasses(tp: Type): List[ClassSymbol] = tp.stripTypeVar match + case OrType(tp1, tp2) => + if tp1.isBottomType && (tp1 frozen_<:< tp2) then orBaseClasses(tp2) + else if tp2.isBottomType && (tp2 frozen_<:< tp1) then orBaseClasses(tp1) + else intersect(orBaseClasses(tp1), orBaseClasses(tp2)) + case _ => tp.baseClasses /** The minimal set of classes in `cs` which derive all other classes in `cs` */ def dominators(cs: List[ClassSymbol], accu: List[ClassSymbol]): List[ClassSymbol] = (cs: @unchecked) match { @@ -369,7 +371,7 @@ object TypeOps: } // Step 3: Intersect base classes of both sides - val commonBaseClasses = tp.mapReduceOr(_.baseClasses)(intersect) + val commonBaseClasses = orBaseClasses(tp) val doms = dominators(commonBaseClasses, Nil) def baseTp(cls: ClassSymbol): Type = tp.baseType(cls).mapReduceOr(identity)(mergeRefinedOrApplied) diff --git a/tests/explicit-nulls/pos/i16236.scala b/tests/explicit-nulls/pos/i16236.scala new file mode 100644 index 000000000000..a64f5bc176ce --- /dev/null +++ b/tests/explicit-nulls/pos/i16236.scala @@ -0,0 +1,10 @@ +// Copy of tests/pos/i16236.scala +trait A + +def consume[T](t: T): Unit = () + +def fails(p: (Double & A) | Null): Unit = consume(p) // was: assertion failed: & A + +def switchedOrder(p: (A & Double) | Null): Unit = consume(p) // ok +def nonPrimitive(p: (String & A) | Null): Unit = consume(p) // ok +def notNull(p: (Double & A)): Unit = consume(p) // ok diff --git a/tests/pos/i16236.scala b/tests/pos/i16236.scala new file mode 100644 index 000000000000..6451689ad94d --- /dev/null +++ b/tests/pos/i16236.scala @@ -0,0 +1,9 @@ +trait A + +def consume[T](t: T): Unit = () + +def fails(p: (Double & A) | Null): Unit = consume(p) // was: assertion failed: & A + +def switchedOrder(p: (A & Double) | Null): Unit = consume(p) // ok +def nonPrimitive(p: (String & A) | Null): Unit = consume(p) // ok +def notNull(p: (Double & A)): Unit = consume(p) // ok