Skip to content
Merged
7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/config/ScalaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ private sealed trait XSettings:
val XprintSuspension: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprint-suspension", "Show when code is suspended until macros are compiled.")
val Xprompt: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprompt", "Display a prompt after each error (debugging option).")
val XreplDisableDisplay: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xrepl-disable-display", "Do not display definitions in REPL.")
val XreplInterruptInstrumentation: Setting[String] = StringSetting(
AdvancedSetting,
"Xrepl-interrupt-instrumentation",
"true|false|local",
"pass `false` to disable bytecode instrumentation for interrupt handling in REPL, or `local` to limit interrupt support to only REPL-defined classes",
"true"
)
val XverifySignatures: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xverify-signatures", "Verify generic signatures in generated bytecode.")
val XignoreScala2Macros: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xignore-scala2-macros", "Ignore errors when compiling code that calls Scala2 macros, these will fail at runtime.")
val XimportSuggestionTimeout: Setting[Int] = IntSetting(AdvancedSetting, "Ximport-suggestion-timeout", "Timeout (in ms) for searching for import suggestions when errors are reported.", 8000)
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/quoted/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Interpreter(pos: SrcPos, classLoader0: ClassLoader)(using Context):

val classLoader =
if ctx.owner.topLevelClass.name.startsWith(str.REPL_SESSION_LINE) then
new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader0)
new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader0, "false")
else classLoader0

/** Local variable environment */
Expand Down Expand Up @@ -204,7 +204,11 @@ class Interpreter(pos: SrcPos, classLoader0: ClassLoader)(using Context):
}

private def loadReplLineClass(moduleClass: Symbol): Class[?] = {
val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader)
val lineClassloader = new AbstractFileClassLoader(
ctx.settings.outputDir.value,
classLoader,
"false"
)
lineClassloader.loadClass(moduleClass.name.firstPart.toString)
}

Expand Down
63 changes: 54 additions & 9 deletions compiler/src/dotty/tools/repl/AbstractFileClassLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ package repl
import scala.language.unsafeNulls

import io.AbstractFile
import dotty.tools.repl.ReplBytecodeInstrumentation

import java.net.{URL, URLConnection, URLStreamHandler}
import java.util.Collections

class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) extends ClassLoader(parent):
class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader, interruptInstrumentation: String) extends ClassLoader(parent):
private def findAbstractFile(name: String) = root.lookupPath(name.split('/').toIndexedSeq, directory = false)

// on JDK 20 the URL constructor we're using is deprecated,
Expand All @@ -45,17 +46,61 @@ class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) exten
val pathParts = name.split("[./]").toList
for (dirPart <- pathParts.init) {
file = file.lookupName(dirPart, true)
if (file == null) {
throw new ClassNotFoundException(name)
}
if (file == null) throw new ClassNotFoundException(name)
}
file = file.lookupName(pathParts.last+".class", false)
if (file == null) {
throw new ClassNotFoundException(name)
}
if (file == null) throw new ClassNotFoundException(name)

val bytes = file.toByteArray
defineClass(name, bytes, 0, bytes.length)

if interruptInstrumentation != "false" then defineClassInstrumented(name, bytes)
else defineClass(name, bytes, 0, bytes.length)
}

override def loadClass(name: String): Class[?] = try findClass(name) catch case _: ClassNotFoundException => super.loadClass(name)
def defineClassInstrumented(name: String, originalBytes: Array[Byte]) = {
val instrumentedBytes = ReplBytecodeInstrumentation.instrument(originalBytes)
defineClass(name, instrumentedBytes, 0, instrumentedBytes.length)
}

override def loadClass(name: String): Class[?] =
if interruptInstrumentation == "false" || interruptInstrumentation == "local"
then return super.loadClass(name)

val loaded = findLoadedClass(name) // Check if already loaded
if loaded != null then return loaded

