Skip to content

Commit 8e6c29f

Browse files
committed
.
1 parent 96aff06 commit 8e6c29f

File tree

6 files changed

+183
-6
lines changed

6 files changed

+183
-6
lines changed

compiler/src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ private sealed trait XSettings:
325325
val XprintSuspension: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprint-suspension", "Show when code is suspended until macros are compiled.")
326326
val Xprompt: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprompt", "Display a prompt after each error (debugging option).")
327327
val XreplDisableDisplay: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xrepl-disable-display", "Do not display definitions in REPL.")
328+
val XreplDisableBytecodeInstrumentation: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xrepl-disable-bytecode-instrumentation", "Disable bytecode instrumentation for interrupt handling in REPL.")
328329
val XverifySignatures: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xverify-signatures", "Verify generic signatures in generated bytecode.")
329330
val XignoreScala2Macros: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xignore-scala2-macros", "Ignore errors when compiling code that calls Scala2 macros, these will fail at runtime.")
330331
val XimportSuggestionTimeout: Setting[Int] = IntSetting(AdvancedSetting, "Ximport-suggestion-timeout", "Timeout (in ms) for searching for import suggestions when errors are reported.", 8000)

compiler/src/dotty/tools/repl/AbstractFileClassLoader.scala

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ package repl
1616
import scala.language.unsafeNulls
1717

1818
import io.AbstractFile
19+
import dotty.tools.repl.ReplBytecodeInstrumentation
1920

2021
import java.net.{URL, URLConnection, URLStreamHandler}
2122
import java.util.Collections
2223

23-
class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) extends ClassLoader(parent):
24+
class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader, instrumentBytecode: Boolean = true) extends ClassLoader(parent):
2425
private def findAbstractFile(name: String) = root.lookupPath(name.split('/').toIndexedSeq, directory = false)
2526

2627
// on JDK 20 the URL constructor we're using is deprecated,
@@ -53,9 +54,62 @@ class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) exten
5354
if (file == null) {
5455
throw new ClassNotFoundException(name)
5556
}
56-
val bytes = file.toByteArray
57+
val originalBytes = file.toByteArray
58+
59+
// Instrument bytecode for everything except ReplCancel itself to avoid infinite recursion
60+
val bytes =
61+
if !instrumentBytecode || name == "dotty.tools.repl.ReplCancel" then originalBytes
62+
else ReplBytecodeInstrumentation.instrument(originalBytes)
63+
5764
defineClass(name, bytes, 0, bytes.length)
5865
}
5966

60-
override def loadClass(name: String): Class[?] = try findClass(name) catch case _: ClassNotFoundException => super.loadClass(name)
67+
private def tryInstrumentLibraryClass(name: String): Class[?] =
68+
try
69+
val resourceName = name.replace('.', '/') + ".class"
70+
val is = getParent.getResourceAsStream(resourceName)
71+
72+
if (is == null) return super.loadClass(name)
73+
74+
try
75+
val bytes = is.readAllBytes()
76+
val instrumentedBytes = if instrumentBytecode then ReplBytecodeInstrumentation.instrument(bytes) else bytes
77+
defineClass(name, instrumentedBytes, 0, instrumentedBytes.length)
78+
finally
79+
is.close()
80+
catch
81+
case ex: Exception => super.loadClass(name)
82+
83+
override def loadClass(name: String): Class[?] =
84+
// Check if already loaded
85+
val loaded = findLoadedClass(name)
86+
if loaded != null then
87+
return loaded
88+
89+
// Don't instrument JDK classes or ReplCancel
90+
name match{
91+
case s"java.$_" => super.loadClass(name)
92+
case s"javax.$_" => super.loadClass(name)
93+
case s"sun.$_" => super.loadClass(name)
94+
case s"jdk.$_" => super.loadClass(name)
95+
case "dotty.tools.repl.ReplCancel" =>
96+
// Load ReplCancel from parent but ensure each classloader gets its own copy
97+
val is = getParent.getResourceAsStream(name.replace('.', '/') + ".class")
98+
if is != null then
99+
try
100+
val bytes = is.readAllBytes()
101+
defineClass(name, bytes, 0, bytes.length)
102+
finally
103+
is.close()
104+
else
105+
// Can't get as resource, use the classloader that loaded this AbstractFileClassLoader
106+
// class itself, which must have access to ReplCancel
107+
classOf[AbstractFileClassLoader].getClassLoader.loadClass(name)
108+
case _ =>
109+
try findClass(name)
110+
catch case _: ClassNotFoundException =>
111+
// Not in REPL output, try to load from parent and instrument it
112+
tryInstrumentLibraryClass(name)
113+
}
114+
61115
end AbstractFileClassLoader

