Skip to content

Commit a68d5b1

Browse files
committed
.
1 parent 96aff06 commit a68d5b1

File tree

6 files changed

+180
-6
lines changed

6 files changed

+180
-6
lines changed

compiler/src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ private sealed trait XSettings:
325325
val XprintSuspension: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprint-suspension", "Show when code is suspended until macros are compiled.")
326326
val Xprompt: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xprompt", "Display a prompt after each error (debugging option).")
327327
val XreplDisableDisplay: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xrepl-disable-display", "Do not display definitions in REPL.")
328+
val XreplDisableBytecodeInstrumentation: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xrepl-disable-bytecode-instrumentation", "Disable bytecode instrumentation for interrupt handling in REPL.")
328329
val XverifySignatures: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xverify-signatures", "Verify generic signatures in generated bytecode.")
329330
val XignoreScala2Macros: Setting[Boolean] = BooleanSetting(AdvancedSetting, "Xignore-scala2-macros", "Ignore errors when compiling code that calls Scala2 macros, these will fail at runtime.")
330331
val XimportSuggestionTimeout: Setting[Int] = IntSetting(AdvancedSetting, "Ximport-suggestion-timeout", "Timeout (in ms) for searching for import suggestions when errors are reported.", 8000)

compiler/src/dotty/tools/repl/AbstractFileClassLoader.scala

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ package repl
1616
import scala.language.unsafeNulls
1717

1818
import io.AbstractFile
19+
import dotty.tools.repl.ReplBytecodeInstrumentation
1920

2021
import java.net.{URL, URLConnection, URLStreamHandler}
2122
import java.util.Collections
2223

23-
class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) extends ClassLoader(parent):
24+
class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader, instrumentBytecode: Boolean = true) extends ClassLoader(parent):
2425
private def findAbstractFile(name: String) = root.lookupPath(name.split('/').toIndexedSeq, directory = false)
2526

2627
// on JDK 20 the URL constructor we're using is deprecated,
@@ -53,9 +54,66 @@ class AbstractFileClassLoader(val root: AbstractFile, parent: ClassLoader) exten
5354
if (file == null) {
5455
throw new ClassNotFoundException(name)
5556
}
56-
val bytes = file.toByteArray
57+
val originalBytes = file.toByteArray
58+
59+
// Instrument bytecode for everything except ReplCancel itself to avoid infinite recursion
60+
val bytes =
61+
if !instrumentBytecode || name == "dotty.tools.repl.ReplCancel" then originalBytes
62+
else ReplBytecodeInstrumentation.instrument(originalBytes)
63+
5764
defineClass(name, bytes, 0, bytes.length)
5865
}
5966

60-
override def loadClass(name: String): Class[?] = try findClass(name) catch case _: ClassNotFoundException => super.loadClass(name)
67+
private def tryInstrumentLibraryClass(name: String): Class[?] =
68+
try
69+
val resourceName = name.replace('.', '/') + ".class"
70+
getParent.getResourceAsStream(resourceName) match{
71+
case null => super.loadClass(resourceName)
72+
case is =>
73+
try
74+
val bytes = is.readAllBytes()
75+
val instrumentedBytes =
76+
if instrumentBytecode then ReplBytecodeInstrumentation.instrument(bytes)
77+
else bytes
78+
defineClass(name, instrumentedBytes, 0, instrumentedBytes.length)
79+
finally is.close()
80+
}
81+
catch
82+
case ex: Exception => super.loadClass(name)
83+
84+
override def loadClass(name: String): Class[?] =
85+
if !instrumentBytecode then
86+
return super.loadClass(name)
87+
88+
// Check if already loaded
89+
val loaded = findLoadedClass(name)
90+
if loaded != null then
91+
return loaded
92+
93+
// Don't instrument JDK classes or ReplCancel
94+
name match{
95+
case s"java.$_" => super.loadClass(name)
96+
case s"javax.$_" => super.loadClass(name)
97+
case s"sun.$_" => super.loadClass(name)
98+
case s"jdk.$_" => super.loadClass(name)
99+
case "dotty.tools.repl.ReplCancel" =>
100+
// Load ReplCancel from parent but ensure each classloader gets its own copy
101+
val is = getParent.getResourceAsStream(name.replace('.', '/') + ".class")
102+
if is != null then
103+
try
104+
val bytes = is.readAllBytes()
105+
defineClass(name, bytes, 0, bytes.length)
106+
finally
107+
is.close()
108+
else
109+
// Can't get as resource, use the classloader that loaded this AbstractFileClassLoader
110+
// class itself, which must have access to ReplCancel
111+
classOf[AbstractFileClassLoader].getClassLoader.loadClass(name)
112+
case _ =>
113+
try findClass(name)
114+
catch case _: ClassNotFoundException =>
115+
// Not in REPL output, try to load from parent and instrument it
116+
tryInstrumentLibraryClass(name)
117+
}
118+
61119
end AbstractFileClassLoader

