Skip to content

Commit 6a15587

Browse files
committed
.
1 parent 96aff06 commit 6a15587

File tree

6 files changed

+176
-6
lines changed

6 files changed

+176
-6
lines changed

compiler/src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ private sealed trait XSettings:
325325
val XprintSuspension: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprint-suspension", "Show when code is suspended until macros are compiled.")
326326
val Xprompt: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprompt", "Display a prompt after each error (debugging option).")
327327
val XreplDisableDisplay: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xrepl-disable-display", "Do not display definitions in REPL.")
328+
val XreplDisableBytecodeInstrumentation: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xrepl-disable-bytecode-instrumentation", "Disable bytecode instrumentation for interrupt handling in REPL.")
328329
val XverifySignatures: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xverify-signatures", "Verify generic signatures in generated bytecode.")
329330
val XignoreScala2Macros: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xignore-scala2-macros", "Ignore errors when compiling code that calls Scala2 macros, these will fail at runtime.")
330331
val XimportSuggestionTimeout: Setting[Int] = IntSetting(AdvancedSetting, "Ximport-suggestion-timeout", "Timeout (in ms) for searching for import suggestions when errors are reported.", 8000)

compiler/src/dotty/tools/repl/AbstractFileClassLoader.scala

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ package repl
1616
import scala.language.unsafeNulls
1717

1818
import io.AbstractFile
19+
import dotty.tools.repl.ReplBytecodeInstrumentation
1920

2021
import java.net.{URL, URLConnection, URLStreamHandler}
2122
import java.util.Collections
2223

23-
class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) extends ClassLoader(parent):
24+
class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader, instrumentBytecode: Boolean = true) extends ClassLoader(parent):
2425
private def findAbstractFile(name: String) = root.lookupPath(name.split('/').toIndexedSeq, directory = false)
2526

2627
// on JDK 20 the URL constructor we're using is deprecated,
@@ -53,9 +54,66 @@ class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) exten
5354
if (file == null) {
5455
throw new ClassNotFoundException(name)
5556
}
56-
val bytes = file.toByteArray
57+
val originalBytes = file.toByteArray
58+
59+
// Instrument bytecode for everything except StopRepl itself to avoid infinite recursion
60+
val bytes =
61+
if !instrumentBytecode || name == "dotty.tools.repl.StopRepl" then originalBytes
62+
else ReplBytecodeInstrumentation.instrument(originalBytes)
63+
5764
defineClass(name, bytes, 0, bytes.length)
5865
}
5966

60-
override def loadClass(name: String): Class[?] = try findClass(name) catch case _: ClassNotFoundException => super.loadClass(name)
67+
private def tryInstrumentLibraryClass(name: String): Class[?] =
68+
try
69+
val resourceName = name.replace('.', '/') + ".class"
70+
getParent.getResourceAsStream(resourceName) match{
71+
case null => super.loadClass(resourceName)
72+
case is =>
73+
try
74+
val bytes = is.readAllBytes()
75+
val instrumentedBytes =
76+
if instrumentBytecode then ReplBytecodeInstrumentation.instrument(bytes)
77+
else bytes
78+
defineClass(name, instrumentedBytes, 0, instrumentedBytes.length)
79+
finally is.close()
80+
}
81+
catch
82+
case ex: Exception => super.loadClass(name)
83+
84+
override def loadClass(name: String): Class[?] =
85+
if !instrumentBytecode then
86+
return super.loadClass(name)
87+
88+
// Check if already loaded
89+
val loaded = findLoadedClass(name)
90+
if loaded != null then
91+
return loaded
92+
93+
// Don't instrument JDK classes or StopRepl
94+
name match{
95+
case s"java.$_" => super.loadClass(name)
96+
case s"javax.$_" => super.loadClass(name)
97+
case s"sun.$_" => super.loadClass(name)
98+
case s"jdk.$_" => super.loadClass(name)
99+
case "dotty.tools.repl.StopRepl" =>
100+
// Load StopRepl from parent but ensure each classloader gets its own copy
101+
val is = getParent.getResourceAsStream(name.replace('.', '/') + ".class")
102+
if is != null then
103+
try
104+
val bytes = is.readAllBytes()
105+
defineClass(name, bytes, 0, bytes.length)
106+
finally
107+
is.close()
108+
else
109+
// Can't get as resource, use the classloader that loaded this AbstractFileClassLoader
110+
// class itself, which must have access to StopRepl
111+
classOf[AbstractFileClassLoader].getClassLoader.loadClass(name)
112+
case _ =>
113+
try findClass(name)
114+
catch case _: ClassNotFoundException =>
115+
// Not in REPL output, try to load from parent and instrument it
116+
tryInstrumentLibraryClass(name)
117+
}
118+
61119
end AbstractFileClassLoader