compiler/src/dotty/tools/repl/Rendering.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None):
7272
new java.net.URLClassLoader(compilerClasspath.toArray, baseClassLoader)
7373
}
7474

75-
myClassLoader = new AbstractFileClassLoader(ctx.settings.outputDir.value, parent)
75+
val instrumentBytecode = !ctx.settings.XreplDisableBytecodeInstrumentation.value
76+
myClassLoader = new AbstractFileClassLoader(ctx.settings.outputDir.value, parent, instrumentBytecode)
7677
myClassLoader
7778
}
7879

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package dotty.tools
2+
package repl
3+
4+
import scala.language.unsafeNulls
5+
6+
import scala.tools.asm.*
7+
import scala.tools.asm.Opcodes.*
8+
import scala.tools.asm.tree.*
9+
import scala.collection.JavaConverters.*
10+
import java.util.concurrent.atomic.AtomicBoolean
11+
12+
object ReplBytecodeInstrumentation:
13+
14+
/** The internal name for the cancel class (with slashes) */
15+
private val CANCEL_CLASS_INTERNAL = "dotty/tools/repl/ReplCancel"
16+
17+
/** The name of the stop check method */
18+
private val STOP_CHECK_METHOD = "stopCheck"
19+
20+
/** Instrument bytecode to add stop checks at backward branches.
21+
*
22+
* Backward branches indicate loops, which is where code can hang.
23+
* We inject a call to stopCheck() before each backward branch.
24+
*
25+
* @param originalBytes the original class bytecode
26+
* @return the instrumented bytecode
27+
*/
28+
def instrument(originalBytes: Array[Byte]): Array[Byte] =
29+
try
30+
val cr = new ClassReader(originalBytes)
31+
val cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES)
32+
val instrumenter = new StopCheckInstrumenter(cw)
33+
cr.accept(instrumenter, ClassReader.EXPAND_FRAMES)
34+
cw.toByteArray
35+
catch
36+
case ex: Exception => originalBytes
37+
38+
def setStopFlag(classLoader: ClassLoader, b: Boolean): Unit =
39+
val cancelClassOpt =
40+
try Some(classLoader.loadClass(CANCEL_CLASS_INTERNAL.replace('/', '.')))
41+
catch{
42+
case _: java.lang.ClassNotFoundException => None
43+
}
44+
for(cancelClass <- cancelClassOpt){
45+
val setAllStopMethod = cancelClass.getDeclaredMethod("setAllStop", classOf[Boolean])
46+
setAllStopMethod.invoke(null, b.asInstanceOf[AnyRef])
47+
}
48+
49+
/** ClassVisitor that instruments methods to add stop checks at backward branches */
50+
private class StopCheckInstrumenter(cv: ClassVisitor) extends ClassVisitor(ASM9, cv):
51+
52+
override def visitMethod(
53+
access: Int,
54+
name: String,
55+
descriptor: String,
56+
signature: String,
57+
exceptions: Array[String]
58+
): MethodVisitor =
59+
val mv = super.visitMethod(access, name, descriptor, signature, exceptions)
60+
if mv == null then mv
61+
else new StopCheckMethodVisitor(mv)
62+
63+
/** MethodVisitor that inserts stop checks at backward branches */
64+
private class StopCheckMethodVisitor(mv: MethodVisitor) extends MethodVisitor(ASM9, mv):
65+
// Track labels we've seen to identify backward branches
66+
private val seenLabels = scala.collection.mutable.Set[Label]()
67+
68+
override def visitLabel(label: Label): Unit =
69+
seenLabels.add(label)
70+
super.visitLabel(label)
71+
72+
override def visitJumpInsn(opcode: Int, label: Label): Unit =
73+
// Check if this is a backward branch (jumping to a label we've already seen)
74+
if seenLabels.contains(label) then
75+
// Insert stopCheck() call before the backward branch
76+
mv.visitMethodInsn(
77+
INVOKESTATIC,
78+
CANCEL_CLASS_INTERNAL,
79+
STOP_CHECK_METHOD,
80+
"()V",
81+
false
82+
)
83+
84+
// Emit the original jump instruction
85+
super.visitJumpInsn(opcode, label)
86+
87+
end ReplBytecodeInstrumentation
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package dotty.tools.repl
2+
3+
import scala.annotation.static
4+
5+
/** Cancel class for REPL interrupt handling.
6+
*
7+
* Each REPL classloader gets its own copy of this class. When Ctrl-C is pressed,
8+
* the allStop flag is set to 0, causing the 1/0 to throw an exception that we translate
9+
* into a `ThreadDeath`. If not, the the 1/1 computation should be fast and add minimal
10+
* overhead to the instrumented code in the hot path.
11+
*/
12+
class ReplCancel
13+
14+
object ReplCancel {
15+
// Needs to be volatile, otherwise changes to this may not get seen by other threads
16+
// for arbitrarily long periods of time (minutes!)
17+
@static @volatile private var allStop: Boolean = false
18+
19+
@static def setAllStop(n: Boolean): Unit = { allStop = n }
20+
21+
/** Check if execution should stop, and throw ThreadDeath if so */
22+
@static def stopCheck(): Unit = {
23+
if (allStop) throw new ThreadDeath()
24+
}
25+
}