name match { // Don't instrument JDK classes or StopRepl
case s"java.$_" => super.loadClass(name)
case s"javax.$_" => super.loadClass(name)
case s"sun.$_" => super.loadClass(name)
case s"jdk.$_" => super.loadClass(name)
case "dotty.tools.repl.StopRepl" =>
// Load StopRepl bytecode from parent but ensure each classloader gets its own copy
val classFileName = name.replace('.', '/') + ".class"
val is = Option(getParent.getResourceAsStream(classFileName))
// Can't get as resource, use the classloader that loaded this AbstractFileClassLoader
// class itself, which must have access to StopRepl
.getOrElse(classOf[AbstractFileClassLoader].getClassLoader.getResourceAsStream(classFileName))

try
val bytes = is.readAllBytes()
defineClass(name, bytes, 0, bytes.length)
finally is.close()

case _ =>
try findClass(name)
catch case _: ClassNotFoundException =>
// Not in REPL output, try to load from parent and instrument it
try
val resourceName = name.replace('.', '/') + ".class"
getParent.getResourceAsStream(resourceName) match {
case null => super.loadClass(resourceName)
case is =>
try defineClassInstrumented(name, is.readAllBytes())
finally is.close()
}
catch
case ex: Exception => super.loadClass(name)
}

end AbstractFileClassLoader
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/repl/Rendering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None):
new java.net.URLClassLoader(compilerClasspath.toArray, baseClassLoader)
}

myClassLoader = new AbstractFileClassLoader(ctx.settings.outputDir.value, parent)
myClassLoader = new AbstractFileClassLoader(
ctx.settings.outputDir.value,
parent,
ctx.settings.XreplInterruptInstrumentation.value
)
myClassLoader
}

Expand Down
75 changes: 75 additions & 0 deletions compiler/src/dotty/tools/repl/ReplBytecodeInstrumentation.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package dotty.tools
package repl

import scala.language.unsafeNulls

import scala.tools.asm.*
import scala.tools.asm.Opcodes.*
import scala.tools.asm.tree.*
import scala.collection.JavaConverters.*
import java.util.concurrent.atomic.AtomicBoolean

object ReplBytecodeInstrumentation:
/** Instrument bytecode to add checks to throw an exception if the REPL command is cancelled
*/
def instrument(originalBytes: Array[Byte]): Array[Byte] =
try
val cr = new ClassReader(originalBytes)
val cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES)
val instrumenter = new InstrumentClassVisitor(cw)
cr.accept(instrumenter, ClassReader.EXPAND_FRAMES)
cw.toByteArray
catch
case ex: Exception => originalBytes

def setStopFlag(classLoader: ClassLoader, b: Boolean): Unit =
val cancelClassOpt =
try Some(classLoader.loadClass(classOf[dotty.tools.repl.StopRepl].getName))
catch {
case _: java.lang.ClassNotFoundException => None
}
for(cancelClass <- cancelClassOpt) {
val setAllStopMethod = cancelClass.getDeclaredMethod("setStop", classOf[Boolean])
setAllStopMethod.invoke(null, b.asInstanceOf[AnyRef])
}

private class InstrumentClassVisitor(cv: ClassVisitor) extends ClassVisitor(ASM9, cv):

override def visitMethod(
access: Int,
name: String,
descriptor: String,
signature: String,
exceptions: Array[String]
): MethodVisitor =
new InstrumentMethodVisitor(super.visitMethod(access, name, descriptor, signature, exceptions))

/** MethodVisitor that inserts stop checks at backward branches */
private class InstrumentMethodVisitor(mv: MethodVisitor) extends MethodVisitor(ASM9, mv):
// Track labels we've seen to identify backward branches
private val seenLabels = scala.collection.mutable.Set[Label]()

def addStopCheck() = mv.visitMethodInsn(
INVOKESTATIC,
classOf[dotty.tools.repl.StopRepl].getName.replace('.', '/'),
"throwIfReplStopped",
"()V",
false
)

override def visitCode(): Unit =
super.visitCode()
// Insert throwIfReplStopped() call at the start of the method
// to allow breaking out of deeply recursive methods like fib(99)
addStopCheck()

override def visitLabel(label: Label): Unit =
seenLabels.add(label)
super.visitLabel(label)

override def visitJumpInsn(opcode: Int, label: Label): Unit =
// Add throwIfReplStopped if this is a backward branch (jumping to a label we've already seen)
if seenLabels.contains(label) then addStopCheck()
super.visitJumpInsn(opcode, label)