compiler/src/dotty/tools/repl/Rendering.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None):
7272
new java.net.URLClassLoader(compilerClasspath.toArray, baseClassLoader)
7373
}
7474

75-
myClassLoader = new AbstractFileClassLoader(ctx.settings.outputDir.value, parent)
75+
val instrumentBytecode = !ctx.settings.XreplDisableBytecodeInstrumentation.value
76+
myClassLoader = new AbstractFileClassLoader(ctx.settings.outputDir.value, parent, instrumentBytecode)
7677
myClassLoader
7778
}
7879

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package dotty.tools
2+
package repl
3+
4+
import scala.language.unsafeNulls
5+
6+
import scala.tools.asm.*
7+
import scala.tools.asm.Opcodes.*
8+
import scala.tools.asm.tree.*
9+
import scala.collection.JavaConverters.*
10+
import java.util.concurrent.atomic.AtomicBoolean
11+
12+
object ReplBytecodeInstrumentation:
13+
/** Instrument bytecode to add stop checks at backward branches.
14+
*
15+
* Backward branches indicate loops, which is where code can hang.
16+
* We inject a call to throwIfReplStopped() before each backward branch.
17+
*
18+
* @param originalBytes the original class bytecode
19+
* @return the instrumented bytecode
20+
*/
21+
def instrument(originalBytes: Array[Byte]): Array[Byte] =
22+
try
23+
val cr = new ClassReader(originalBytes)
24+
val cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES)
25+
val instrumenter = new InstrumentClassVisitor(cw)
26+
cr.accept(instrumenter, ClassReader.EXPAND_FRAMES)
27+
cw.toByteArray
28+
catch
29+
case ex: Exception => originalBytes
30+
31+
def setStopFlag(classLoader: ClassLoader, b: Boolean): Unit =
32+
val cancelClassOpt =
33+
try Some(classLoader.loadClass(classOf[dotty.tools.repl.StopRepl].getName))
34+
catch{
35+
case _: java.lang.ClassNotFoundException => None
36+
}
37+
for(cancelClass <- cancelClassOpt){
38+
val setAllStopMethod = cancelClass.getDeclaredMethod("setStop", classOf[Boolean])
39+
setAllStopMethod.invoke(null, b.asInstanceOf[AnyRef])
40+
}
41+
42+
private class InstrumentClassVisitor(cv: ClassVisitor) extends ClassVisitor(ASM9, cv):
43+
44+
override def visitMethod(
45+
access: Int,
46+
name: String,
47+
descriptor: String,
48+
signature: String,
49+
exceptions: Array[String]
50+
): MethodVisitor =
51+
val mv = super.visitMethod(access, name, descriptor, signature, exceptions)
52+
if mv == null then mv
53+
else new InstrumentMethodVisitor(mv)
54+
55+
/** MethodVisitor that inserts stop checks at backward branches */
56+
private class InstrumentMethodVisitor(mv: MethodVisitor) extends MethodVisitor(ASM9, mv):
57+
// Track labels we've seen to identify backward branches
58+
private val seenLabels = scala.collection.mutable.Set[Label]()
59+
60+
def addStopCheck() = mv.visitMethodInsn(
61+
INVOKESTATIC,
62+
classOf[dotty.tools.repl.StopRepl].getName.replace('.', '/'),
63+
"throwIfReplStopped",
64+
"()V",
65+
false
66+
)
67+
68+
override def visitCode(): Unit =
69+
super.visitCode()
70+
// Insert throwIfReplStopped() call at the start of the method
71+
// to allow breaking out of deeply recursive methods like fib(99)
72+
addStopCheck()
73+
74+
override def visitLabel(label: Label): Unit =
75+
seenLabels.add(label)
76+
super.visitLabel(label)
77+
78+
override def visitJumpInsn(opcode: Int, label: Label): Unit =
79+
// Add throwIfReplStopped if this is a backward branch (jumping to a label we've already seen)
80+
if seenLabels.contains(label) then addStopCheck()
81+
super.visitJumpInsn(opcode, label)
82+
83+
end ReplBytecodeInstrumentation