compiler/src/dotty/tools/repl/ReplDriver.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import dotty.tools.dotc.{CompilationUnit, Driver}
3434
import dotty.tools.dotc.config.CompilerCommand
3535
import dotty.tools.io.*
3636
import dotty.tools.repl.Rendering.showUser
37+
import dotty.tools.repl.ReplBytecodeInstrumentation
3738
import dotty.tools.runner.ScalaClassLoader.*
3839
import org.jline.reader.*
3940

@@ -228,13 +229,20 @@ class ReplDriver(settings: Array[String],
228229
// Set up interrupt handler for command execution
229230
var firstCtrlCEntered = false
230231
val thread = Thread.currentThread()
232+
233+
// Clear the stop flag before executing new code
234+
ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), false)
235+
231236
val previousSignalHandler = terminal.handle(
232237
org.jline.terminal.Terminal.Signal.INT,
233238
(sig: org.jline.terminal.Terminal.Signal) => {
234239
if (!firstCtrlCEntered) {
235240
firstCtrlCEntered = true
241+
// Set the stop flag to trigger stopCheck() in instrumented code
242+
ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), true)
243+
// Also interrupt the thread as a fallback for non-instrumented code
236244
thread.interrupt()
237-
out.println("\nInterrupting running thread, Ctrl-C again to terminate the REPL Process")
245+
out.println("\nInterrupting running thread")
238246
} else {
239247
out.println("\nTerminating REPL Process...")
240248
System.exit(130) // Standard exit code for SIGINT
@@ -591,8 +599,9 @@ class ReplDriver(settings: Array[String],
591599
val prevClassLoader = rendering.classLoader()
592600
val jarClassLoader = fromURLsParallelCapable(
593601
jarClassPath.asURLs, prevClassLoader)
602+
val instrumentBytecode = !ctx.settings.XreplDisableBytecodeInstrumentation.value
594603
rendering.myClassLoader = new AbstractFileClassLoader(
595-
prevOutputDir, jarClassLoader)
604+
prevOutputDir, jarClassLoader, instrumentBytecode)
596605

597606
out.println(s"Added '$path' to classpath.")
598607
} catch {

0 commit comments

Comments
 (0)