end ReplBytecodeInstrumentation
15 changes: 13 additions & 2 deletions compiler/src/dotty/tools/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import dotty.tools.dotc.{CompilationUnit, Driver}
import dotty.tools.dotc.config.CompilerCommand
import dotty.tools.io.*
import dotty.tools.repl.Rendering.showUser
import dotty.tools.repl.ReplBytecodeInstrumentation
import dotty.tools.runner.ScalaClassLoader.*
import org.jline.reader.*

Expand Down Expand Up @@ -228,13 +229,20 @@ class ReplDriver(settings: Array[String],
// Set up interrupt handler for command execution
var firstCtrlCEntered = false
val thread = Thread.currentThread()

// Clear the stop flag before executing new code
ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), false)

val previousSignalHandler = terminal.handle(
org.jline.terminal.Terminal.Signal.INT,
(sig: org.jline.terminal.Terminal.Signal) => {
if (!firstCtrlCEntered) {
firstCtrlCEntered = true
// Set the stop flag to trigger throwIfReplStopped() in instrumented code
ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), true)
// Also interrupt the thread as a fallback for non-instrumented code
thread.interrupt()
out.println("\nInterrupting running thread, Ctrl-C again to terminate the REPL Process")
out.println("\nAttempting to interrupt running thread with `Thread.interrupt`")
} else {
out.println("\nTerminating REPL Process...")
System.exit(130) // Standard exit code for SIGINT
Expand Down Expand Up @@ -592,7 +600,10 @@ class ReplDriver(settings: Array[String],
val jarClassLoader = fromURLsParallelCapable(
jarClassPath.asURLs, prevClassLoader)
rendering.myClassLoader = new AbstractFileClassLoader(
prevOutputDir, jarClassLoader)
prevOutputDir,
jarClassLoader,
ctx.settings.XreplInterruptInstrumentation.value
)

out.println(s"Added '$path' to classpath.")
} catch {
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/repl/ScriptEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ class ScriptEngine extends AbstractScriptEngine {
"-classpath", "", // Avoid the default "."
"-usejavacp",
"-color:never",
"-Xrepl-disable-display"
"-Xrepl-disable-display",
"-Xrepl-interrupt-instrumentation",
"false"
), Console.out, None)

private val rendering = new Rendering(Some(getClass.getClassLoader))
private var state: State = driver.initialState

Expand Down
18 changes: 18 additions & 0 deletions compiler/src/dotty/tools/repl/StopRepl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package dotty.tools.repl

import scala.annotation.static

class StopRepl

object StopRepl {
// Needs to be volatile, otherwise changes to this may not get seen by other threads
// for arbitrarily long periods of time (minutes!)
@static @volatile private var stop: Boolean = false

@static def setStop(n: Boolean): Unit = { stop = n }

/** Check if execution should stop, and throw ThreadDeath if so */
@static def throwIfReplStopped(): Unit = {
if (stop) throw new ThreadDeath()
}
}
24 changes: 12 additions & 12 deletions compiler/test/dotty/tools/repl/AbstractFileClassLoaderTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ class AbstractFileClassLoaderTest:
@Test def afclGetsParent(): Unit =
val p = new URLClassLoader(Array.empty[URL])
val d = new VirtualDirectory("vd", None)
val x = new AbstractFileClassLoader(d, p)
val x = new AbstractFileClassLoader(d, p, "false")
assertSame(p, x.getParent)