compiler/src/dotty/tools/repl/ReplDriver.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import dotty.tools.dotc.{CompilationUnit, Driver}
3434
import dotty.tools.dotc.config.CompilerCommand
3535
import dotty.tools.io.*
3636
import dotty.tools.repl.Rendering.showUser
37+
import dotty.tools.repl.ReplBytecodeInstrumentation
3738
import dotty.tools.runner.ScalaClassLoader.*
3839
import org.jline.reader.*
3940

@@ -228,13 +229,20 @@ class ReplDriver(settings: Array[String],
228229
// Set up interrupt handler for command execution
229230
var firstCtrlCEntered = false
230231
val thread = Thread.currentThread()
232+
233+
// Clear the stop flag before executing new code
234+
ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), false)
235+
231236
val previousSignalHandler = terminal.handle(
232237
org.jline.terminal.Terminal.Signal.INT,
233238
(sig: org.jline.terminal.Terminal.Signal) => {
234239
if (!firstCtrlCEntered) {
235240
firstCtrlCEntered = true
241+
// Set the stop flag to trigger throwIfReplStopped() in instrumented code
242+
ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), true)
243+
// Also interrupt the thread as a fallback for non-instrumented code
236244
thread.interrupt()
237-
out.println("\nInterrupting running thread, Ctrl-C again to terminate the REPL Process")
245+
out.println("\nInterrupting running thread")
238246
} else {
239247
out.println("\nTerminating REPL Process...")
240248
System.exit(130) // Standard exit code for SIGINT
@@ -591,8 +599,9 @@ class ReplDriver(settings: Array[String],
591599
val prevClassLoader = rendering.classLoader()
592600
val jarClassLoader = fromURLsParallelCapable(
593601
jarClassPath.asURLs, prevClassLoader)
602+
val instrumentBytecode = !ctx.settings.XreplDisableBytecodeInstrumentation.value
594603
rendering.myClassLoader = new AbstractFileClassLoader(
595-
prevOutputDir, jarClassLoader)
604+
prevOutputDir, jarClassLoader, instrumentBytecode)
596605

597606
out.println(s"Added '$path' to classpath.")
598607
} catch {
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package dotty.tools.repl
2+
3+
import scala.annotation.static
4+
5+
class StopRepl
6+
7+
object StopRepl {
8+
// Needs to be volatile, otherwise changes to this may not get seen by other threads
9+
// for arbitrarily long periods of time (minutes!)
10+
@static @volatile private var stop: Boolean = false
11+
12+
@static def setStop(n: Boolean): Unit = { stop = n }
13+
14+
/** Check if execution should stop, and throw ThreadDeath if so */
15+
@static def throwIfReplStopped(): Unit = {
16+
if (stop) throw new ThreadDeath()
17+
}
18+
}

0 commit comments

Comments
 (0)