compiler/src/dotty/tools/repl/Rendering.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None):
7272
new java.net.URLClassLoader(compilerClasspath.toArray, baseClassLoader)
7373
}
7474

75-
myClassLoader = new AbstractFileClassLoader(ctx.settings.outputDir.value, parent)
75+
val instrumentBytecode = !ctx.settings.XreplDisableBytecodeInstrumentation.value
76+
myClassLoader = new AbstractFileClassLoader(ctx.settings.outputDir.value, parent, instrumentBytecode)
7677
myClassLoader
7778
}
7879

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package dotty.tools
2+
package repl
3+
4+
import scala.language.unsafeNulls
5+
6+
import scala.tools.asm.*
7+
import scala.tools.asm.Opcodes.*
8+
import scala.tools.asm.tree.*
9+
import scala.collection.JavaConverters.*
10+
import java.util.concurrent.atomic.AtomicBoolean
11+
12+
object ReplBytecodeInstrumentation:
13+
14+
/** The internal name for the cancel class (with slashes) */
15+
private val CANCEL_CLASS_INTERNAL = "dotty/tools/repl/ReplCancel"
16+
17+
/** The name of the stop check method */
18+
private val STOP_CHECK_METHOD = "stopCheck"
19+
20+
/** Instrument bytecode to add stop checks at backward branches.
21+
*
22+
* Backward branches indicate loops, which is where code can hang.
23+
* We inject a call to stopCheck() before each backward branch.
24+
*
25+
* @param originalBytes the original class bytecode
26+
* @return the instrumented bytecode
27+
*/
28+
def instrument(originalBytes: Array[Byte]): Array[Byte] =
29+
try
30+
val cr = new ClassReader(originalBytes)
31+
val cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES)
32+
val instrumenter = new StopCheckInstrumenter(cw)
33+
cr.accept(instrumenter, ClassReader.EXPAND_FRAMES)
34+
cw.toByteArray
35+
catch
36+
case ex: Exception => originalBytes
37+
38+
def setStopFlag(classLoader: ClassLoader, b: Boolean): Unit =
39+
val cancelClassOpt =
40+
try Some(classLoader.loadClass(CANCEL_CLASS_INTERNAL.replace('/', '.')))
41+
catch{
42+
case _: java.lang.ClassNotFoundException => None
43+
}
44+
for(cancelClass <- cancelClassOpt){
45+
val setAllStopMethod = cancelClass.getDeclaredMethod("setAllStop", classOf[Boolean])
46+
setAllStopMethod.invoke(null, b.asInstanceOf[AnyRef])
47+
}
48+
49+
/** ClassVisitor that instruments methods to add stop checks at backward branches */
50+
private class StopCheckInstrumenter(cv: ClassVisitor) extends ClassVisitor(ASM9, cv):
51+
52+
override def visitMethod(
53+
access: Int,
54+
name: String,
55+
descriptor: String,
56+
signature: String,
57+
exceptions: Array[String]
58+
): MethodVisitor =
59+
val mv = super.visitMethod(access, name, descriptor, signature, exceptions)
60+
if mv == null then mv
61+
else new StopCheckMethodVisitor(mv)
62+
63+
/** MethodVisitor that inserts stop checks at backward branches */
64+
private class StopCheckMethodVisitor(mv: MethodVisitor) extends MethodVisitor(ASM9, mv):
65+
// Track labels we've seen to identify backward branches
66+
private val seenLabels = scala.collection.mutable.Set[Label]()
67+
68+
override def visitLabel(label: Label): Unit =
69+
seenLabels.add(label)
70+
super.visitLabel(label)
71+
72+
override def visitJumpInsn(opcode: Int, label: Label): Unit =
73+
// Check if this is a backward branch (jumping to a label we've already seen)
74+
if seenLabels.contains(label) then
75+
// Insert stopCheck() call before the backward branch
76+
mv.visitMethodInsn(
77+
INVOKESTATIC,
78+
CANCEL_CLASS_INTERNAL,
79+
STOP_CHECK_METHOD,
80+
"()V",
81+
false
82+
)
83+
84+
// Emit the original jump instruction
85+
super.visitJumpInsn(opcode, label)
86+
87+
end ReplBytecodeInstrumentation
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package dotty.tools.repl
2+
3+
import scala.annotation.static
4+
5+
class ReplCancel
6+
7+
object ReplCancel {
8+
// Needs to be volatile, otherwise changes to this may not get seen by other threads
9+
// for arbitrarily long periods of time (minutes!)
10+
@static @volatile private var allStop: Boolean = false
11+
12+
@static def setAllStop(n: Boolean): Unit = { allStop = n }
13+
14+
/** Check if execution should stop, and throw ThreadDeath if so */
15+
@static def stopCheck(): Unit = {
16+
if (allStop) throw new ThreadDeath()
17+
}
18+
}