@Test def afclGetsResource(): Unit =
val (fuzz, booz) = fuzzBuzzBooz
booz.writeContent("hello, world")
val sut = new AbstractFileClassLoader(fuzz, NoClassLoader)
val sut = new AbstractFileClassLoader(fuzz, NoClassLoader, "false")
val res = sut.getResource("buzz/booz.class")
assertNotNull("Find buzz/booz.class", res)
assertEquals("hello, world", slurp(res))
Expand All @@ -66,8 +66,8 @@ class AbstractFileClassLoaderTest:
val (fuzz_, booz_) = fuzzBuzzBooz
booz.writeContent("hello, world")
booz_.writeContent("hello, world_")
val p = new AbstractFileClassLoader(fuzz, NoClassLoader)
val sut = new AbstractFileClassLoader(fuzz_, p)
val p = new AbstractFileClassLoader(fuzz, NoClassLoader, "false")
val sut = new AbstractFileClassLoader(fuzz_, p, "false")
val res = sut.getResource("buzz/booz.class")
assertNotNull("Find buzz/booz.class", res)
assertEquals("hello, world", slurp(res))
Expand All @@ -78,7 +78,7 @@ class AbstractFileClassLoaderTest:
val bass = fuzz.fileNamed("bass")
booz.writeContent("hello, world")
bass.writeContent("lo tone")
val sut = new AbstractFileClassLoader(fuzz, NoClassLoader)
val sut = new AbstractFileClassLoader(fuzz, NoClassLoader, "false")
val res = sut.getResource("booz.class")
assertNotNull(res)
assertEquals("hello, world", slurp(res))
Expand All @@ -88,7 +88,7 @@ class AbstractFileClassLoaderTest:
@Test def afclGetsResources(): Unit =
val (fuzz, booz) = fuzzBuzzBooz
booz.writeContent("hello, world")
val sut = new AbstractFileClassLoader(fuzz, NoClassLoader)
val sut = new AbstractFileClassLoader(fuzz, NoClassLoader, "false")
val e = sut.getResources("buzz/booz.class")
assertTrue("At least one buzz/booz.class", e.hasMoreElements)
assertEquals("hello, world", slurp(e.nextElement))
Expand All @@ -99,8 +99,8 @@ class AbstractFileClassLoaderTest:
val (fuzz_, booz_) = fuzzBuzzBooz
booz.writeContent("hello, world")
booz_.writeContent("hello, world_")
val p = new AbstractFileClassLoader(fuzz, NoClassLoader)
val x = new AbstractFileClassLoader(fuzz_, p)
val p = new AbstractFileClassLoader(fuzz, NoClassLoader, "false")
val x = new AbstractFileClassLoader(fuzz_, p, "false")
val e = x.getResources("buzz/booz.class")
assertTrue(e.hasMoreElements)
assertEquals("hello, world", slurp(e.nextElement))
Expand All @@ -111,15 +111,15 @@ class AbstractFileClassLoaderTest:
@Test def afclGetsResourceAsStream(): Unit =
val (fuzz, booz) = fuzzBuzzBooz
booz.writeContent("hello, world")
val x = new AbstractFileClassLoader(fuzz, NoClassLoader)
val x = new AbstractFileClassLoader(fuzz, NoClassLoader, "false")
val r = x.getResourceAsStream("buzz/booz.class")
assertNotNull(r)
assertEquals("hello, world", closing(r)(is => Source.fromInputStream(is).mkString))

@Test def afclGetsClassBytes(): Unit =
val (fuzz, booz) = fuzzBuzzBooz
booz.writeContent("hello, world")
val sut = new AbstractFileClassLoader(fuzz, NoClassLoader)
val sut = new AbstractFileClassLoader(fuzz, NoClassLoader, "false")
val b = sut.classBytes("buzz/booz.class")
assertEquals("hello, world", new String(b, UTF8.charSet))

Expand All @@ -129,8 +129,8 @@ class AbstractFileClassLoaderTest:
booz.writeContent("hello, world")
booz_.writeContent("hello, world_")

val p = new AbstractFileClassLoader(fuzz, NoClassLoader)
val sut = new AbstractFileClassLoader(fuzz_, p)
val p = new AbstractFileClassLoader(fuzz, NoClassLoader, "false")
val sut = new AbstractFileClassLoader(fuzz_, p, "false")
val b = sut.classBytes("buzz/booz.class")
assertEquals("hello, world", new String(b, UTF8.charSet))
end AbstractFileClassLoaderTest
2 changes: 1 addition & 1 deletion staging/src/scala/quoted/staging/QuoteDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ private class QuoteDriver(appClassloader: ClassLoader) extends Driver:
case Left(classname) =>
assert(!ctx.reporter.hasErrors)

val classLoader = new AbstractFileClassLoader(outDir, appClassloader)
val classLoader = new AbstractFileClassLoader(outDir, appClassloader, "false")

val clazz = classLoader.loadClass(classname)
val method = clazz.getMethod("apply")
Expand Down