compiler/src/dotty/tools/repl/ReplDriver.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import dotty.tools.dotc.{CompilationUnit, Driver}
3434
import dotty.tools.dotc.config.CompilerCommand
3535
import dotty.tools.io.*
3636
import dotty.tools.repl.Rendering.showUser
37+
import dotty.tools.repl.ReplBytecodeInstrumentation
3738
import dotty.tools.runner.ScalaClassLoader.*
3839
import org.jline.reader.*
3940

@@ -228,13 +229,20 @@ class ReplDriver(settings: Array[String],
228229
// Set up interrupt handler for command execution
229230
var firstCtrlCEntered = false
230231
val thread = Thread.currentThread()
232+
233+
// Clear the stop flag before executing new code
234+
ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), false)
235+
231236
val previousSignalHandler = terminal.handle(
232237
org.jline.terminal.Terminal.Signal.INT,
233238
(sig: org.jline.terminal.Terminal.Signal) => {
234239
if (!firstCtrlCEntered) {
235240
firstCtrlCEntered = true
241+
// Set the stop flag to trigger stopCheck() in instrumented code
242+
ReplBytecodeInstrumentation.setStopFlag(rendering.classLoader()(using state.context), true)
243+
// Also interrupt the thread as a fallback for non-instrumented code
236244
thread.interrupt()
237-
out.println("\nInterrupting running thread, Ctrl-C again to terminate the REPL Process")
245+
out.println("\nInterrupting running thread")
238246
} else {
239247
out.println("\nTerminating REPL Process...")
240248
System.exit(130) // Standard exit code for SIGINT
@@ -591,8 +599,9 @@ class ReplDriver(settings: Array[String],
591599
val prevClassLoader = rendering.classLoader()
592600
val jarClassLoader = fromURLsParallelCapable(
593601
jarClassPath.asURLs, prevClassLoader)
602+
val instrumentBytecode = !ctx.settings.XreplDisableBytecodeInstrumentation.value
594603
rendering.myClassLoader = new AbstractFileClassLoader(
595-
prevOutputDir, jarClassLoader)
604+
prevOutputDir, jarClassLoader, instrumentBytecode)
596605

597606
out.println(s"Added '$path' to classpath.")
598607
} catch {

0 commit comments

Comments
 (0)