From b59a8883e99f8383de9af8ff33ffd9518b3e6680 Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Thu, 20 Nov 2025 20:07:12 +0800 Subject: [PATCH 01/16] progress --- .../shared/src/main/scala/hkmc2/Config.scala | 2 + .../main/scala/hkmc2/codegen/Lowering.scala | 12 +- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 116 ++++++++++++++++++ hkmc2/shared/src/test/mlscript/HkScratch.mls | 17 ++- .../src/test/scala/hkmc2/MLsDiffMaker.scala | 2 + 5 files changed, 144 insertions(+), 5 deletions(-) create mode 100644 hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala diff --git a/hkmc2/shared/src/main/scala/hkmc2/Config.scala b/hkmc2/shared/src/main/scala/hkmc2/Config.scala index 6e6a7f183b..f46e001a21 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/Config.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/Config.scala @@ -23,6 +23,7 @@ case class Config( stageCode: Bool, target: CompilationTarget, rewriteWhileLoops: Bool, + tailRecOpt: Bool, ): def stackSafety: Opt[StackSafety] = effectHandlers.flatMap(_.stackSafety) @@ -40,6 +41,7 @@ object Config: target = CompilationTarget.JS, rewriteWhileLoops = true, stageCode = false, + tailRecOpt = false, ) case class SanityChecks(light: Bool) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index d4922ec4c2..3235f5de10 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -973,10 +973,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val merged = MergeMatchArmTransformer.applyBlock(bufferable) - val res = + val staged = if config.stageCode then Instrumentation(using summon).applyBlock(merged) else merged + val res = + if config.tailRecOpt then TailRecOpt().transform(staged) + else staged + Program( imps.map(imp => imp.sym -> imp.str), res @@ -1007,7 +1011,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case Annot.Untyped => () case a @ Annot.TailRec => target match - case TermDefinition(body = S(bod), k = syntax.Fun) => warn(a, S(msg"Tail call optimization is not yet implemented.")) + case TermDefinition(body = S(bod), k = syntax.Fun) => () case TermDefinition(k = syntax.Fun) => warn(a, S(msg"Only functions with a body may be marked as @tailrec.")) case _ => warn(a) @@ -1028,9 +1032,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case Annot.Untyped => () case a @ Annot.TailCall => receiver match case st.App(Ref(_: BuiltinSymbol), _) => warn(a, S(msg"The @tailcall annotation has no effect on calls to built-in symbols.")) - case st.App(_, _) => warn(a, S(msg"Tail call optimization is not yet implemented.")) + case st.App(_, _) => () case st.Resolved(_, defnSym) => defnSym.defn match - case S(td: TermDefinition) if (td.k is syntax.Fun) && td.params.isEmpty => warn(a, S(msg"Tail call optimization is not yet implemented.")) + case S(td: TermDefinition) if (td.k is syntax.Fun) && td.params.isEmpty => () case _ => warn(a) case _ => warn(a) case annot => warn(annot) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala new file mode 100644 index 0000000000..d4852ab50b --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -0,0 +1,116 @@ +package hkmc2 + +import scala.collection.mutable + +import mlscript.utils.*, shorthands.* +import utils.* + +import hkmc2.codegen.* +import hkmc2.semantics.* +import hkmc2.Message.* +import hkmc2.semantics.Elaborator.State +import hkmc2.syntax.Tree + +// This optimization assumes the lifter has been run. +// It technically still works without lifting, but it will only consider calls to functions defined in the same scope. +class TailRecOpt(using State, TL, Raise): + + object CallToFun: + def unapply(c: Call): Opt[TermSymbol] = c match + case Call(fun = Value.Ref(b, S(r: TermSymbol))) => S(r) + case _ => N + + sealed abstract class CallEdge: + val f1: TermSymbol + val f2: TermSymbol + val call: Call + + case class TailCall(f1: TermSymbol, f2: TermSymbol)(val call: Call) extends CallEdge + case class NormalCall(f1: TermSymbol, f2: TermSymbol)(val call: Call) extends CallEdge + + class CallFinder(f: FunDefn) extends BlockTraverserShallow: + + var edges: List[CallEdge] = Nil + + def find = + println(f.sym) + edges = Nil + applyBlock(f.body) + edges + + override def applyBlock(b: Block): Unit = b match + case Return(c @ CallToFun(r), _) => edges ::= TailCall(f.dSym, r)(c) + case Assign(a, c @ CallToFun(r), Return(b, _)) if a == b => edges ::= TailCall(f.dSym, r)(c) + case Return(c: Call, _) => + if c.explicitTailCall then + raise(ErrorReport(msg"Only direct calls in tail position may be marked @tailcall." -> c.toLoc :: Nil)) + case _ => super.applyBlock(b) + + override def applyResult(r: Result): Unit = r match + case c: Call => + if c.explicitTailCall then + raise(ErrorReport(msg"This call is not in tail position." -> c.toLoc :: Nil)) + c match + case CallToFun(r) => edges ::= NormalCall(f.dSym, r)(c) + case _ => + case _ => super.applyResult(r) + + def buildCallGraph(fs: List[FunDefn]): List[CallEdge] = + fs.flatMap(f => CallFinder(f).find) + + case class SccOfCalls(funs: List[FunDefn], calls: List[CallEdge]) + + def partFns(fs: List[FunDefn]): List[SccOfCalls] = + val defnSyms = fs.map(_.dSym) + val tsToDefn = fs.map(f => f.dSym -> f).toMap + + // Only care about calls to functions in the same scope + // Note that the results may differ if the lifter has been run. + val cg = buildCallGraph(fs).filter: c => + val cond = defnSyms.contains(c.f1) && defnSyms.contains(c.f2) + c.match + case c: TailCall if c.call.explicitTailCall && !cond => + raise(ErrorReport( + msg"This tail call exits the current scope and cannot be optimized. Enabling the lifter may fix this." -> c.call.toLoc :: Nil)) + case _ => + cond + + val cgTup = cg.map(c => (c.f1, c.f2)) + val sccs = algorithms.sccsWithInfo(cgTup, Nil) + + // partition the call graph + val sccMap = sccs.sccs.flatMap: + case (id, scc) => scc.map(f => f -> id) + + val cgLabelled = cg + .groupBy: c => + val s1 = sccMap(c.f1) + val s2 = sccMap(c.f2) + if s1 != s2 && c.call.explicitTailCall then + raise(ErrorReport( + msg"This call is not optimized as it does not directly recurse through its parent function." -> c.call.toLoc :: Nil)) + -1 + else s1 + .filter: + (id, _) => id != -1 + + sccs.sccs.toList.map: v => + val (id, tss) = v + val cgs = cgLabelled.get(id) match + case Some(value) => value + case None => Nil + SccOfCalls(tss.map(tsToDefn), cgs) + + def optFunctions(fs: List[FunDefn], owner: Opt[InnerSymbol]) = + val parts = partFns(fs) + + def transform(b: Block) = + val (blk, defns) = b.floatOutDefns() + val (funs, clses) = + defns.foldLeft[(List[FunDefn], List[ClsLikeDefn])](Nil, Nil): + case ((fs, cs), d) => d match + case f: FunDefn => (f :: fs, cs) + case c: ClsLikeDefn => (fs, c :: cs) + case _ => (fs, cs) // unreachable as floatOutDefns only floats out FunDefns and ClsLikeDefns + print(optFunctions(funs, N)) + b diff --git a/hkmc2/shared/src/test/mlscript/HkScratch.mls b/hkmc2/shared/src/test/mlscript/HkScratch.mls index ef38e8b363..44a3165d05 100644 --- a/hkmc2/shared/src/test/mlscript/HkScratch.mls +++ b/hkmc2/shared/src/test/mlscript/HkScratch.mls @@ -5,8 +5,23 @@ // :elt :global +:tailrec // :d // :todo +fun f(x) = + g(x) +fun g(x) = + if true then + @tailcall f(x) + else + @tailcall h(x) +fun h(x) = 2 * x +//│ FAILURE: Unexpected type error +//│ FAILURE LOCATION: cgLabelled (TailRecOpt.scala:90) +//│ ╔══[ERROR] This call is not optimized as it does not directly recurse through its parent function. +//│ ║ l.18: @tailcall h(x) +//│ ╙── ^^^^ - +fun f(x) = + @tailcall f(x) diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala index f527cb10d8..b8b60db9cb 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala @@ -72,6 +72,7 @@ abstract class MLsDiffMaker extends DiffMaker: val importQQ = NullaryCommand("qq") val stageCode = NullaryCommand("staging") val dontRewriteWhile = NullaryCommand("dontRewriteWhile") + val tailRecOpt = NullaryCommand("tailrec") def mkConfig: Config = import Config.* @@ -100,6 +101,7 @@ abstract class MLsDiffMaker extends DiffMaker: stageCode = stageCode.isSet, target = if wasm.isSet then CompilationTarget.Wasm else CompilationTarget.JS, rewriteWhileLoops = !dontRewriteWhile.isSet, + tailRecOpt = tailRecOpt.isSet, ) From 3b2ccc021c89e18d1ccb1bb08190116335061e5b Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Sat, 29 Nov 2025 02:07:47 +0800 Subject: [PATCH 02/16] progress --- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 198 ++++++++++++++- .../hkmc2/semantics/ucs/Normalization.scala | 5 +- hkmc2/shared/src/test/mlscript/HkScratch.mls | 17 +- .../src/test/mlscript/tailrec/Errors.mls | 47 ++++ .../src/test/mlscript/tailrec/TailRecOpt.mls | 230 ++++++++++++++++++ 5 files changed, 466 insertions(+), 31 deletions(-) create mode 100644 hkmc2/shared/src/test/mlscript/tailrec/Errors.mls create mode 100644 hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index d4852ab50b..28b0d73882 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -10,16 +10,27 @@ import hkmc2.semantics.* import hkmc2.Message.* import hkmc2.semantics.Elaborator.State import hkmc2.syntax.Tree +import scala.collection.mutable.ArrayBuffer +import java.lang.instrument.ClassDefinition // This optimization assumes the lifter has been run. -// It technically still works without lifting, but it will only consider calls to functions defined in the same scope. class TailRecOpt(using State, TL, Raise): object CallToFun: def unapply(c: Call): Opt[TermSymbol] = c match case Call(fun = Value.Ref(b, S(r: TermSymbol))) => S(r) + case Call(fun = s: Select) => s.symbol match + case Some(r: TermSymbol) => S(r) + case _ => N case _ => N + object TailCallShape: + def unapply(b: Block): Opt[(TermSymbol, Call)] = b match + case Return(c @ CallToFun(r), _) => S((r, c)) + case Assign(a, c @ CallToFun(r), Return(b, _)) if a == b => S((r, c)) + case _ => N + + sealed abstract class CallEdge: val f1: TermSymbol val f2: TermSymbol @@ -33,19 +44,23 @@ class TailRecOpt(using State, TL, Raise): var edges: List[CallEdge] = Nil def find = - println(f.sym) - edges = Nil - applyBlock(f.body) - edges + // Ignore functions with multiple parameter lists + if f.params.length > 1 then + if f.isTailRec then + raise(ErrorReport(msg"Functions with more than one parameter list may not be marked @tailrec." -> f.dSym.toLoc :: Nil)) + Nil + else + edges = Nil + applyBlock(f.body) + edges override def applyBlock(b: Block): Unit = b match - case Return(c @ CallToFun(r), _) => edges ::= TailCall(f.dSym, r)(c) - case Assign(a, c @ CallToFun(r), Return(b, _)) if a == b => edges ::= TailCall(f.dSym, r)(c) + case TailCallShape(r, c) => edges ::= TailCall(f.dSym, r)(c) case Return(c: Call, _) => if c.explicitTailCall then raise(ErrorReport(msg"Only direct calls in tail position may be marked @tailcall." -> c.toLoc :: Nil)) case _ => super.applyBlock(b) - + override def applyResult(r: Result): Unit = r match case c: Call => if c.explicitTailCall then @@ -71,12 +86,12 @@ class TailRecOpt(using State, TL, Raise): c.match case c: TailCall if c.call.explicitTailCall && !cond => raise(ErrorReport( - msg"This tail call exits the current scope and cannot be optimized. Enabling the lifter may fix this." -> c.call.toLoc :: Nil)) + msg"This tail call exits the current scope and cannot be optimized." -> c.call.toLoc :: Nil)) case _ => cond val cgTup = cg.map(c => (c.f1, c.f2)) - val sccs = algorithms.sccsWithInfo(cgTup, Nil) + val sccs = algorithms.sccsWithInfo(cgTup, defnSyms) // partition the call graph val sccMap = sccs.sccs.flatMap: @@ -101,9 +116,164 @@ class TailRecOpt(using State, TL, Raise): case None => Nil SccOfCalls(tss.map(tsToDefn), cgs) + def maxInt[T](items: List[T], f: T => Int): Int = items.foldLeft(0): + case (l, item) => + val x = f(item) + if x > l then x else l + + def getParamSyms(f: FunDefn) = f.params.headOption match + case Some(ParamList(_, params, S(rest))) => + params.map(_.sym).appended(rest.sym) + case Some(p) => p.params.map(_.sym) + case None => Nil + + // assume only one parameter list + def paramsLen(f: FunDefn): Int = f.params match + case head :: next => + if head.restParam.isDefined then 1 + head.params.length + else head.params.length + case Nil => 0 + + def rewriteCallArgs(f: FunDefn, c: Call): Opt[List[Result]] = + // need to be careful in handling restParams + // if any arg is a spread that spreads across multiple parameters, then + // we ignore it for now + val ret = f.params match + case head :: Nil => + val (headArgs, restArgs) = head.restParam match + case Some(value) => c.args.splitAt(head.params.length) + case None => (c.args, Nil) + + var bad = false + val hd = for a <- headArgs yield a.spread match + case Some(true) => + if c.explicitTailCall then + raise(ErrorReport(msg"Spreads are not yet supported here in calls marked @tailcall." -> a.value.toLoc :: Nil)) + bad = true + a.value + case _ => a.value + if bad then return N + + if head.restParam.isDefined then + val rest = + restArgs match + case Arg(S(true), value) :: Nil => value + case _ => Tuple(true, restArgs) + hd.appended(rest) + else + hd + case Nil => c.args.map(_.value) + case _ => return N + S(ret) + + def optScc(scc: SccOfCalls, owner: Opt[InnerSymbol]): List[FunDefn] = + if scc.calls.size == 0 then return scc.funs + + val nonTailCalls = scc.calls + .collect: + case c: NormalCall => c.f2 -> c.call + .toMap + + if !nonTailCalls.isEmpty then + for f <- scc.funs if f.isTailRec do + val reportLoc = nonTailCalls.get(f.dSym) match + // always display a call to f, if possible + case Some(value) => value.toLoc + case None => nonTailCalls.head._2.toLoc + raise(ErrorReport( + msg"`${f.sym.nme}` is not tail recursive." -> f.dSym.toLoc + :: msg"It could self-recurse through this call, which is not a tail call." -> reportLoc + :: Nil + )) + + val maxParamLen = maxInt(scc.funs, paramsLen) + val paramSyms = + if scc.funs.length == 1 then (getParamSyms(scc.funs.head)) + else + for i <- 0 to maxParamLen - 1 yield VarSymbol(Tree.Ident("param" + i)) + .toList + val paramSymsArr = ArrayBuffer.from(paramSyms) + val dSymIds = scc.funs.map(_.dSym).zipWithIndex.toMap + val bms = + if scc.funs.size == 1 then scc.funs.head.sym + else BlockMemberSymbol(scc.funs.map(_.sym.nme).mkString("_"), Nil, true) + val dSym = + if scc.funs.size == 1 then scc.funs.head.dSym + else TermSymbol(syntax.Fun, owner, Tree.Ident(bms.nme)) + val loopSym = TempSymbol(N, "loopLabel") + val curIdSym = VarSymbol(Tree.Ident("id")) + + class FunRewriter(f: FunDefn) extends BlockTransformer(SymbolSubst()): + val params = getParamSyms(f) + val paramsSet = f.params.toSet + val paramsIdxes = params.zipWithIndex.toMap + + def applyVarSym(l: VarSymbol): VarSymbol = paramsIdxes.get(l) match + case Some(idx) => paramSymsArr(idx) + case _ => l + + override def applyValue(v: Value)(k: Value => Block): Block = v match + case Value.Ref(l: VarSymbol, d) => k(Value.Ref(applyVarSym(l), d)) + case _ => super.applyValue(v)(k) + + override def applyBlock(b: Block): Block = b match + case TailCallShape(dSym, c) => dSymIds.get(dSym) match + case Some(id) => + val argVals = rewriteCallArgs(f, c) match + case Some(value) => value + case None => return super.applyBlock(b) + val cont = Assign(curIdSym, Value.Lit(Tree.IntLit(dSymIds(dSym))), Continue(loopSym)) + paramSyms.zip(argVals).foldRight[Block](cont): + case ((sym, res), acc) => res match + case Value.Ref(`sym`, _) => acc + case _ => applyResult(res)(Assign(sym, _, acc)) + case None => super.applyBlock(b) + case _ => super.applyBlock(b) + + val arms = scc.funs.map: f => + Case.Lit(Tree.IntLit(dSymIds(f.dSym))) -> FunRewriter(f).applyBlock(f.body) + + val switch = + if arms.length == 1 then arms.head._2 + else Match(curIdSym.asPath, arms, N, End()) + + val loop = Label(loopSym, true, switch, End()) + + val rewrittenFuns = + if scc.funs.size == 1 then Nil + else scc.funs.map: f => + val paramArgs = getParamSyms(f).map(_.asPath.asArg) + val args = + Value.Lit(Tree.IntLit(dSymIds(f.dSym))).asArg + :: paramArgs + ::: List.fill(maxParamLen - paramArgs.length)(Value.Lit(Tree.UnitLit(false)).asArg) + val newBod = Return( + Call(Value.Ref(bms, S(dSym)), args)(true, false, false), + false + ) + FunDefn(f.owner, f.sym, f.dSym, f.params, newBod)(false) + + val params = + val initial = paramSyms.map(Param.simple(_)) + if scc.funs.length == 1 then initial + else Param.simple(curIdSym) :: initial + + FunDefn( + owner, bms, dSym, + PlainParamList(params) :: Nil, + loop + )(false) :: rewrittenFuns + def optFunctions(fs: List[FunDefn], owner: Opt[InnerSymbol]) = - val parts = partFns(fs) + partFns(fs).flatMap(optScc(_, owner)) + def optClasses(cs: List[ClsLikeDefn]) = cs.map: c => + val mtds = optFunctions(c.methods, S(c.isym)) + val companion = c.companion.map: comp => + val cMtds = optFunctions(comp.methods, S(comp.isym)) + comp.copy(methods = cMtds) + c.copy(methods = mtds, companion = companion) + def transform(b: Block) = val (blk, defns) = b.floatOutDefns() val (funs, clses) = @@ -112,5 +282,7 @@ class TailRecOpt(using State, TL, Raise): case f: FunDefn => (f :: fs, cs) case c: ClsLikeDefn => (fs, c :: cs) case _ => (fs, cs) // unreachable as floatOutDefns only floats out FunDefns and ClsLikeDefns - print(optFunctions(funs, N)) - b + val bod1 = optFunctions(funs, N).foldLeft(blk): + case (acc, f) => Define(f, acc) + optClasses(clses).foldLeft(bod1): + case (acc, c) => Define(c, acc) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index 066c807d2d..f6d76bdb7c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -328,6 +328,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e // The symbol for the loop label if the term is a `while`. lazy val loopLabel = new TempSymbol(t) lazy val f = new BlockMemberSymbol("while", Nil, false) + lazy val fDSym = new TermSymbol(syntax.Fun, N, Tree.Ident(f.nme)) val normalized = tl.scoped("ucs:normalize"): normalize(inputSplit)(using VarSet()) tl.scoped("ucs:normalized"): @@ -338,7 +339,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e lazy val breakRoot = (r: Result) => Assign(l, r, Break(rootBreakLabel)) lazy val assignResult = (r: Result) => Assign(l, r, End()) val loopCont = if config.rewriteWhileLoops - then Return(Call(Value.Ref(f, N), Nil)(true, true, false), false) + then Return(Call(Value.Ref(f, S(fDSym)), Nil)(true, true, false), false) else Continue(loopLabel) val cont = if kw === `while` then @@ -398,7 +399,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e Select(Value.Ref(State.runtimeSymbol), Tree.Ident("LoopEnd"))(S(State.loopEndSymbol)) val blk = blockBuilder .assign(l, Value.Lit(Tree.UnitLit(false))) - .define(FunDefn.withFreshSymbol(N, f, PlainParamList(Nil) :: Nil, Begin(body, Return(loopEnd, false)))(isTailRec = false)) + .define(FunDefn(N, f, fDSym, PlainParamList(Nil) :: Nil, Begin(body, Return(loopEnd, false)))(isTailRec = false)) .assign(loopResult, Call(Value.Ref(f, N), Nil)(true, true, false)) if summon[LoweringCtx].mayRet then blk diff --git a/hkmc2/shared/src/test/mlscript/HkScratch.mls b/hkmc2/shared/src/test/mlscript/HkScratch.mls index 44a3165d05..ef38e8b363 100644 --- a/hkmc2/shared/src/test/mlscript/HkScratch.mls +++ b/hkmc2/shared/src/test/mlscript/HkScratch.mls @@ -5,23 +5,8 @@ // :elt :global -:tailrec // :d // :todo -fun f(x) = - g(x) -fun g(x) = - if true then - @tailcall f(x) - else - @tailcall h(x) -fun h(x) = 2 * x -//│ FAILURE: Unexpected type error -//│ FAILURE LOCATION: cgLabelled (TailRecOpt.scala:90) -//│ ╔══[ERROR] This call is not optimized as it does not directly recurse through its parent function. -//│ ║ l.18: @tailcall h(x) -//│ ╙── ^^^^ -fun f(x) = - @tailcall f(x) + diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls new file mode 100644 index 0000000000..b6247044f0 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls @@ -0,0 +1,47 @@ +:js +:tailrec + +:e +fun f(x) = + @tailcall f(x) + f(x) +//│ ╔══[ERROR] This call is not in tail position. +//│ ║ l.6: @tailcall f(x) +//│ ╙── ^^^^ + +:e +fun f(x) = 2 +fun g(x) = @tailcall f(x) +//│ ╔══[ERROR] This call is not optimized as it does not directly recurse through its parent function. +//│ ║ l.14: fun g(x) = @tailcall f(x) +//│ ╙── ^^^^ + +:e +@tailrec fun f(x) = + g(x) +fun g(x) = + f(x) + h(x) +fun h(x) = + g(x) +//│ ╔══[ERROR] `f` is not tail recursive. +//│ ║ l.20: @tailrec fun f(x) = +//│ ║ ^ +//│ ╟── It could self-recurse through this call, which is not a tail call. +//│ ║ l.23: f(x) +//│ ╙── ^^^^ + +:e +@tailrec fun f(x) = + g(x) +fun g(x) = + h(x) + f(x) +fun h(x) = + f(x) +//│ ╔══[ERROR] `f` is not tail recursive. +//│ ║ l.35: @tailrec fun f(x) = +//│ ║ ^ +//│ ╟── It could self-recurse through this call, which is not a tail call. +//│ ║ l.38: h(x) +//│ ╙── ^^^^ diff --git a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls new file mode 100644 index 0000000000..baf7403e56 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls @@ -0,0 +1,230 @@ +:js +:tailrec + +:expect 200010000 +fun sum_impl(n, acc) = if n == 0 then acc else sum_impl(n - 1, n + acc) +fun sum(n) = sum_impl(n, 0) +sum(20000) +//│ = 200010000 + +:sjs +:expect 50000 +fun g(a, b, c, d) = f(a, b, c + d) +fun f(a, b, c) = + if + a > 0 then g(a - 1, b, c, 1) + b > 0 then g(a, b - 1, c, 2) + else c +f(10000, 20000, 0) +//│ JS (unsanitized): +//│ let f, g, g_f; +//│ f = function f(a, b, c) { +//│ return g_f(1, a, b, c, undefined) +//│ }; +//│ g = function g(a, b, c, d) { +//│ return g_f(0, a, b, c, d) +//│ }; +//│ g_f = function g_f(id, param0, param1, param2, param3) { +//│ let scrut, scrut1, tmp, tmp1, tmp2; +//│ loopLabel: while (true) { +//│ if (id === 0) { +//│ tmp = param2 + param3; +//│ param0 = param0; +//│ param1 = param1; +//│ param2 = tmp; +//│ id = 1; +//│ continue loopLabel +//│ } else if (id === 1) { +//│ scrut = param0 > 0; +//│ if (scrut === true) { +//│ tmp1 = param0 - 1; +//│ param0 = tmp1; +//│ param1 = param1; +//│ param2 = param2; +//│ param3 = 1; +//│ id = 0; +//│ continue loopLabel +//│ } else { +//│ scrut1 = param1 > 0; +//│ if (scrut1 === true) { +//│ tmp2 = param1 - 1; +//│ param0 = param0; +//│ param1 = tmp2; +//│ param2 = param2; +//│ param3 = 2; +//│ id = 0; +//│ continue loopLabel +//│ } else { return param2 } +//│ } +//│ } +//│ break; +//│ } +//│ }; +//│ f(10000, 20000, 0) +//│ = 50000 + +:sjs +:expect 200010000 +class A with + fun sum_impl(n, acc) = if n == 0 then acc else @tailcall sum_impl(n - 1, n + acc) + fun sum(n) = sum_impl(n, 0) +(new A).sum(20000) +//│ JS (unsanitized): +//│ let A1, tmp; +//│ globalThis.Object.freeze(class A { +//│ static { +//│ A1 = this +//│ } +//│ constructor() {} +//│ sum(n) { +//│ loopLabel: while (true) { +//│ return this.sum_impl(n, 0); +//│ break; +//│ } +//│ } +//│ sum_impl(n, acc) { +//│ let scrut, tmp1, tmp2, id; +//│ loopLabel: while (true) { +//│ scrut = n == 0; +//│ if (scrut === true) { +//│ return acc +//│ } else { +//│ tmp1 = n - 1; +//│ tmp2 = n + acc; +//│ n = tmp1; +//│ acc = tmp2; +//│ id = 0; +//│ continue loopLabel +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "A"]; +//│ }); +//│ tmp = globalThis.Object.freeze(new A1()); +//│ tmp.sum(20000) +//│ = 200010000 + +:silent +:lift +:expect 200010000 +val x = mut [] +let n = 20000 +let i = 0 +while i < n do + x.push of () => i + set i += 1 +fun sumOf(x, idx, acc) = + if idx == 0 then acc + else + sumOf(x, idx - 1, acc + x![idx]()) +sumOf(x, n - 1, 0) + +// check that spreads. where supported, are compiled correctly +:sjs +fun f(x, ...z) = + if x < 0 then 0 + else g(x, 0, x, x) +fun g(x, y, ...z) = + if x < 0 then 0 + else f(x - 1, 0, ...[x, x, x]) +f(100) +//│ JS (unsanitized): +//│ let f1, g1, f_g; +//│ g1 = function g(x1, y, ...z) { +//│ return f_g(1, x1, y, z) +//│ }; +//│ f1 = function f(x1, ...z) { +//│ return f_g(0, x1, z, undefined) +//│ }; +//│ f_g = function f_g(id, param0, param1, param2) { +//│ let scrut, scrut1, tmp5, tmp6; +//│ loopLabel: while (true) { +//│ if (id === 0) { +//│ scrut = param0 < 0; +//│ if (scrut === true) { +//│ return 0 +//│ } else { +//│ param0 = param0; +//│ param1 = [ +//│ 0, +//│ param0, +//│ param0 +//│ ]; +//│ id = 1; +//│ continue loopLabel +//│ } +//│ } else if (id === 1) { +//│ scrut1 = param0 < 0; +//│ if (scrut1 === true) { +//│ return 0 +//│ } else { +//│ tmp5 = param0 - 1; +//│ tmp6 = globalThis.Object.freeze([ +//│ param0, +//│ param0, +//│ param0 +//│ ]); +//│ param0 = tmp5; +//│ param1 = 0; +//│ param2 = tmp6; +//│ id = 0; +//│ continue loopLabel +//│ } +//│ } +//│ break; +//│ } +//│ }; +//│ f1(100) +//│ = 0 + +:todo +fun f(x, y, z) = @tailcall f(...[1, 2, 3]) +//│ ═══[ERROR] Spreads are not yet supported here in calls marked @tailcall. + +:todo +fun f(x, ...y) = + if x < 0 then g(0, 0, 0) + else 0 +fun g(x, y, z) = + @tailcall f(...[1, 1, 2]) +g(0, 0, 0) +//│ ═══[ERROR] Spreads are not yet supported here in calls marked @tailcall. +//│ = 0 + +fun f(x) = + fun g(x) = + @tailcall f(x) + @tailcall g(x) +//│ FAILURE: Unexpected type error +//│ FAILURE LOCATION: cg (TailRecOpt.scala:88) +//│ ╔══[ERROR] This tail call exits the current scope and cannot be optimized. +//│ ║ l.199: @tailcall g(x) +//│ ╙── ^^^^ + +// the lifter doesn't propagate the definition symbols properly +:todo +:lift +fun f(x) = + fun g() = + @tailcall f(x) + @tailcall g() +//│ ═══[ERROR] Only direct calls in tail position may be marked @tailcall. +//│ ╔══[ERROR] This call is not optimized as it does not directly recurse through its parent function. +//│ ║ l.211: @tailcall f(x) +//│ ╙── ^ + +:todo +:lift +class A with + fun f(x) = + fun g(x) = + @tailcall f(x) + @tailcall g(x) +//│ ╔══[ERROR] This tail call exits the current scope and cannot be optimized. +//│ ║ l.223: @tailcall f(x) +//│ ╙── ^^^ +//│ ╔══[ERROR] Only direct calls in tail position may be marked @tailcall. +//│ ║ l.224: @tailcall g(x) +//│ ╙── ^ From eb7c0e145ceb918f0199bd8497705f144b047969 Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Sat, 29 Nov 2025 20:19:33 +0800 Subject: [PATCH 03/16] changes to lifter and handler lowering --- .../scala/hkmc2/codegen/HandlerLowering.scala | 17 ++- .../scala/hkmc2/codegen/LambdaRewriter.scala | 15 +- .../src/main/scala/hkmc2/codegen/Lifter.scala | 58 ++++---- .../main/scala/hkmc2/codegen/Lowering.scala | 6 +- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 4 +- .../scala/hkmc2/codegen/UsedVarAnalyzer.scala | 2 +- .../main/scala/hkmc2/semantics/Symbol.scala | 4 +- .../hkmc2/semantics/ucs/Normalization.scala | 8 +- hkmc2/shared/src/test/mlscript/HkScratch.mls | 39 +++++- .../src/test/mlscript/lifter/ClassInFun.mls | 130 +++++++++++++++++- .../test/mlscript/lifter/EffectHandlers.mls | 83 +++++++++++ .../test/mlscript/lifter/StackSafetyLift.mls | 13 +- .../src/test/mlscript/tailrec/Errors.mls | 19 ++- .../src/test/mlscript/tailrec/TailRecOpt.mls | 53 +++---- 14 files changed, 364 insertions(+), 87 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index e9dd264b34..2d8795aacb 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -542,23 +542,26 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val cls = if handlerCtx.isTopLevel then N else genContClass(b, callSelf) val ret = cls match - case None => genNormalBody(b, BlockMemberSymbol("", Nil), N) + case None => genNormalBody(b, BlockMemberSymbol("", Nil).asPath, N) case Some(cls) => // create the doUnwind function val doUnwindSym = BlockMemberSymbol(doUnwindNme, Nil, true) - doUnwindMap += fnOrCls -> doUnwindSym.asPath val pcSym = VarSymbol(Tree.Ident("pc")) val resSym = VarSymbol(Tree.Ident("res")) val doUnwindBlk = h.linkAndHandle( - LinkState(resSym, cls.sym.asPath, pcSym.asPath) + LinkState(resSym, Value.Ref(cls.sym, S(cls.isym)), pcSym.asPath) ) val doUnwindDef = FunDefn.withFreshSymbol( N, doUnwindSym, PlainParamList(Param.simple(resSym) :: Param.simple(pcSym) :: Nil) :: Nil, doUnwindBlk )(false) - val doUnwindLazy = Lazy(doUnwindSym.asPath) - val rst = genNormalBody(b, cls.sym, S(doUnwindLazy)) + + val doUnwindPath: Path = Value.Ref(doUnwindSym, S(doUnwindDef.dSym)) + doUnwindMap += fnOrCls -> doUnwindPath + + val doUnwindLazy = Lazy(doUnwindPath) + val rst = genNormalBody(b, Value.Ref(cls.sym, S(cls.isym)), S(doUnwindLazy)) if doUnwindLazy.isEmpty && opt.stackSafety.isEmpty then blockBuilder @@ -928,12 +931,12 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, N, // TODO: bufferable? )) - private def genNormalBody(b: Block, clsSym: BlockMemberSymbol, doUnwind: Opt[Lazy[Path]])(using HandlerCtx): Block = + private def genNormalBody(b: Block, clsPath: Path, doUnwind: Opt[Lazy[Path]])(using HandlerCtx): Block = val transform = new BlockTransformerShallow(SymbolSubst()): override def applyBlock(b: Block): Block = b match case ResultPlaceholder(res, uid, c, rest) => val doUnwindBlk = doUnwind match - case None => Assign(res, topLevelCall(LinkState(res, clsSym.asPath, Value.Lit(Tree.IntLit(uid)))), End()) + case None => Assign(res, topLevelCall(LinkState(res, clsPath, Value.Lit(Tree.IntLit(uid)))), End()) case Some(doUnwind) => Return(PureCall(doUnwind.get_!, res.asPath :: Value.Lit(Tree.IntLit(uid)) :: Nil), false) blockBuilder .assign(res, c) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala index be48baefa9..2cddb74add 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala @@ -11,25 +11,26 @@ import hkmc2.syntax.Tree object LambdaRewriter: def desugar(b: Block)(using State) = - def rewriteOneBlk(b: Block) = b match case Assign(lhs, Lambda(params, body), rest) if !lhs.isInstanceOf[TempSymbol] => val newSym = BlockMemberSymbol(lhs.nme, Nil, nameIsMeaningful = true // TODO: lhs.nme is not always meaningful ) + val defn = FunDefn.withFreshSymbol(N, newSym, params :: Nil, body)(false) val blk = blockBuilder - .define(FunDefn.withFreshSymbol(N, newSym, params :: Nil, body)(false)) - .assign(lhs, newSym.asPath) + .define(defn) + .assign(lhs, Value.Ref(newSym, S(defn.dSym))) .rest(rest) (blk, Nil) case _ => - var lambdasList: List[(BlockMemberSymbol, Lambda)] = Nil + var lambdasList: List[((BlockMemberSymbol, TermSymbol), Lambda)] = Nil val lambdaRewriter = new BlockDataTransformer(SymbolSubst()): override def applyResult(r: Result)(k: Result => Block): Block = r match case lam: Lambda => val sym = BlockMemberSymbol("lambda", Nil, nameIsMeaningful = false) - lambdasList ::= (sym -> super.applyLam(lam)) - k(Value.Ref(sym, N)) + val tSym = TermSymbol.fromFunBms(sym, N) + lambdasList ::= ((sym, tSym) -> super.applyLam(lam)) + k(Value.Ref(sym, S(tSym))) case _ => super.applyResult(r)(k) val blk = lambdaRewriter.applyBlock(b) (blk, lambdasList) @@ -39,7 +40,7 @@ object LambdaRewriter: val (newBlk, lambdasList) = rewriteOneBlk(b) val lambdaDefns = lambdasList.map: case (sym, Lambda(params, body)) => - FunDefn.withFreshSymbol(N, sym, params :: Nil, body)(false) + FunDefn(N, sym._1, sym._2, params :: Nil, body)(false) val ret = lambdaDefns.foldLeft(newBlk): case (acc, defn) => Define(defn, acc) super.applyBlock(ret) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index f642716f73..cb1fd6e234 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -110,8 +110,8 @@ object Lifter: object InstSel: def unapply(p: Path) = p match - case Value.Ref(l: BlockMemberSymbol, _) => S(l) - case s @ Select(Value.Ref(l: BlockMemberSymbol, _), Tree.Ident("class")) => S(l) + case Value.Ref(l: BlockMemberSymbol, d) => S((l, d)) + case s @ Select(Value.Ref(l: BlockMemberSymbol, _), Tree.Ident("class")) => S((l, s.symbol)) case _ => N def modOrObj(d: Defn) = d match @@ -359,15 +359,21 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): val imutVars = captureFnVars.vars imutVars.filter: s => !mutVars.contains(s) && candVars.contains(s) - + + case class FunSyms[T <: DefinitionSymbol[?]](b: BlockMemberSymbol, d: T): + def asPath = Value.Ref(b, S(d)) + object FunSyms: + def fromFun(b: BlockMemberSymbol, owner: Opt[InnerSymbol] = N) = + FunSyms(b, TermSymbol.fromFunBms(b, owner)) + // Info required for lifting a definition. case class LiftedInfo( val reqdCaptures: List[BlockMemberSymbol], // The mutable captures a lifted definition must take. val reqdVars: List[Local], // The (passed by value) variables a lifted definition must take. val reqdInnerSyms: List[InnerSymbol], // The inner symbols a lifted definition must take. val reqdBms: List[BlockMemberSymbol], // BMS's belonging to unlifted definitions that this definition references. - val fakeCtorBms: Option[BlockMemberSymbol], // only for classes - val singleCallBms: BlockMemberSymbol, // optimization + val fakeCtorBms: Option[FunSyms[TermSymbol]], // only for classes + val singleCallBms: FunSyms[TermSymbol], // optimization ) case class Lifted[+T <: Defn]( @@ -579,7 +585,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): val info = LiftedInfo( includedCaptures, includedLocals, clsCaptures, - refBms, fakeCtorBms, singleCallBms + refBms, fakeCtorBms.map(FunSyms.fromFun(_)), FunSyms.fromFun(singleCallBms) ) d match @@ -624,27 +630,27 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): // Does *not* rewrite references to non-lifted BMS symbols. def rewriteBms(b: Block) = // BMS's that need to be created - val syms: LinkedHashMap[BlockMemberSymbol, Local] = LinkedHashMap.empty + val syms: LinkedHashMap[FunSyms[?], Local] = LinkedHashMap.empty val walker = new BlockDataTransformer(SymbolSubst()): // only scan within the block. don't traverse override def applyResult(r: Result)(k: Result => Block): Block = r match // if possible, directly rewrite the call using the efficient version - case c @ Call(RefOfBms(l, _), args) => ctx.bmsReqdInfo.get(l) match + case c @ Call(RefOfBms(l, S(d)), args) => ctx.bmsReqdInfo.get(l) match case Some(info) if !ctx.isModOrObj(l) => val extraArgs = ctx.defns.get(l) match // If it's a class, we need to add the isMut parameter. // Instantiation without `new mut` is always immutable - case Some(c: ClsLikeDefn) => Value.Lit(Tree.BoolLit(false)).asArg :: getCallArgs(l, ctx) - case _ => getCallArgs(l, ctx) + case Some(c: ClsLikeDefn) => Value.Lit(Tree.BoolLit(false)).asArg :: getCallArgs(FunSyms(l, d), ctx) + case _ => getCallArgs(FunSyms(l, d), ctx) applyListOf(args, applyArg(_)(_)): newArgs => k(Call(info.singleCallBms.asPath, extraArgs ++ newArgs)(c.isMlsFun, false, c.explicitTailCall)) case _ => super.applyResult(r)(k) - case c @ Instantiate(mut, InstSel(l), args) => + case c @ Instantiate(mut, InstSel(l, S(d)), args) => ctx.bmsReqdInfo.get(l) match case Some(info) if !ctx.isModOrObj(l) => - val extraArgs = Value.Lit(Tree.BoolLit(mut)).asArg :: getCallArgs(l, ctx) + val extraArgs = Value.Lit(Tree.BoolLit(mut)).asArg :: getCallArgs(FunSyms(l, d), ctx) applyListOf(args, applyArg(_)(_)): newArgs => k(Call(info.singleCallBms.asPath, extraArgs ++ newArgs)(true, false, false)) case _ => super.applyResult(r)(k) @@ -658,13 +664,13 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): // extract the call override def applyPath(p: Path)(k: Path => Block): Block = p match - case RefOfBms(l, disamb) if ctx.bmsReqdInfo.contains(l) && !ctx.isModOrObj(l) => + case RefOfBms(l, S(d)) if ctx.bmsReqdInfo.contains(l) && !ctx.isModOrObj(l) => val newSym = closureMap.get(l) match case None => // $this was previously used, but it may be confused with the `this` keyword // let's use $here instead val newSym = TempSymbol(N, l.nme + "$here") - syms.addOne(l -> newSym) // add to `syms`: this closure will be initialized in `applyBlock` + syms.addOne(FunSyms(l, d) -> newSym) // add to `syms`: this closure will be initialized in `applyBlock` closureMap.addOne(l -> newSym) // add to `closureMap`: `newSym` refers to the closure and can be used later newSym @@ -672,9 +678,9 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): case Some(value) if activeClosures.contains(value) => value // symbol exists, needs initialization case Some(value) => - syms.addOne(l -> value) + syms.addOne(FunSyms(l, d) -> value) value - k(Value.Ref(newSym, disamb)) + k(Value.Ref(newSym, S(d))) case _ => super.applyPath(p)(k) (walker.applyBlock(b), syms.toList) end rewriteBms @@ -692,7 +698,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): val pre = syms.foldLeft(blockBuilder): case (blk, (bms, local)) => val initial = blk.assign(local, createCall(bms, ctx)) - ctx.defns(bms) match + ctx.defns(bms.b) match case c: ClsLikeDefn => initial.assignFieldN(local.asPath, Tree.Ident("class"), bms.asPath) case _ => initial @@ -769,7 +775,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): case Some(sym) if !ctx.ignored(d.sym) => ctx.getBmsReqdInfo(d.sym) match case Some(_) => // has args blockBuilder - .assign(sym, Instantiate(mut = false, d.sym.asPath, getCallArgs(d.sym, ctx))) + .assign(sym, Instantiate(mut = false, d.sym.asPath, getCallArgs(FunSyms(d.sym, d.isym), ctx))) .rest(applyBlock(rest)) case None => // has no args // Objects with no parameters are instantiated statically @@ -831,8 +837,8 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): // When calling a lifted function or constructor, we need to pass, as arguments, the local variables, // inner symbols, etc that it needs to access. This function creates those arguments for that in // the correct order. - def getCallArgs(sym: BlockMemberSymbol, ctx: LifterCtx) = - val info = ctx.getBmsReqdInfo(sym).get + def getCallArgs(sym: FunSyms[?], ctx: LifterCtx) = + val info = ctx.getBmsReqdInfo(sym.b).get val localsArgs = info.reqdVars.map(s => ctx.getLocalPath(s).get.asArg) val capturesArgs = info.reqdCaptures.map(ctx.getCapturePath(_).get.asArg) val iSymArgs = info.reqdInnerSyms.map(ctx.getIsymPath(_).get.asArg) @@ -840,12 +846,12 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): bmsArgs ++ iSymArgs ++ localsArgs ++ capturesArgs // This creates a call to a lifted function or constructor. - def createCall(sym: BlockMemberSymbol, ctx: LifterCtx): Call = - val info = ctx.getBmsReqdInfo(sym).get + def createCall(syms: FunSyms[?], ctx: LifterCtx): Call = + val info = ctx.getBmsReqdInfo(syms.b).get val callSym = info.fakeCtorBms match case Some(v) => v - case None => sym - Call(callSym.asPath, getCallArgs(sym, ctx))(false, false, false) + case None => syms + Call(callSym.asPath, getCallArgs(syms, ctx))(false, false, false) /* * Explanation of liftOutDefnCont, liftDefnsInCls, liftDefnsInFn: @@ -947,7 +953,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): .ret(Call(singleCallBms.asPath, args1 ++ args2)(true, false, false)) // TODO: restParams not considered val mainDefn = FunDefn(f.owner, f.sym, f.dSym, PlainParamList(extraParamsCpy) :: headPlistCopy :: Nil, bdy)(false) - val auxDefn = FunDefn.withFreshSymbol(N, singleCallBms, flatPlist, lifted.body)(isTailRec = f.isTailRec) + val auxDefn = FunDefn(N, singleCallBms.b, singleCallBms.d, flatPlist, lifted.body)(isTailRec = f.isTailRec) if ctx.firstClsFns.contains(f.sym) then Lifted(mainDefn, auxDefn :: extras) @@ -1078,7 +1084,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): case Some(value) => (ParamList(value.flags, extraPlist.params ++ value.params, value.restParam), auxPlist) - val auxCtorDefn_ = FunDefn.withFreshSymbol(None, singleCallBms, headParams :: newAuxPlist, bod)(false) + val auxCtorDefn_ = FunDefn(None, singleCallBms.b, singleCallBms.d, headParams :: newAuxPlist, bod)(false) val auxCtorDefn = BlockTransformer(subst).applyFunDefn(auxCtorDefn_) // Lifted(lifted, extras ::: (fakeCtorDefn :: auxCtorDefn :: Nil)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 3235f5de10..575a6af2c6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -473,7 +473,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): lamDef, k(Call( Value.Ref(State.runtimeSymbol).selN(Tree.Ident(if isAnd then "short_and" else "short_or")), - Arg(N, ar1) :: Arg(N, Value.Ref(lamSym, N)) :: Nil + Arg(N, ar1) :: Arg(N, Value.Ref(lamSym, S(lamDef.dSym))) :: Nil )(true, false, false))) else subTerm_nonTail(arg2): ar2 => @@ -602,7 +602,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val lamDef = FunDefn.withFreshSymbol(N, lamSym, paramLists, bodyBlock)(isTailRec = false) Define( lamDef, - k(Value.Ref(lamSym, N))) + k(Value.Ref(lamSym, S(lamDef.dSym)))) case iftrm: st.IfLike => ucs.Normalization(this)(iftrm)(k) @@ -940,7 +940,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case Lambda(params, body) => val lamSym = BlockMemberSymbol("lambda", Nil, false) val lamDef = FunDefn.withFreshSymbol(N, lamSym, params :: Nil, body)(isTailRec = false) - Define(lamDef, k(Value.Ref(lamSym, N))) + Define(lamDef, k(Value.Ref(lamSym, S(lamDef.dSym)))) case r => val l = new TempSymbol(N) Assign(l, r, k(l |> Value.Ref.apply)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 28b0d73882..c1437cf969 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -194,10 +194,10 @@ class TailRecOpt(using State, TL, Raise): .toList val paramSymsArr = ArrayBuffer.from(paramSyms) val dSymIds = scc.funs.map(_.dSym).zipWithIndex.toMap - val bms = + val bms = if scc.funs.size == 1 then scc.funs.head.sym else BlockMemberSymbol(scc.funs.map(_.sym.nme).mkString("_"), Nil, true) - val dSym = + val dSym = if scc.funs.size == 1 then scc.funs.head.dSym else TermSymbol(syntax.Fun, owner, Tree.Ident(bms.nme)) val loopSym = TempSymbol(N, "loopLabel") diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala index 31432b48e3..622e5d8f0d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala @@ -368,7 +368,7 @@ class UsedVarAnalyzer(b: Block, handlerPaths: Opt[HandlerPaths])(using State): handleCalledBms(l) case Instantiate(mut, InstSel(l), args) => args.map(super.applyArg) - handleCalledBms(l) + handleCalledBms(l._1) case _ => super.applyResult(r) override def applyPath(p: Path): Unit = p match diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index 1f9113e758..e44f7a188c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -265,7 +265,9 @@ class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.I override def toString: Str = s"term:${owner.map(o => s"${o}.").getOrElse("")}${id.name}${State.dbgUid(uid)}" def subst(using sub: SymbolSubst): TermSymbol = sub.mapTermSym(this) - +object TermSymbol: + def fromFunBms(b: BlockMemberSymbol, owner: Opt[InnerSymbol])(using State) = + TermSymbol(syntax.Fun, owner, Tree.Ident(b.nme)) sealed trait CtorSymbol extends Symbol: def subst(using sub: SymbolSubst): CtorSymbol = sub.mapCtorSym(this) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index f6d76bdb7c..b817ac55f1 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -328,7 +328,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e // The symbol for the loop label if the term is a `while`. lazy val loopLabel = new TempSymbol(t) lazy val f = new BlockMemberSymbol("while", Nil, false) - lazy val fDSym = new TermSymbol(syntax.Fun, N, Tree.Ident(f.nme)) + lazy val tSym = new TermSymbol(syntax.Fun, N, Tree.Ident(f.nme)) val normalized = tl.scoped("ucs:normalize"): normalize(inputSplit)(using VarSet()) tl.scoped("ucs:normalized"): @@ -339,7 +339,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e lazy val breakRoot = (r: Result) => Assign(l, r, Break(rootBreakLabel)) lazy val assignResult = (r: Result) => Assign(l, r, End()) val loopCont = if config.rewriteWhileLoops - then Return(Call(Value.Ref(f, S(fDSym)), Nil)(true, true, false), false) + then Return(Call(Value.Ref(f, S(tSym)), Nil)(true, true, false), false) else Continue(loopLabel) val cont = if kw === `while` then @@ -399,8 +399,8 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e Select(Value.Ref(State.runtimeSymbol), Tree.Ident("LoopEnd"))(S(State.loopEndSymbol)) val blk = blockBuilder .assign(l, Value.Lit(Tree.UnitLit(false))) - .define(FunDefn(N, f, fDSym, PlainParamList(Nil) :: Nil, Begin(body, Return(loopEnd, false)))(isTailRec = false)) - .assign(loopResult, Call(Value.Ref(f, N), Nil)(true, true, false)) + .define(FunDefn(N, f, tSym, PlainParamList(Nil) :: Nil, Begin(body, Return(loopEnd, false)))(isTailRec = false)) + .assign(loopResult, Call(Value.Ref(f, S(tSym)), Nil)(true, true, false)) if summon[LoweringCtx].mayRet then blk .assign(isReturned, Call(Value.Ref(State.builtinOpsMap("!==")), diff --git a/hkmc2/shared/src/test/mlscript/HkScratch.mls b/hkmc2/shared/src/test/mlscript/HkScratch.mls index ef38e8b363..f5951cad3e 100644 --- a/hkmc2/shared/src/test/mlscript/HkScratch.mls +++ b/hkmc2/shared/src/test/mlscript/HkScratch.mls @@ -9,4 +9,41 @@ // :todo - +:lot +:sjs +let x = [() => 2] +//│ JS (unsanitized): +//│ let x, lambda; +//│ lambda = (undefined, function () { return 2 }); +//│ x = globalThis.Object.freeze([ lambda ]); +//│ Lowered: +//│ Program: +//│ imports = Nil +//│ main = Define: +//│ defn = FunDefn: +//│ owner = N +//│ sym = member:lambda +//│ dSym = term:lambda +//│ params = Ls of +//│ ParamList: +//│ flags = () +//│ params = Nil +//│ restParam = N +//│ body = Return: +//│ res = Lit of IntLit of 2 +//│ implct = false +//│ rest = Assign: \ +//│ lhs = x +//│ rhs = Tuple: +//│ mut = false +//│ elems = Ls of +//│ Arg: +//│ spread = N +//│ value = Ref: +//│ l = member:lambda +//│ disamb = N +//│ rest = Assign: \ +//│ lhs = $block$res +//│ rhs = Lit of UnitLit of false +//│ rest = End of "" +//│ x = [fun] diff --git a/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls b/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls index ec726f6ed4..6ca163f9b8 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls @@ -11,7 +11,87 @@ fun f() = h.perform() 1 f() + f() + f() -//│ = 2 +//│ FAILURE: Unexpected compilation error +//│ FAILURE LOCATION: lookup_! (Scope.scala:112) +//│ FAILURE INFO: Tuple2: +//│ _1 = Tuple2: +//│ _1 = member:doUnwind +//│ _2 = class hkmc2.semantics.BlockMemberSymbol +//│ _2 = Scope: +//│ parent = S of Scope: +//│ parent = S of Scope: +//│ parent = N +//│ curThis = S of S of globalThis:globalThis +//│ bindings = HashMap(member:doUnwind$ -> doUnwind$1, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:hdlrFun$ -> hdlrFun$, $runtime -> runtime, $definitionMetadata -> definitionMetadata, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:handleBlock$ -> handleBlock$, $block$res -> block$res, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:Cont$func$f$ -> Cont$func$f$, class:Handler$h$ -> Handler$h$, member:Cont$func$f$ -> Cont$func$f$1, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, $tmp -> tmp, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:f$ -> f$, member:doUnwind$ -> doUnwind$, member:Cont$func$f$$ -> Cont$func$f$$) +//│ curThis = S of N +//│ bindings = HashMap($args -> args, h -> h) +//│ curThis = N +//│ bindings = HashMap($tmp -> tmp1) +//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' +//│ FAILURE: Unexpected compilation error +//│ FAILURE LOCATION: lookup_! (Scope.scala:112) +//│ FAILURE INFO: Tuple2: +//│ _1 = Tuple2: +//│ _1 = member:doUnwind +//│ _2 = class hkmc2.semantics.BlockMemberSymbol +//│ _2 = Scope: +//│ parent = S of Scope: +//│ parent = S of Scope: +//│ parent = N +//│ curThis = S of S of globalThis:globalThis +//│ bindings = HashMap(member:doUnwind$ -> doUnwind$1, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:hdlrFun$ -> hdlrFun$, $runtime -> runtime, $definitionMetadata -> definitionMetadata, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:handleBlock$ -> handleBlock$, $block$res -> block$res, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:Cont$func$f$ -> Cont$func$f$, class:Handler$h$ -> Handler$h$, member:Cont$func$f$ -> Cont$func$f$1, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, $tmp -> tmp, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:f$ -> f$, member:doUnwind$ -> doUnwind$, member:Cont$func$f$$ -> Cont$func$f$$) +//│ curThis = S of N +//│ bindings = HashMap($args -> args) +//│ curThis = N +//│ bindings = HashMap(h -> h, $tmp -> tmp1, $tmp -> tmp2, $tmp -> tmp3, $tmp -> tmp4) +//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' +//│ FAILURE: Unexpected compilation error +//│ FAILURE LOCATION: lookup_! (Scope.scala:112) +//│ FAILURE INFO: Tuple2: +//│ _1 = Tuple2: +//│ _1 = member:doUnwind +//│ _2 = class hkmc2.semantics.BlockMemberSymbol +//│ _2 = Scope: +//│ parent = S of Scope: +//│ parent = S of Scope: +//│ parent = N +//│ curThis = S of S of globalThis:globalThis +//│ bindings = HashMap(member:doUnwind$ -> doUnwind$1, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:hdlrFun$ -> hdlrFun$, $runtime -> runtime, $definitionMetadata -> definitionMetadata, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:handleBlock$ -> handleBlock$, $block$res -> block$res, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:Cont$func$f$ -> Cont$func$f$, class:Handler$h$ -> Handler$h$, member:Cont$func$f$ -> Cont$func$f$1, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, $tmp -> tmp, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:f$ -> f$, member:doUnwind$ -> doUnwind$, member:Cont$func$f$$ -> Cont$func$f$$) +//│ curThis = S of N +//│ bindings = HashMap($args -> args) +//│ curThis = N +//│ bindings = HashMap(h -> h, $tmp -> tmp1, $tmp -> tmp2, $tmp -> tmp3, $tmp -> tmp4) +//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' +//│ FAILURE: Unexpected compilation error +//│ FAILURE LOCATION: lookup_! (Scope.scala:112) +//│ FAILURE INFO: Tuple2: +//│ _1 = Tuple2: +//│ _1 = member:doUnwind +//│ _2 = class hkmc2.semantics.BlockMemberSymbol +//│ _2 = Scope: +//│ parent = S of Scope: +//│ parent = S of Scope: +//│ parent = N +//│ curThis = S of S of globalThis:globalThis +//│ bindings = HashMap(member:doUnwind$ -> doUnwind$1, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:hdlrFun$ -> hdlrFun$, $runtime -> runtime, $definitionMetadata -> definitionMetadata, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:handleBlock$ -> handleBlock$, $block$res -> block$res, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:Cont$func$f$ -> Cont$func$f$, class:Handler$h$ -> Handler$h$, member:Cont$func$f$ -> Cont$func$f$1, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, $tmp -> tmp, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:f$ -> f$, member:doUnwind$ -> doUnwind$, member:Cont$func$f$$ -> Cont$func$f$$) +//│ curThis = S of N +//│ bindings = HashMap($args -> args) +//│ curThis = N +//│ bindings = HashMap(h -> h, $tmp -> tmp1, $tmp -> tmp2, $tmp -> tmp3, $tmp -> tmp4) +//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' +//│ FAILURE: Unexpected runtime error +//│ FAILURE LOCATION: mkQuery (JSBackendDiffMaker.scala:159) +//│ ═══[RUNTIME ERROR] ReferenceError: doUnwind is not defined +//│ at f$ (REPL10:1:5833) +//│ at handleBlock$ (REPL10:1:6072) +//│ at REPL10:1:6523 +//│ at ContextifyScript.runInThisContext (node:vm:137:12) +//│ at REPLServer.defaultEval (node:repl:562:24) +//│ at bound (node:domain:433:15) +//│ at REPLServer.runBound [as eval] (node:domain:444:12) +//│ at REPLServer.onLine (node:repl:886:12) +//│ at REPLServer.emit (node:events:508:28) +//│ at REPLServer.emit (node:domain:489:12) :expect 1 fun f(x) = @@ -226,7 +306,53 @@ fun sum(n) = else n + sum(n - 1) sum(100) -//│ = 5050 +//│ FAILURE: Unexpected compilation error +//│ FAILURE LOCATION: lookup_! (Scope.scala:112) +//│ FAILURE INFO: Tuple2: +//│ _1 = Tuple2: +//│ _1 = member:doUnwind +//│ _2 = class hkmc2.semantics.BlockMemberSymbol +//│ _2 = Scope: +//│ parent = S of Scope: +//│ parent = S of Scope: +//│ parent = N +//│ curThis = S of S of globalThis:globalThis +//│ bindings = HashMap($tmp -> tmp2, member:Test$ -> Test$1, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:g$ -> g$, $runtime -> runtime, member:h$ -> h$, $definitionMetadata -> definitionMetadata, class:Test -> Test6, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:Cont$func$sum$ -> Cont$func$sum$1, $block$res -> block$res5, $res -> res, $tmp -> tmp4, member:Test$ -> Test$3, member:‹stack safe body› -> $_stack$_safe$_body$_, $block$res -> block$res, member:g$ -> g$2, member:doUnwind$ -> doUnwind$2, member:A -> A3, member:h$ -> h$2, member:f -> f6, member:Cont$func$sum$$ -> Cont$func$sum$$, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:A -> A2, class:Handler$h$ -> Handler$h$, member:Test -> Test5, member:f -> f2, a -> a, member:Test -> Test1, member:f -> f, b -> b, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, class:Test -> Test, $tmp -> tmp, member:f$ -> f$, $block$res -> block$res8, member:doUnwind$ -> doUnwind$, $tmp -> tmp6, $tmp -> tmp7, class:Test -> Test4, member:A -> A1, member:Cont$func$f$$ -> Cont$func$f$$, member:f -> f4, member:A$ -> A$1, member:doUnwind$ -> doUnwind$1, class:A -> A, class:f$capture -> f$capture4, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, $block$res -> block$res2, $tmp -> tmp1, member:f$capture -> f$capture5, member:hdlrFun$ -> hdlrFun$, member:Test$ -> Test$, $block$res -> block$res4, $tmp -> tmp3, $block$res -> block$res6, member:Test$ -> Test$2, member:g$ -> g$1, member:A$ -> A$, $block$res -> block$res7, member:handleBlock$ -> handleBlock$, member:h$ -> h$1, class:f$capture -> f$capture, $tmp -> tmp5, member:Bad$ -> Bad$, member:f$capture -> f$capture1, member:Good$ -> Good$, member:Test -> Test3, member:sum -> sum, member:f -> f1, class:f$capture -> f$capture2, class:Cont$func$f$ -> Cont$func$f$, member:f$capture -> f$capture3, $block$res -> block$res9, member:Bad -> Bad1, member:Good -> Good1, class:Test -> Test2, member:f -> f5, member:Cont$func$f$ -> Cont$func$f$1, class:Good -> Good, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:Test -> Test7, member:f -> f3, class:Bad -> Bad, class:Cont$func$sum$ -> Cont$func$sum$, $block$res -> block$res3) +//│ curThis = S of N +//│ bindings = HashMap($args -> args, n -> n) +//│ curThis = N +//│ bindings = HashMap($scrut -> scrut, $curDepth -> curDepth, $stackDelayRes -> stackDelayRes, $tmp -> tmp8, $tmp -> tmp9) +//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' +//│ FAILURE: Unexpected compilation error +//│ FAILURE LOCATION: lookup_! (Scope.scala:112) +//│ FAILURE INFO: Tuple2: +//│ _1 = Tuple2: +//│ _1 = member:doUnwind +//│ _2 = class hkmc2.semantics.BlockMemberSymbol +//│ _2 = Scope: +//│ parent = S of Scope: +//│ parent = S of Scope: +//│ parent = N +//│ curThis = S of S of globalThis:globalThis +//│ bindings = HashMap($tmp -> tmp2, member:Test$ -> Test$1, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:g$ -> g$, $runtime -> runtime, member:h$ -> h$, $definitionMetadata -> definitionMetadata, class:Test -> Test6, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:Cont$func$sum$ -> Cont$func$sum$1, $block$res -> block$res5, $res -> res, $tmp -> tmp4, member:Test$ -> Test$3, member:‹stack safe body› -> $_stack$_safe$_body$_, $block$res -> block$res, member:g$ -> g$2, member:doUnwind$ -> doUnwind$2, member:A -> A3, member:h$ -> h$2, member:f -> f6, member:Cont$func$sum$$ -> Cont$func$sum$$, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:A -> A2, class:Handler$h$ -> Handler$h$, member:Test -> Test5, member:f -> f2, a -> a, member:Test -> Test1, member:f -> f, b -> b, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, class:Test -> Test, $tmp -> tmp, member:f$ -> f$, $block$res -> block$res8, member:doUnwind$ -> doUnwind$, $tmp -> tmp6, $tmp -> tmp7, class:Test -> Test4, member:A -> A1, member:Cont$func$f$$ -> Cont$func$f$$, member:f -> f4, member:A$ -> A$1, member:doUnwind$ -> doUnwind$1, class:A -> A, class:f$capture -> f$capture4, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, $block$res -> block$res2, $tmp -> tmp1, member:f$capture -> f$capture5, member:hdlrFun$ -> hdlrFun$, member:Test$ -> Test$, $block$res -> block$res4, $tmp -> tmp3, $block$res -> block$res6, member:Test$ -> Test$2, member:g$ -> g$1, member:A$ -> A$, $block$res -> block$res7, member:handleBlock$ -> handleBlock$, member:h$ -> h$1, class:f$capture -> f$capture, $tmp -> tmp5, member:Bad$ -> Bad$, member:f$capture -> f$capture1, member:Good$ -> Good$, member:Test -> Test3, member:sum -> sum, member:f -> f1, class:f$capture -> f$capture2, class:Cont$func$f$ -> Cont$func$f$, member:f$capture -> f$capture3, $block$res -> block$res9, member:Bad -> Bad1, member:Good -> Good1, class:Test -> Test2, member:f -> f5, member:Cont$func$f$ -> Cont$func$f$1, class:Good -> Good, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:Test -> Test7, member:f -> f3, class:Bad -> Bad, class:Cont$func$sum$ -> Cont$func$sum$, $block$res -> block$res3) +//│ curThis = S of N +//│ bindings = HashMap($args -> args, n -> n) +//│ curThis = N +//│ bindings = HashMap($scrut -> scrut, $curDepth -> curDepth, $stackDelayRes -> stackDelayRes, $tmp -> tmp8, $tmp -> tmp9) +//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' +//│ FAILURE: Unexpected runtime error +//│ FAILURE LOCATION: mkQuery (JSBackendDiffMaker.scala:159) +//│ ═══[RUNTIME ERROR] ReferenceError: doUnwind is not defined +//│ at sum (REPL36:1:2283) +//│ at sum (REPL36:1:2438) +//│ at sum (REPL36:1:2438) +//│ at sum (REPL36:1:2438) +//│ at sum (REPL36:1:2438) +//│ at sum (REPL36:1:2438) +//│ at sum (REPL36:1:2438) +//│ at sum (REPL36:1:2438) +//│ at sum (REPL36:1:2438) +//│ at $_stack$_safe$_body$_ (REPL36:1:2750) // instance checks diff --git a/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls b/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls index ceac941c04..ff48e0ad97 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls @@ -3,11 +3,94 @@ fun f() = 3 +:sjs :effectHandlers module A with data class Test with f() val a = 1 +//│ JS (unsanitized): +//│ let A1, Cont$ctor$Test$1, doUnwind$, Cont$ctor$Test$$; +//│ Cont$ctor$Test$$ = function Cont$ctor$Test$$(isMut, Test$instance, tmp, pc) { +//│ let tmp1, tmp2; +//│ if (isMut === true) { +//│ tmp1 = new Cont$ctor$Test$1(pc); +//│ } else { +//│ tmp1 = globalThis.Object.freeze(new Cont$ctor$Test$1(pc)); +//│ } +//│ tmp2 = tmp1(Test$instance, tmp); +//│ return tmp2 +//│ }; +//│ globalThis.Object.freeze(class Cont$ctor$Test$ extends runtime.FunctionContFrame.class { +//│ static { +//│ Cont$ctor$Test$1 = this +//│ } +//│ constructor(pc) { +//│ return (Test$instance, tmp) => { +//│ let tmp1; +//│ tmp1 = super(null); +//│ this.tmp = tmp; +//│ this.Test$instance = Test$instance; +//│ this.pc = pc; +//│ return this; +//│ } +//│ } +//│ #tmp; +//│ #Test$instance; +//│ get tmp() { return this.#tmp; } +//│ set tmp(value) { this.#tmp = value; } +//│ get Test$instance() { return this.#Test$instance; } +//│ set Test$instance(value) { this.#Test$instance = value; } +//│ resume(value$) { +//│ if (this.pc === 1) { +//│ this.tmp = value$; +//│ } +//│ contLoop: while (true) { +//│ if (this.pc === 2) { +//│ return this.Test$instance +//│ } else if (this.pc === 1) { +//│ this.Test$instance.a = 1; +//│ this.pc = 2; +//│ continue contLoop +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Cont$ctor$Test$"]; +//│ }); +//│ doUnwind$ = function doUnwind$(Test$instance, tmp, res, pc) { +//│ res.contTrace.last.next = new Cont$ctor$Test$1(pc); +//│ res.contTrace.last = res.contTrace.last.next; +//│ return res +//│ }; +//│ globalThis.Object.freeze(class A { +//│ static { +//│ A1 = this +//│ } +//│ constructor() { +//│ runtime.Unit; +//│ } +//│ static { +//│ globalThis.Object.freeze(class Test { +//│ static { +//│ A.Test = this +//│ } +//│ constructor() { +//│ let tmp; +//│ tmp = f(); +//│ if (tmp instanceof runtime.EffectSig.class) { +//│ return doUnwind$(this, tmp, tmp, 1) +//│ } +//│ this.a = 1; +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Test"]; +//│ }); +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "A"]; +//│ }); :effectHandlers class Test with diff --git a/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls b/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls index 8b772c4748..b71744b065 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls @@ -236,7 +236,10 @@ handle h = Eff with else n + f(n-1) resume(f(10000)) foo(h) -//│ = 50005000 +//│ FAILURE: Unexpected runtime error +//│ FAILURE LOCATION: processTerm (JSBackendDiffMaker.scala:208) +//│ ═══[RUNTIME ERROR] Expected: '50005000', got: 'fun' +//│ = fun // function call and defn inside handler :effectHandlers @@ -252,7 +255,10 @@ handle h = Eff with in fun foo(h) = h.perform foo(h) -//│ = 50005000 +//│ FAILURE: Unexpected runtime error +//│ FAILURE LOCATION: processTerm (JSBackendDiffMaker.scala:208) +//│ ═══[RUNTIME ERROR] Expected: '50005000', got: 'fun' +//│ = fun :re :effectHandlers @@ -265,7 +271,8 @@ handle h = Eff with else n + f(n-1) resume(f(10000)) foo(h) -//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded +//│ = fun +//│ FAILURE: Unexpected lack of runtime error :effectHandlers :stackSafe diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls index b6247044f0..6e3d249731 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls @@ -16,6 +16,17 @@ fun g(x) = @tailcall f(x) //│ ║ l.14: fun g(x) = @tailcall f(x) //│ ╙── ^^^^ +:e +@tailrec fun f(x) = + g(x) +fun g(x) = f(x); f(x) +//│ ╔══[ERROR] `f` is not tail recursive. +//│ ║ l.20: @tailrec fun f(x) = +//│ ║ ^ +//│ ╟── It could self-recurse through this call, which is not a tail call. +//│ ║ l.22: fun g(x) = f(x); f(x) +//│ ╙── ^^^^ + :e @tailrec fun f(x) = g(x) @@ -25,10 +36,10 @@ fun g(x) = fun h(x) = g(x) //│ ╔══[ERROR] `f` is not tail recursive. -//│ ║ l.20: @tailrec fun f(x) = +//│ ║ l.31: @tailrec fun f(x) = //│ ║ ^ //│ ╟── It could self-recurse through this call, which is not a tail call. -//│ ║ l.23: f(x) +//│ ║ l.34: f(x) //│ ╙── ^^^^ :e @@ -40,8 +51,8 @@ fun g(x) = fun h(x) = f(x) //│ ╔══[ERROR] `f` is not tail recursive. -//│ ║ l.35: @tailrec fun f(x) = +//│ ║ l.46: @tailrec fun f(x) = //│ ║ ^ //│ ╟── It could self-recurse through this call, which is not a tail call. -//│ ║ l.38: h(x) +//│ ║ l.49: h(x) //│ ╙── ^^^^ diff --git a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls index baf7403e56..781834c3f6 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls @@ -65,34 +65,36 @@ f(10000, 20000, 0) :sjs :expect 200010000 -class A with +module A with fun sum_impl(n, acc) = if n == 0 then acc else @tailcall sum_impl(n - 1, n + acc) fun sum(n) = sum_impl(n, 0) -(new A).sum(20000) +A.sum(20000) //│ JS (unsanitized): -//│ let A1, tmp; +//│ let A1; //│ globalThis.Object.freeze(class A { //│ static { //│ A1 = this //│ } -//│ constructor() {} -//│ sum(n) { +//│ constructor() { +//│ runtime.Unit; +//│ } +//│ static sum(n) { //│ loopLabel: while (true) { -//│ return this.sum_impl(n, 0); +//│ return A.sum_impl(n, 0); //│ break; //│ } //│ } -//│ sum_impl(n, acc) { -//│ let scrut, tmp1, tmp2, id; +//│ static sum_impl(n, acc) { +//│ let scrut, tmp, tmp1, id; //│ loopLabel: while (true) { //│ scrut = n == 0; //│ if (scrut === true) { //│ return acc //│ } else { -//│ tmp1 = n - 1; -//│ tmp2 = n + acc; -//│ n = tmp1; -//│ acc = tmp2; +//│ tmp = n - 1; +//│ tmp1 = n + acc; +//│ n = tmp; +//│ acc = tmp1; //│ id = 0; //│ continue loopLabel //│ } @@ -102,8 +104,7 @@ class A with //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "A"]; //│ }); -//│ tmp = globalThis.Object.freeze(new A1()); -//│ tmp.sum(20000) +//│ A1.sum(20000) //│ = 200010000 :silent @@ -139,7 +140,7 @@ f(100) //│ return f_g(0, x1, z, undefined) //│ }; //│ f_g = function f_g(id, param0, param1, param2) { -//│ let scrut, scrut1, tmp5, tmp6; +//│ let scrut, scrut1, tmp4, tmp5; //│ loopLabel: while (true) { //│ if (id === 0) { //│ scrut = param0 < 0; @@ -160,15 +161,15 @@ f(100) //│ if (scrut1 === true) { //│ return 0 //│ } else { -//│ tmp5 = param0 - 1; -//│ tmp6 = globalThis.Object.freeze([ +//│ tmp4 = param0 - 1; +//│ tmp5 = globalThis.Object.freeze([ //│ param0, //│ param0, //│ param0 //│ ]); -//│ param0 = tmp5; +//│ param0 = tmp4; //│ param1 = 0; -//│ param2 = tmp6; +//│ param2 = tmp5; //│ id = 0; //│ continue loopLabel //│ } @@ -200,7 +201,7 @@ fun f(x) = //│ FAILURE: Unexpected type error //│ FAILURE LOCATION: cg (TailRecOpt.scala:88) //│ ╔══[ERROR] This tail call exits the current scope and cannot be optimized. -//│ ║ l.199: @tailcall g(x) +//│ ║ l.200: @tailcall g(x) //│ ╙── ^^^^ // the lifter doesn't propagate the definition symbols properly @@ -212,19 +213,19 @@ fun f(x) = @tailcall g() //│ ═══[ERROR] Only direct calls in tail position may be marked @tailcall. //│ ╔══[ERROR] This call is not optimized as it does not directly recurse through its parent function. -//│ ║ l.211: @tailcall f(x) +//│ ║ l.212: @tailcall f(x) //│ ╙── ^ :todo :lift -class A with +module A with fun f(x) = fun g(x) = @tailcall f(x) @tailcall g(x) //│ ╔══[ERROR] This tail call exits the current scope and cannot be optimized. -//│ ║ l.223: @tailcall f(x) +//│ ║ l.224: @tailcall f(x) //│ ╙── ^^^ -//│ ╔══[ERROR] Only direct calls in tail position may be marked @tailcall. -//│ ║ l.224: @tailcall g(x) -//│ ╙── ^ +//│ ╔══[ERROR] This tail call exits the current scope and cannot be optimized. +//│ ║ l.225: @tailcall g(x) +//│ ╙── ^^^^ From 934e2c8fe6ad595b6fa150d3063d1f6fd85f2128 Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Sat, 29 Nov 2025 20:34:34 +0800 Subject: [PATCH 04/16] changes to lifter and handler lowering --- hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala | 1 + .../src/main/scala/hkmc2/codegen/HandlerLowering.scala | 4 ++-- .../src/main/scala/hkmc2/codegen/LambdaRewriter.scala | 2 +- hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala | 6 +++--- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 6b17fbca2e..42f2a275f6 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -346,6 +346,7 @@ final case class FunDefn( val isTailRec: Bool, ) extends Defn: val innerSym = N + val asPath = Value.Ref(sym, S(dSym)) object FunDefn: def withFreshSymbol(owner: Opt[InnerSymbol], sym: BlockMemberSymbol, params: Ls[ParamList], body: Block)(isTailRec: Bool)(using State) = FunDefn(owner, sym, TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme)), params, body)(isTailRec) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index 2d8795aacb..c3a2c707bf 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -476,7 +476,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, // all the code in the first state private def translateBlock(b: Block, extraLocals: Set[Local], callSelf: Opt[Result], fnOrCls: FnOrCls, h: HandlerCtx): Block = val getLocalsFn = createGetLocalsFn(b, extraLocals)(using h) - given HandlerCtx = h.nestDebugScope(b.userDefinedVars ++ extraLocals, getLocalsFn.sym.asPath) + given HandlerCtx = h.nestDebugScope(b.userDefinedVars ++ extraLocals, getLocalsFn.asPath) val stage1 = firstPass(b) val stage2 = secondPass(stage1, fnOrCls, callSelf, getLocalsFn) if h.isTopLevel then stage2 else thirdPass(stage2) @@ -557,7 +557,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, doUnwindBlk )(false) - val doUnwindPath: Path = Value.Ref(doUnwindSym, S(doUnwindDef.dSym)) + val doUnwindPath: Path = doUnwindDef.asPath doUnwindMap += fnOrCls -> doUnwindPath val doUnwindLazy = Lazy(doUnwindPath) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala index 2cddb74add..cec6c76a24 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala @@ -19,7 +19,7 @@ object LambdaRewriter: val defn = FunDefn.withFreshSymbol(N, newSym, params :: Nil, body)(false) val blk = blockBuilder .define(defn) - .assign(lhs, Value.Ref(newSym, S(defn.dSym))) + .assign(lhs, defn.asPath) .rest(rest) (blk, Nil) case _ => diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 575a6af2c6..c5cc7cb134 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -473,7 +473,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): lamDef, k(Call( Value.Ref(State.runtimeSymbol).selN(Tree.Ident(if isAnd then "short_and" else "short_or")), - Arg(N, ar1) :: Arg(N, Value.Ref(lamSym, S(lamDef.dSym))) :: Nil + Arg(N, ar1) :: Arg(N, lamDef.asPath) :: Nil )(true, false, false))) else subTerm_nonTail(arg2): ar2 => @@ -602,7 +602,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val lamDef = FunDefn.withFreshSymbol(N, lamSym, paramLists, bodyBlock)(isTailRec = false) Define( lamDef, - k(Value.Ref(lamSym, S(lamDef.dSym)))) + k(lamDef.asPath)) case iftrm: st.IfLike => ucs.Normalization(this)(iftrm)(k) @@ -940,7 +940,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case Lambda(params, body) => val lamSym = BlockMemberSymbol("lambda", Nil, false) val lamDef = FunDefn.withFreshSymbol(N, lamSym, params :: Nil, body)(isTailRec = false) - Define(lamDef, k(Value.Ref(lamSym, S(lamDef.dSym)))) + Define(lamDef, k(lamDef.asPath)) case r => val l = new TempSymbol(N) Assign(l, r, k(l |> Value.Ref.apply)) From ee5a6c426e801d58ddc2ea90f6ef4929b458311a Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Sat, 29 Nov 2025 21:01:32 +0800 Subject: [PATCH 05/16] fixes --- .../scala/hkmc2/codegen/HandlerLowering.scala | 13 +- hkmc2/shared/src/test/mlscript/HkScratch.mls | 292 +++++++++++++++--- .../test/mlscript/lifter/StackSafetyLift.mls | 13 +- 3 files changed, 266 insertions(+), 52 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index c3a2c707bf..3a99d81ed4 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -595,7 +595,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case pList :: Nil => val params = pList.params.map(p => p.sym.asPath.asArg) f.owner match - case None => S(Call(f.sym.asPath, params)(true, true, false)) + case None => S(Call(f.asPath, params)(true, true, false)) case Some(owner) => S(Call(Select(owner.asPath, Tree.Ident(f.sym.nme))(N), params)(true, true, false)) case _ => None // TODO: more than one plist @@ -634,11 +634,12 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, // Handle block becomes a FunDefn and CallPlaceholder private def translateHandleBlock(h: HandleBlock)(using HandlerCtx): Block = val sym = BlockMemberSymbol(s"handleBlock$$", Nil) + val tSym = TermSymbol.fromFunBms(sym, N) val lbl = freshTmp("handlerBody") val lblLoop = freshTmp("handlerLoop") val handlerBody = translateBlock( - h.body, Set.empty, S(Call(sym.asPath, Nil)(true, false, false)), L(sym), + h.body, Set.empty, S(Call(Value.Ref(sym, S(tSym)), Nil)(true, false, false)), L(sym), HandlerCtx( false, true, s"Cont$$handleBlock$$${symToStr(h.lhs)}$$", N, @@ -665,7 +666,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, handler.params, Define( fDef, - Return(PureCall(paths.mkEffectPath, h.cls.asPath :: Value.Ref(sym, N) :: Nil), false)))(false) + Return(PureCall(paths.mkEffectPath, h.cls.asPath :: fDef.asPath :: Nil), false)))(false) // Some limited handling of effects extending classes and having access to their fields. // Currently does not support super() raising effects. @@ -698,14 +699,14 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, .assign(h.lhs, Instantiate(mut = true, Value.Ref(clsDefn.sym, S(h.cls)), Nil)) .rest(handlerBody) - val defn = FunDefn.withFreshSymbol( + val defn = FunDefn( N, // no owner - sym, PlainParamList(Nil) :: Nil, body)(false) + sym, tSym, PlainParamList(Nil) :: Nil, body)(false) val result = blockBuilder .define(defn) .rest( - ResultPlaceholder(h.res, freshId(), Call(sym.asPath, Nil)(true, true, false), h.rest) + ResultPlaceholder(h.res, freshId(), Call(defn.asPath, Nil)(true, true, false), h.rest) ) result diff --git a/hkmc2/shared/src/test/mlscript/HkScratch.mls b/hkmc2/shared/src/test/mlscript/HkScratch.mls index f5951cad3e..c191173195 100644 --- a/hkmc2/shared/src/test/mlscript/HkScratch.mls +++ b/hkmc2/shared/src/test/mlscript/HkScratch.mls @@ -8,42 +8,262 @@ // :d // :todo +class Eff -:lot +:effectHandlers :sjs -let x = [() => 2] +:lift +fun foo(h) = h.perform +handle h = Eff with + fun perform(resume) = + let fuck = () + set fuck = n => + if n <= 0 then 0 + else n + fuck(n-1) + resume(fuck(10000)) +foo(h) //│ JS (unsanitized): -//│ let x, lambda; -//│ lambda = (undefined, function () { return 2 }); -//│ x = globalThis.Object.freeze([ lambda ]); -//│ Lowered: -//│ Program: -//│ imports = Nil -//│ main = Define: -//│ defn = FunDefn: -//│ owner = N -//│ sym = member:lambda -//│ dSym = term:lambda -//│ params = Ls of -//│ ParamList: -//│ flags = () -//│ params = Nil -//│ restParam = N -//│ body = Return: -//│ res = Lit of IntLit of 2 -//│ implct = false -//│ rest = Assign: \ -//│ lhs = x -//│ rhs = Tuple: -//│ mut = false -//│ elems = Ls of -//│ Arg: -//│ spread = N -//│ value = Ref: -//│ l = member:lambda -//│ disamb = N -//│ rest = Assign: \ -//│ lhs = $block$res -//│ rhs = Lit of UnitLit of false -//│ rest = End of "" -//│ x = [fun] +//│ let foo, lambda, tmp, handleBlock$, Cont$handleBlock$h$1, hdlrFun, Cont$func$lambda$1, Cont$handler$h$perform$1, Handler$h$1, doUnwind$, Cont$handleBlock$h$$, hdlrFun$, lambda$, doUnwind$1, Cont$func$lambda$$, doUnwind$2, Cont$handler$h$perform$$, hdlrFun$capture1; +//│ foo = function foo(h) { +//│ return h.perform +//│ }; +//│ Cont$handler$h$perform$$ = function Cont$handler$h$perform$$(isMut, Handler$h$$instance, resume, tmp1, pc) { +//│ let tmp2, tmp3; +//│ if (isMut === true) { +//│ tmp2 = new Cont$handler$h$perform$1(pc); +//│ } else { +//│ tmp2 = globalThis.Object.freeze(new Cont$handler$h$perform$1(pc)); +//│ } +//│ tmp3 = tmp2(Handler$h$$instance, resume, tmp1); +//│ return tmp3 +//│ }; +//│ globalThis.Object.freeze(class Cont$handler$h$perform$ extends runtime.FunctionContFrame.class { +//│ static { +//│ Cont$handler$h$perform$1 = this +//│ } +//│ constructor(pc) { +//│ return (Handler$h$$instance, resume, tmp1) => { +//│ let tmp2; +//│ tmp2 = super(null); +//│ this.resume = resume; +//│ this.tmp = tmp1; +//│ this.Handler$h$$instance = Handler$h$$instance; +//│ this.pc = pc; +//│ return this; +//│ } +//│ } +//│ #resume; +//│ #tmp; +//│ #Handler$h$$instance; +//│ get resume() { return this.#resume; } +//│ set resume(value) { this.#resume = value; } +//│ get tmp() { return this.#tmp; } +//│ set tmp(value) { this.#tmp = value; } +//│ get Handler$h$$instance() { return this.#Handler$h$$instance; } +//│ set Handler$h$$instance(value) { this.#Handler$h$$instance = value; } +//│ resume(value$) { +//│ if (this.pc === 5) { +//│ this.tmp = value$; +//│ } +//│ contLoop: while (true) { +//│ if (this.pc === 6) { +//│ return runtime.safeCall(this.resume(this.tmp)) +//│ } else if (this.pc === 5) { +//│ this.pc = 6; +//│ continue contLoop +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Cont$handler$h$perform$"]; +//│ }); +//│ doUnwind$2 = function doUnwind$(Handler$h$$instance, resume, tmp1, res, pc) { +//│ res.contTrace.last.next = Cont$handler$h$perform$$(true, Handler$h$$instance, resume, tmp1, pc); +//│ res.contTrace.last = res.contTrace.last.next; +//│ return res +//│ }; +//│ Cont$func$lambda$$ = function Cont$func$lambda$$(isMut, Handler$h$$instance, n, tmp1, pc) { +//│ let tmp2, tmp3; +//│ if (isMut === true) { +//│ tmp2 = new Cont$func$lambda$1(pc); +//│ } else { +//│ tmp2 = globalThis.Object.freeze(new Cont$func$lambda$1(pc)); +//│ } +//│ tmp3 = tmp2(Handler$h$$instance, n, tmp1); +//│ return tmp3 +//│ }; +//│ globalThis.Object.freeze(class Cont$func$lambda$ extends runtime.FunctionContFrame.class { +//│ static { +//│ Cont$func$lambda$1 = this +//│ } +//│ constructor(pc) { +//│ return (Handler$h$$instance, n, tmp1) => { +//│ let tmp2; +//│ tmp2 = super(null); +//│ this.n = n; +//│ this.tmp = tmp1; +//│ this.Handler$h$$instance = Handler$h$$instance; +//│ this.pc = pc; +//│ return this; +//│ } +//│ } +//│ #n; +//│ #tmp; +//│ #Handler$h$$instance; +//│ get n() { return this.#n; } +//│ set n(value) { this.#n = value; } +//│ get tmp() { return this.#tmp; } +//│ set tmp(value) { this.#tmp = value; } +//│ get Handler$h$$instance() { return this.#Handler$h$$instance; } +//│ set Handler$h$$instance(value) { this.#Handler$h$$instance = value; } +//│ resume(value$) { +//│ if (this.pc === 2) { +//│ this.tmp = value$; +//│ } +//│ contLoop: while (true) { +//│ if (this.pc === 2) { +//│ return this.n + this.tmp +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Cont$func$lambda$"]; +//│ }); +//│ doUnwind$1 = function doUnwind$(Handler$h$$instance, n, tmp1, res, pc) { +//│ res.contTrace.last.next = Cont$func$lambda$$(true, Handler$h$$instance, n, tmp1, pc); +//│ res.contTrace.last = res.contTrace.last.next; +//│ return res +//│ }; +//│ lambda$ = function lambda$(Handler$h$$instance, hdlrFun$capture2, n) { +//│ let scrut, tmp1, tmp2; +//│ scrut = n <= 0; +//│ if (scrut === true) { +//│ return 0 +//│ } else { +//│ tmp1 = n - 1; +//│ tmp2 = runtime.safeCall(hdlrFun$capture2.fuck$capture$0(tmp1)); +//│ if (tmp2 instanceof runtime.EffectSig.class) { +//│ return doUnwind$1(Handler$h$$instance, n, tmp2, tmp2, 2) +//│ } +//│ return n + tmp2 +//│ } +//│ }; +//│ lambda = (undefined, function (Handler$h$$instance, hdlrFun$capture2) { +//│ return (n) => { +//│ return lambda$(Handler$h$$instance, hdlrFun$capture2, n) +//│ } +//│ }); +//│ globalThis.Object.freeze(class hdlrFun$capture { +//│ static { +//│ hdlrFun$capture1 = this +//│ } +//│ constructor(fuck$capture$0) { +//│ this.fuck$capture$0 = fuck$capture$0; +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "hdlrFun$capture"]; +//│ }); +//│ hdlrFun$ = function hdlrFun$(Handler$h$$instance, resume) { +//│ let tmp1, capture, lambda$here; +//│ capture = new hdlrFun$capture1(null); +//│ capture.fuck$capture$0 = runtime.Unit; +//│ lambda$here = runtime.safeCall(lambda(Handler$h$$instance, capture)); +//│ capture.fuck$capture$0 = lambda$here; +//│ tmp1 = runtime.safeCall(capture.fuck$capture$0(10000)); +//│ if (tmp1 instanceof runtime.EffectSig.class) { +//│ return doUnwind$2(Handler$h$$instance, resume, tmp1, tmp1, 5) +//│ } +//│ return runtime.safeCall(resume(tmp1)) +//│ }; +//│ hdlrFun = function hdlrFun(Handler$h$$instance) { +//│ return (resume) => { +//│ return hdlrFun$(Handler$h$$instance, resume) +//│ } +//│ }; +//│ globalThis.Object.freeze(class Handler$h$ extends Eff1 { +//│ static { +//│ Handler$h$1 = this +//│ } +//│ constructor() { +//│ let tmp1; +//│ tmp1 = super(); +//│ } +//│ get perform() { +//│ let hdlrFun$here; +//│ hdlrFun$here = runtime.safeCall(hdlrFun(this)); +//│ return runtime.mkEffect(this, hdlrFun$here); +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Handler$h$"]; +//│ }); +//│ Cont$handleBlock$h$$ = function Cont$handleBlock$h$$(isMut, res, pc) { +//│ let tmp1, tmp2; +//│ if (isMut === true) { +//│ tmp1 = new Cont$handleBlock$h$1(pc); +//│ } else { +//│ tmp1 = globalThis.Object.freeze(new Cont$handleBlock$h$1(pc)); +//│ } +//│ tmp2 = tmp1(res); +//│ return tmp2 +//│ }; +//│ globalThis.Object.freeze(class Cont$handleBlock$h$ extends runtime.FunctionContFrame.class { +//│ static { +//│ Cont$handleBlock$h$1 = this +//│ } +//│ constructor(pc) { +//│ return (res) => { +//│ let tmp1; +//│ tmp1 = super(null); +//│ this.res = res; +//│ this.pc = pc; +//│ return this; +//│ } +//│ } +//│ #res; +//│ get res() { return this.#res; } +//│ set res(value) { this.#res = value; } +//│ resume(value$) { +//│ if (this.pc === 1) { +//│ this.res = value$; +//│ } +//│ contLoop: while (true) { +//│ if (this.pc === 1) { +//│ return this.res +//│ } +//│ break; +//│ } +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "Cont$handleBlock$h$"]; +//│ }); +//│ doUnwind$ = function doUnwind$(h, res, res1, pc) { +//│ res1.contTrace.last.next = Cont$handleBlock$h$$(true, res, pc); +//│ return runtime.handleBlockImpl(res1, h) +//│ }; +//│ handleBlock$ = function handleBlock$() { +//│ let h, res; +//│ h = new Handler$h$1(); +//│ res = foo(h); +//│ if (res instanceof runtime.EffectSig.class) { +//│ return doUnwind$(h, res, res, 1) +//│ } +//│ return res +//│ }; +//│ tmp = handleBlock$(); +//│ if (tmp instanceof runtime.EffectSig.class) { tmp = runtime.topLevelEffect(tmp, false); } +//│ tmp +//│ FAILURE: Unexpected runtime error +//│ FAILURE LOCATION: mkQuery (JSBackendDiffMaker.scala:159) +//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded +//│ at Runtime.checkArgs (file:///storage/mark/Repos/mlscript/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs:429:19) +//│ at lambda$ (REPL13:1:6736) +//│ at hdlrFun$capture.fuck$capture$0 (REPL13:1:7499) +//│ at lambda$ (REPL13:1:7024) +//│ at hdlrFun$capture.fuck$capture$0 (REPL13:1:7499) +//│ at lambda$ (REPL13:1:7024) +//│ at hdlrFun$capture.fuck$capture$0 (REPL13:1:7499) +//│ at lambda$ (REPL13:1:7024) +//│ at hdlrFun$capture.fuck$capture$0 (REPL13:1:7499) +//│ at lambda$ (REPL13:1:7024) diff --git a/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls b/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls index b71744b065..8b772c4748 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls @@ -236,10 +236,7 @@ handle h = Eff with else n + f(n-1) resume(f(10000)) foo(h) -//│ FAILURE: Unexpected runtime error -//│ FAILURE LOCATION: processTerm (JSBackendDiffMaker.scala:208) -//│ ═══[RUNTIME ERROR] Expected: '50005000', got: 'fun' -//│ = fun +//│ = 50005000 // function call and defn inside handler :effectHandlers @@ -255,10 +252,7 @@ handle h = Eff with in fun foo(h) = h.perform foo(h) -//│ FAILURE: Unexpected runtime error -//│ FAILURE LOCATION: processTerm (JSBackendDiffMaker.scala:208) -//│ ═══[RUNTIME ERROR] Expected: '50005000', got: 'fun' -//│ = fun +//│ = 50005000 :re :effectHandlers @@ -271,8 +265,7 @@ handle h = Eff with else n + f(n-1) resume(f(10000)) foo(h) -//│ = fun -//│ FAILURE: Unexpected lack of runtime error +//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded :effectHandlers :stackSafe From e354f7859b41888e8f48663173adaa3249d45ce0 Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Mon, 1 Dec 2025 19:30:10 +0800 Subject: [PATCH 06/16] deal with classes --- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 40 ++++++++--- .../src/test/mlscript/tailrec/Errors.mls | 11 ++++ .../src/test/mlscript/tailrec/TailRecOpt.mls | 66 +++++++++++++------ 3 files changed, 90 insertions(+), 27 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index c1437cf969..562ebecd01 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -86,7 +86,7 @@ class TailRecOpt(using State, TL, Raise): c.match case c: TailCall if c.call.explicitTailCall && !cond => raise(ErrorReport( - msg"This tail call exits the current scope and cannot be optimized." -> c.call.toLoc :: Nil)) + msg"This tail call exits the current scope is not optimized." -> c.call.toLoc :: Nil)) case _ => cond @@ -148,7 +148,7 @@ class TailRecOpt(using State, TL, Raise): val hd = for a <- headArgs yield a.spread match case Some(true) => if c.explicitTailCall then - raise(ErrorReport(msg"Spreads are not yet supported here in calls marked @tailcall." -> a.value.toLoc :: Nil)) + raise(ErrorReport(msg"Spreads are not yet fully supported in calls marked @tailcall." -> a.value.toLoc :: Nil)) bad = true a.value case _ => a.value @@ -239,6 +239,10 @@ class TailRecOpt(using State, TL, Raise): val loop = Label(loopSym, true, switch, End()) + val sel = owner match + case Some(value) => Select(Value.Ref(value, N), Tree.Ident(bms.nme))(S(dSym)) + case None => Value.Ref(bms, S(dSym)) + val rewrittenFuns = if scc.funs.size == 1 then Nil else scc.funs.map: f => @@ -248,7 +252,7 @@ class TailRecOpt(using State, TL, Raise): :: paramArgs ::: List.fill(maxParamLen - paramArgs.length)(Value.Lit(Tree.UnitLit(false)).asArg) val newBod = Return( - Call(Value.Ref(bms, S(dSym)), args)(true, false, false), + Call(sel, args)(true, false, false), false ) FunDefn(f.owner, f.sym, f.dSym, f.params, newBod)(false) @@ -267,13 +271,33 @@ class TailRecOpt(using State, TL, Raise): def optFunctions(fs: List[FunDefn], owner: Opt[InnerSymbol]) = partFns(fs).flatMap(optScc(_, owner)) + def reportClassesTailrec(c: ClsLikeDefn) = + new BlockTraverserShallow(): + for f <- c.methods do + applyBlock(f.body) + if f.isTailRec then + raise(ErrorReport(msg"Class methods may not yet be marked @tailrec." -> f.dSym.toLoc :: Nil)) + override def applyResult(r: Result): Unit = r match + case c: Call if c.explicitTailCall => + raise(ErrorReport(msg"Calls from class methods cannot yet be marked @tailcall." -> c.toLoc :: Nil)) + case _ => super.applyResult(r) + def optClasses(cs: List[ClsLikeDefn]) = cs.map: c => - val mtds = optFunctions(c.methods, S(c.isym)) - val companion = c.companion.map: comp => - val cMtds = optFunctions(comp.methods, S(comp.isym)) - comp.copy(methods = cMtds) - c.copy(methods = mtds, companion = companion) + // Class methods cannot yet be optimized as they cannot yet be marked final. + if c.k is syntax.Cls then + reportClassesTailrec(c) + val companion = c.companion.map: comp => + val cMtds = optFunctions(comp.methods, S(comp.isym)) + comp.copy(methods = cMtds) + c.copy(companion = companion) + else + val mtds = optFunctions(c.methods, S(c.isym)) + val companion = c.companion.map: comp => + val cMtds = optFunctions(comp.methods, S(comp.isym)) + comp.copy(methods = cMtds) + c.copy(methods = mtds, companion = companion) + def transform(b: Block) = val (blk, defns) = b.floatOutDefns() val (funs, clses) = diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls index 6e3d249731..d9754b9aa4 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls @@ -56,3 +56,14 @@ fun h(x) = //│ ╟── It could self-recurse through this call, which is not a tail call. //│ ║ l.49: h(x) //│ ╙── ^^^^ + +:e +module A with + fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) + fun g(x) = if x == 0 then 1 else + @tailcall f(x - 1) + @tailcall f(x - 1) +//│ ╔══[ERROR] This call is not in tail position. +//│ ║ l.64: @tailcall f(x - 1) +//│ ╙── ^^^^^^^^ + diff --git a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls index 781834c3f6..1fe1abf903 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls @@ -182,7 +182,7 @@ f(100) :todo fun f(x, y, z) = @tailcall f(...[1, 2, 3]) -//│ ═══[ERROR] Spreads are not yet supported here in calls marked @tailcall. +//│ ═══[ERROR] Spreads are not yet fully supported in calls marked @tailcall. :todo fun f(x, ...y) = @@ -191,30 +191,25 @@ fun f(x, ...y) = fun g(x, y, z) = @tailcall f(...[1, 1, 2]) g(0, 0, 0) -//│ ═══[ERROR] Spreads are not yet supported here in calls marked @tailcall. +//│ ═══[ERROR] Spreads are not yet fully supported in calls marked @tailcall. //│ = 0 +:e +:lift fun f(x) = - fun g(x) = + fun g() = @tailcall f(x) - @tailcall g(x) -//│ FAILURE: Unexpected type error -//│ FAILURE LOCATION: cg (TailRecOpt.scala:88) -//│ ╔══[ERROR] This tail call exits the current scope and cannot be optimized. -//│ ║ l.200: @tailcall g(x) -//│ ╙── ^^^^ + @tailcall f(x) + @tailcall g() +//│ ╔══[ERROR] This call is not in tail position. +//│ ║ l.201: @tailcall f(x) +//│ ╙── ^ -// the lifter doesn't propagate the definition symbols properly -:todo :lift fun f(x) = fun g() = @tailcall f(x) @tailcall g() -//│ ═══[ERROR] Only direct calls in tail position may be marked @tailcall. -//│ ╔══[ERROR] This call is not optimized as it does not directly recurse through its parent function. -//│ ║ l.212: @tailcall f(x) -//│ ╙── ^ :todo :lift @@ -223,9 +218,42 @@ module A with fun g(x) = @tailcall f(x) @tailcall g(x) -//│ ╔══[ERROR] This tail call exits the current scope and cannot be optimized. -//│ ║ l.224: @tailcall f(x) +//│ ╔══[ERROR] This tail call exits the current scope is not optimized. +//│ ║ l.219: @tailcall f(x) //│ ╙── ^^^ -//│ ╔══[ERROR] This tail call exits the current scope and cannot be optimized. -//│ ║ l.225: @tailcall g(x) +//│ ╔══[ERROR] This tail call exits the current scope is not optimized. +//│ ║ l.220: @tailcall g(x) //│ ╙── ^^^^ + +// These calls are represented as field selections and don't yet have the explicitTailCall parameter. +:todo +:e +module A with + fun f = g + fun g = + @tailcall f + f + +:expect 0 +module A with + fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) + fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) +A.f(10) +//│ = 0 + +:todo +:expect 0 +class A with + @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) + fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) +(new A).f(10) +//│ ╔══[ERROR] Calls from class methods cannot yet be marked @tailcall. +//│ ║ l.247: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) +//│ ╙── ^^^^^^^^ +//│ ╔══[ERROR] Class methods may not yet be marked @tailrec. +//│ ║ l.247: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) +//│ ╙── ^ +//│ ╔══[ERROR] Calls from class methods cannot yet be marked @tailcall. +//│ ║ l.248: fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) +//│ ╙── ^^^^^^^^ +//│ = 0 From 639d505d489aa987640c5d9a7d6a803b0b2940f3 Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Mon, 1 Dec 2025 19:32:06 +0800 Subject: [PATCH 07/16] update tests --- hkmc2/shared/src/test/mlscript/HkScratch.mls | 259 +----------------- .../test/mlscript/codegen/FieldSymbols.mls | 4 +- .../src/test/mlscript/lifter/ClassInFun.mls | 130 +-------- .../test/mlscript/lifter/EffectHandlers.mls | 2 +- .../src/test/mlscript/tailrec/Annots.mls | 7 - 5 files changed, 6 insertions(+), 396 deletions(-) diff --git a/hkmc2/shared/src/test/mlscript/HkScratch.mls b/hkmc2/shared/src/test/mlscript/HkScratch.mls index c191173195..ef38e8b363 100644 --- a/hkmc2/shared/src/test/mlscript/HkScratch.mls +++ b/hkmc2/shared/src/test/mlscript/HkScratch.mls @@ -8,262 +8,5 @@ // :d // :todo -class Eff -:effectHandlers -:sjs -:lift -fun foo(h) = h.perform -handle h = Eff with - fun perform(resume) = - let fuck = () - set fuck = n => - if n <= 0 then 0 - else n + fuck(n-1) - resume(fuck(10000)) -foo(h) -//│ JS (unsanitized): -//│ let foo, lambda, tmp, handleBlock$, Cont$handleBlock$h$1, hdlrFun, Cont$func$lambda$1, Cont$handler$h$perform$1, Handler$h$1, doUnwind$, Cont$handleBlock$h$$, hdlrFun$, lambda$, doUnwind$1, Cont$func$lambda$$, doUnwind$2, Cont$handler$h$perform$$, hdlrFun$capture1; -//│ foo = function foo(h) { -//│ return h.perform -//│ }; -//│ Cont$handler$h$perform$$ = function Cont$handler$h$perform$$(isMut, Handler$h$$instance, resume, tmp1, pc) { -//│ let tmp2, tmp3; -//│ if (isMut === true) { -//│ tmp2 = new Cont$handler$h$perform$1(pc); -//│ } else { -//│ tmp2 = globalThis.Object.freeze(new Cont$handler$h$perform$1(pc)); -//│ } -//│ tmp3 = tmp2(Handler$h$$instance, resume, tmp1); -//│ return tmp3 -//│ }; -//│ globalThis.Object.freeze(class Cont$handler$h$perform$ extends runtime.FunctionContFrame.class { -//│ static { -//│ Cont$handler$h$perform$1 = this -//│ } -//│ constructor(pc) { -//│ return (Handler$h$$instance, resume, tmp1) => { -//│ let tmp2; -//│ tmp2 = super(null); -//│ this.resume = resume; -//│ this.tmp = tmp1; -//│ this.Handler$h$$instance = Handler$h$$instance; -//│ this.pc = pc; -//│ return this; -//│ } -//│ } -//│ #resume; -//│ #tmp; -//│ #Handler$h$$instance; -//│ get resume() { return this.#resume; } -//│ set resume(value) { this.#resume = value; } -//│ get tmp() { return this.#tmp; } -//│ set tmp(value) { this.#tmp = value; } -//│ get Handler$h$$instance() { return this.#Handler$h$$instance; } -//│ set Handler$h$$instance(value) { this.#Handler$h$$instance = value; } -//│ resume(value$) { -//│ if (this.pc === 5) { -//│ this.tmp = value$; -//│ } -//│ contLoop: while (true) { -//│ if (this.pc === 6) { -//│ return runtime.safeCall(this.resume(this.tmp)) -//│ } else if (this.pc === 5) { -//│ this.pc = 6; -//│ continue contLoop -//│ } -//│ break; -//│ } -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "Cont$handler$h$perform$"]; -//│ }); -//│ doUnwind$2 = function doUnwind$(Handler$h$$instance, resume, tmp1, res, pc) { -//│ res.contTrace.last.next = Cont$handler$h$perform$$(true, Handler$h$$instance, resume, tmp1, pc); -//│ res.contTrace.last = res.contTrace.last.next; -//│ return res -//│ }; -//│ Cont$func$lambda$$ = function Cont$func$lambda$$(isMut, Handler$h$$instance, n, tmp1, pc) { -//│ let tmp2, tmp3; -//│ if (isMut === true) { -//│ tmp2 = new Cont$func$lambda$1(pc); -//│ } else { -//│ tmp2 = globalThis.Object.freeze(new Cont$func$lambda$1(pc)); -//│ } -//│ tmp3 = tmp2(Handler$h$$instance, n, tmp1); -//│ return tmp3 -//│ }; -//│ globalThis.Object.freeze(class Cont$func$lambda$ extends runtime.FunctionContFrame.class { -//│ static { -//│ Cont$func$lambda$1 = this -//│ } -//│ constructor(pc) { -//│ return (Handler$h$$instance, n, tmp1) => { -//│ let tmp2; -//│ tmp2 = super(null); -//│ this.n = n; -//│ this.tmp = tmp1; -//│ this.Handler$h$$instance = Handler$h$$instance; -//│ this.pc = pc; -//│ return this; -//│ } -//│ } -//│ #n; -//│ #tmp; -//│ #Handler$h$$instance; -//│ get n() { return this.#n; } -//│ set n(value) { this.#n = value; } -//│ get tmp() { return this.#tmp; } -//│ set tmp(value) { this.#tmp = value; } -//│ get Handler$h$$instance() { return this.#Handler$h$$instance; } -//│ set Handler$h$$instance(value) { this.#Handler$h$$instance = value; } -//│ resume(value$) { -//│ if (this.pc === 2) { -//│ this.tmp = value$; -//│ } -//│ contLoop: while (true) { -//│ if (this.pc === 2) { -//│ return this.n + this.tmp -//│ } -//│ break; -//│ } -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "Cont$func$lambda$"]; -//│ }); -//│ doUnwind$1 = function doUnwind$(Handler$h$$instance, n, tmp1, res, pc) { -//│ res.contTrace.last.next = Cont$func$lambda$$(true, Handler$h$$instance, n, tmp1, pc); -//│ res.contTrace.last = res.contTrace.last.next; -//│ return res -//│ }; -//│ lambda$ = function lambda$(Handler$h$$instance, hdlrFun$capture2, n) { -//│ let scrut, tmp1, tmp2; -//│ scrut = n <= 0; -//│ if (scrut === true) { -//│ return 0 -//│ } else { -//│ tmp1 = n - 1; -//│ tmp2 = runtime.safeCall(hdlrFun$capture2.fuck$capture$0(tmp1)); -//│ if (tmp2 instanceof runtime.EffectSig.class) { -//│ return doUnwind$1(Handler$h$$instance, n, tmp2, tmp2, 2) -//│ } -//│ return n + tmp2 -//│ } -//│ }; -//│ lambda = (undefined, function (Handler$h$$instance, hdlrFun$capture2) { -//│ return (n) => { -//│ return lambda$(Handler$h$$instance, hdlrFun$capture2, n) -//│ } -//│ }); -//│ globalThis.Object.freeze(class hdlrFun$capture { -//│ static { -//│ hdlrFun$capture1 = this -//│ } -//│ constructor(fuck$capture$0) { -//│ this.fuck$capture$0 = fuck$capture$0; -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "hdlrFun$capture"]; -//│ }); -//│ hdlrFun$ = function hdlrFun$(Handler$h$$instance, resume) { -//│ let tmp1, capture, lambda$here; -//│ capture = new hdlrFun$capture1(null); -//│ capture.fuck$capture$0 = runtime.Unit; -//│ lambda$here = runtime.safeCall(lambda(Handler$h$$instance, capture)); -//│ capture.fuck$capture$0 = lambda$here; -//│ tmp1 = runtime.safeCall(capture.fuck$capture$0(10000)); -//│ if (tmp1 instanceof runtime.EffectSig.class) { -//│ return doUnwind$2(Handler$h$$instance, resume, tmp1, tmp1, 5) -//│ } -//│ return runtime.safeCall(resume(tmp1)) -//│ }; -//│ hdlrFun = function hdlrFun(Handler$h$$instance) { -//│ return (resume) => { -//│ return hdlrFun$(Handler$h$$instance, resume) -//│ } -//│ }; -//│ globalThis.Object.freeze(class Handler$h$ extends Eff1 { -//│ static { -//│ Handler$h$1 = this -//│ } -//│ constructor() { -//│ let tmp1; -//│ tmp1 = super(); -//│ } -//│ get perform() { -//│ let hdlrFun$here; -//│ hdlrFun$here = runtime.safeCall(hdlrFun(this)); -//│ return runtime.mkEffect(this, hdlrFun$here); -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "Handler$h$"]; -//│ }); -//│ Cont$handleBlock$h$$ = function Cont$handleBlock$h$$(isMut, res, pc) { -//│ let tmp1, tmp2; -//│ if (isMut === true) { -//│ tmp1 = new Cont$handleBlock$h$1(pc); -//│ } else { -//│ tmp1 = globalThis.Object.freeze(new Cont$handleBlock$h$1(pc)); -//│ } -//│ tmp2 = tmp1(res); -//│ return tmp2 -//│ }; -//│ globalThis.Object.freeze(class Cont$handleBlock$h$ extends runtime.FunctionContFrame.class { -//│ static { -//│ Cont$handleBlock$h$1 = this -//│ } -//│ constructor(pc) { -//│ return (res) => { -//│ let tmp1; -//│ tmp1 = super(null); -//│ this.res = res; -//│ this.pc = pc; -//│ return this; -//│ } -//│ } -//│ #res; -//│ get res() { return this.#res; } -//│ set res(value) { this.#res = value; } -//│ resume(value$) { -//│ if (this.pc === 1) { -//│ this.res = value$; -//│ } -//│ contLoop: while (true) { -//│ if (this.pc === 1) { -//│ return this.res -//│ } -//│ break; -//│ } -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "Cont$handleBlock$h$"]; -//│ }); -//│ doUnwind$ = function doUnwind$(h, res, res1, pc) { -//│ res1.contTrace.last.next = Cont$handleBlock$h$$(true, res, pc); -//│ return runtime.handleBlockImpl(res1, h) -//│ }; -//│ handleBlock$ = function handleBlock$() { -//│ let h, res; -//│ h = new Handler$h$1(); -//│ res = foo(h); -//│ if (res instanceof runtime.EffectSig.class) { -//│ return doUnwind$(h, res, res, 1) -//│ } -//│ return res -//│ }; -//│ tmp = handleBlock$(); -//│ if (tmp instanceof runtime.EffectSig.class) { tmp = runtime.topLevelEffect(tmp, false); } -//│ tmp -//│ FAILURE: Unexpected runtime error -//│ FAILURE LOCATION: mkQuery (JSBackendDiffMaker.scala:159) -//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded -//│ at Runtime.checkArgs (file:///storage/mark/Repos/mlscript/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs:429:19) -//│ at lambda$ (REPL13:1:6736) -//│ at hdlrFun$capture.fuck$capture$0 (REPL13:1:7499) -//│ at lambda$ (REPL13:1:7024) -//│ at hdlrFun$capture.fuck$capture$0 (REPL13:1:7499) -//│ at lambda$ (REPL13:1:7024) -//│ at hdlrFun$capture.fuck$capture$0 (REPL13:1:7499) -//│ at lambda$ (REPL13:1:7024) -//│ at hdlrFun$capture.fuck$capture$0 (REPL13:1:7499) -//│ at lambda$ (REPL13:1:7024) + diff --git a/hkmc2/shared/src/test/mlscript/codegen/FieldSymbols.mls b/hkmc2/shared/src/test/mlscript/codegen/FieldSymbols.mls index 6466aeb028..dd62328b18 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/FieldSymbols.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/FieldSymbols.mls @@ -141,9 +141,9 @@ case //│ rest = End of "" //│ rest = Assign: \ //│ lhs = $block$res -//│ rhs = Ref: +//│ rhs = Ref{disamb=term:lambda}: //│ l = member:lambda -//│ disamb = N +//│ disamb = S of term:lambda //│ rest = Return: \ //│ res = Lit of UnitLit of false //│ implct = true diff --git a/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls b/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls index 6ca163f9b8..ec726f6ed4 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/ClassInFun.mls @@ -11,87 +11,7 @@ fun f() = h.perform() 1 f() + f() + f() -//│ FAILURE: Unexpected compilation error -//│ FAILURE LOCATION: lookup_! (Scope.scala:112) -//│ FAILURE INFO: Tuple2: -//│ _1 = Tuple2: -//│ _1 = member:doUnwind -//│ _2 = class hkmc2.semantics.BlockMemberSymbol -//│ _2 = Scope: -//│ parent = S of Scope: -//│ parent = S of Scope: -//│ parent = N -//│ curThis = S of S of globalThis:globalThis -//│ bindings = HashMap(member:doUnwind$ -> doUnwind$1, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:hdlrFun$ -> hdlrFun$, $runtime -> runtime, $definitionMetadata -> definitionMetadata, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:handleBlock$ -> handleBlock$, $block$res -> block$res, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:Cont$func$f$ -> Cont$func$f$, class:Handler$h$ -> Handler$h$, member:Cont$func$f$ -> Cont$func$f$1, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, $tmp -> tmp, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:f$ -> f$, member:doUnwind$ -> doUnwind$, member:Cont$func$f$$ -> Cont$func$f$$) -//│ curThis = S of N -//│ bindings = HashMap($args -> args, h -> h) -//│ curThis = N -//│ bindings = HashMap($tmp -> tmp1) -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' -//│ FAILURE: Unexpected compilation error -//│ FAILURE LOCATION: lookup_! (Scope.scala:112) -//│ FAILURE INFO: Tuple2: -//│ _1 = Tuple2: -//│ _1 = member:doUnwind -//│ _2 = class hkmc2.semantics.BlockMemberSymbol -//│ _2 = Scope: -//│ parent = S of Scope: -//│ parent = S of Scope: -//│ parent = N -//│ curThis = S of S of globalThis:globalThis -//│ bindings = HashMap(member:doUnwind$ -> doUnwind$1, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:hdlrFun$ -> hdlrFun$, $runtime -> runtime, $definitionMetadata -> definitionMetadata, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:handleBlock$ -> handleBlock$, $block$res -> block$res, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:Cont$func$f$ -> Cont$func$f$, class:Handler$h$ -> Handler$h$, member:Cont$func$f$ -> Cont$func$f$1, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, $tmp -> tmp, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:f$ -> f$, member:doUnwind$ -> doUnwind$, member:Cont$func$f$$ -> Cont$func$f$$) -//│ curThis = S of N -//│ bindings = HashMap($args -> args) -//│ curThis = N -//│ bindings = HashMap(h -> h, $tmp -> tmp1, $tmp -> tmp2, $tmp -> tmp3, $tmp -> tmp4) -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' -//│ FAILURE: Unexpected compilation error -//│ FAILURE LOCATION: lookup_! (Scope.scala:112) -//│ FAILURE INFO: Tuple2: -//│ _1 = Tuple2: -//│ _1 = member:doUnwind -//│ _2 = class hkmc2.semantics.BlockMemberSymbol -//│ _2 = Scope: -//│ parent = S of Scope: -//│ parent = S of Scope: -//│ parent = N -//│ curThis = S of S of globalThis:globalThis -//│ bindings = HashMap(member:doUnwind$ -> doUnwind$1, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:hdlrFun$ -> hdlrFun$, $runtime -> runtime, $definitionMetadata -> definitionMetadata, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:handleBlock$ -> handleBlock$, $block$res -> block$res, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:Cont$func$f$ -> Cont$func$f$, class:Handler$h$ -> Handler$h$, member:Cont$func$f$ -> Cont$func$f$1, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, $tmp -> tmp, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:f$ -> f$, member:doUnwind$ -> doUnwind$, member:Cont$func$f$$ -> Cont$func$f$$) -//│ curThis = S of N -//│ bindings = HashMap($args -> args) -//│ curThis = N -//│ bindings = HashMap(h -> h, $tmp -> tmp1, $tmp -> tmp2, $tmp -> tmp3, $tmp -> tmp4) -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' -//│ FAILURE: Unexpected compilation error -//│ FAILURE LOCATION: lookup_! (Scope.scala:112) -//│ FAILURE INFO: Tuple2: -//│ _1 = Tuple2: -//│ _1 = member:doUnwind -//│ _2 = class hkmc2.semantics.BlockMemberSymbol -//│ _2 = Scope: -//│ parent = S of Scope: -//│ parent = S of Scope: -//│ parent = N -//│ curThis = S of S of globalThis:globalThis -//│ bindings = HashMap(member:doUnwind$ -> doUnwind$1, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:hdlrFun$ -> hdlrFun$, $runtime -> runtime, $definitionMetadata -> definitionMetadata, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:handleBlock$ -> handleBlock$, $block$res -> block$res, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:Cont$func$f$ -> Cont$func$f$, class:Handler$h$ -> Handler$h$, member:Cont$func$f$ -> Cont$func$f$1, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, $tmp -> tmp, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:f$ -> f$, member:doUnwind$ -> doUnwind$, member:Cont$func$f$$ -> Cont$func$f$$) -//│ curThis = S of N -//│ bindings = HashMap($args -> args) -//│ curThis = N -//│ bindings = HashMap(h -> h, $tmp -> tmp1, $tmp -> tmp2, $tmp -> tmp3, $tmp -> tmp4) -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' -//│ FAILURE: Unexpected runtime error -//│ FAILURE LOCATION: mkQuery (JSBackendDiffMaker.scala:159) -//│ ═══[RUNTIME ERROR] ReferenceError: doUnwind is not defined -//│ at f$ (REPL10:1:5833) -//│ at handleBlock$ (REPL10:1:6072) -//│ at REPL10:1:6523 -//│ at ContextifyScript.runInThisContext (node:vm:137:12) -//│ at REPLServer.defaultEval (node:repl:562:24) -//│ at bound (node:domain:433:15) -//│ at REPLServer.runBound [as eval] (node:domain:444:12) -//│ at REPLServer.onLine (node:repl:886:12) -//│ at REPLServer.emit (node:events:508:28) -//│ at REPLServer.emit (node:domain:489:12) +//│ = 2 :expect 1 fun f(x) = @@ -306,53 +226,7 @@ fun sum(n) = else n + sum(n - 1) sum(100) -//│ FAILURE: Unexpected compilation error -//│ FAILURE LOCATION: lookup_! (Scope.scala:112) -//│ FAILURE INFO: Tuple2: -//│ _1 = Tuple2: -//│ _1 = member:doUnwind -//│ _2 = class hkmc2.semantics.BlockMemberSymbol -//│ _2 = Scope: -//│ parent = S of Scope: -//│ parent = S of Scope: -//│ parent = N -//│ curThis = S of S of globalThis:globalThis -//│ bindings = HashMap($tmp -> tmp2, member:Test$ -> Test$1, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:g$ -> g$, $runtime -> runtime, member:h$ -> h$, $definitionMetadata -> definitionMetadata, class:Test -> Test6, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:Cont$func$sum$ -> Cont$func$sum$1, $block$res -> block$res5, $res -> res, $tmp -> tmp4, member:Test$ -> Test$3, member:‹stack safe body› -> $_stack$_safe$_body$_, $block$res -> block$res, member:g$ -> g$2, member:doUnwind$ -> doUnwind$2, member:A -> A3, member:h$ -> h$2, member:f -> f6, member:Cont$func$sum$$ -> Cont$func$sum$$, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:A -> A2, class:Handler$h$ -> Handler$h$, member:Test -> Test5, member:f -> f2, a -> a, member:Test -> Test1, member:f -> f, b -> b, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, class:Test -> Test, $tmp -> tmp, member:f$ -> f$, $block$res -> block$res8, member:doUnwind$ -> doUnwind$, $tmp -> tmp6, $tmp -> tmp7, class:Test -> Test4, member:A -> A1, member:Cont$func$f$$ -> Cont$func$f$$, member:f -> f4, member:A$ -> A$1, member:doUnwind$ -> doUnwind$1, class:A -> A, class:f$capture -> f$capture4, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, $block$res -> block$res2, $tmp -> tmp1, member:f$capture -> f$capture5, member:hdlrFun$ -> hdlrFun$, member:Test$ -> Test$, $block$res -> block$res4, $tmp -> tmp3, $block$res -> block$res6, member:Test$ -> Test$2, member:g$ -> g$1, member:A$ -> A$, $block$res -> block$res7, member:handleBlock$ -> handleBlock$, member:h$ -> h$1, class:f$capture -> f$capture, $tmp -> tmp5, member:Bad$ -> Bad$, member:f$capture -> f$capture1, member:Good$ -> Good$, member:Test -> Test3, member:sum -> sum, member:f -> f1, class:f$capture -> f$capture2, class:Cont$func$f$ -> Cont$func$f$, member:f$capture -> f$capture3, $block$res -> block$res9, member:Bad -> Bad1, member:Good -> Good1, class:Test -> Test2, member:f -> f5, member:Cont$func$f$ -> Cont$func$f$1, class:Good -> Good, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:Test -> Test7, member:f -> f3, class:Bad -> Bad, class:Cont$func$sum$ -> Cont$func$sum$, $block$res -> block$res3) -//│ curThis = S of N -//│ bindings = HashMap($args -> args, n -> n) -//│ curThis = N -//│ bindings = HashMap($scrut -> scrut, $curDepth -> curDepth, $stackDelayRes -> stackDelayRes, $tmp -> tmp8, $tmp -> tmp9) -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' -//│ FAILURE: Unexpected compilation error -//│ FAILURE LOCATION: lookup_! (Scope.scala:112) -//│ FAILURE INFO: Tuple2: -//│ _1 = Tuple2: -//│ _1 = member:doUnwind -//│ _2 = class hkmc2.semantics.BlockMemberSymbol -//│ _2 = Scope: -//│ parent = S of Scope: -//│ parent = S of Scope: -//│ parent = N -//│ curThis = S of S of globalThis:globalThis -//│ bindings = HashMap($tmp -> tmp2, member:Test$ -> Test$1, member:Cont$handleBlock$h$ -> Cont$handleBlock$h$1, member:g$ -> g$, $runtime -> runtime, member:h$ -> h$, $definitionMetadata -> definitionMetadata, class:Test -> Test6, $prettyPrint -> prettyPrint, $Term -> Term, member:hdlrFun -> hdlrFun, $Block -> Block, $Shape -> Shape, member:Cont$func$sum$ -> Cont$func$sum$1, $block$res -> block$res5, $res -> res, $tmp -> tmp4, member:Test$ -> Test$3, member:‹stack safe body› -> $_stack$_safe$_body$_, $block$res -> block$res, member:g$ -> g$2, member:doUnwind$ -> doUnwind$2, member:A -> A3, member:h$ -> h$2, member:f -> f6, member:Cont$func$sum$$ -> Cont$func$sum$$, member:Effect -> Effect1, member:Predef -> Predef, class:Effect -> Effect, class:A -> A2, class:Handler$h$ -> Handler$h$, member:Test -> Test5, member:f -> f2, a -> a, member:Test -> Test1, member:f -> f, b -> b, member:Handler$h$ -> Handler$h$1, $block$res -> block$res1, class:Test -> Test, $tmp -> tmp, member:f$ -> f$, $block$res -> block$res8, member:doUnwind$ -> doUnwind$, $tmp -> tmp6, $tmp -> tmp7, class:Test -> Test4, member:A -> A1, member:Cont$func$f$$ -> Cont$func$f$$, member:f -> f4, member:A$ -> A$1, member:doUnwind$ -> doUnwind$1, class:A -> A, class:f$capture -> f$capture4, member:Cont$handleBlock$h$$ -> Cont$handleBlock$h$$, $block$res -> block$res2, $tmp -> tmp1, member:f$capture -> f$capture5, member:hdlrFun$ -> hdlrFun$, member:Test$ -> Test$, $block$res -> block$res4, $tmp -> tmp3, $block$res -> block$res6, member:Test$ -> Test$2, member:g$ -> g$1, member:A$ -> A$, $block$res -> block$res7, member:handleBlock$ -> handleBlock$, member:h$ -> h$1, class:f$capture -> f$capture, $tmp -> tmp5, member:Bad$ -> Bad$, member:f$capture -> f$capture1, member:Good$ -> Good$, member:Test -> Test3, member:sum -> sum, member:f -> f1, class:f$capture -> f$capture2, class:Cont$func$f$ -> Cont$func$f$, member:f$capture -> f$capture3, $block$res -> block$res9, member:Bad -> Bad1, member:Good -> Good1, class:Test -> Test2, member:f -> f5, member:Cont$func$f$ -> Cont$func$f$1, class:Good -> Good, class:Cont$handleBlock$h$ -> Cont$handleBlock$h$, member:Test -> Test7, member:f -> f3, class:Bad -> Bad, class:Cont$func$sum$ -> Cont$func$sum$, $block$res -> block$res3) -//│ curThis = S of N -//│ bindings = HashMap($args -> args, n -> n) -//│ curThis = N -//│ bindings = HashMap($scrut -> scrut, $curDepth -> curDepth, $stackDelayRes -> stackDelayRes, $tmp -> tmp8, $tmp -> tmp9) -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'doUnwind' -//│ FAILURE: Unexpected runtime error -//│ FAILURE LOCATION: mkQuery (JSBackendDiffMaker.scala:159) -//│ ═══[RUNTIME ERROR] ReferenceError: doUnwind is not defined -//│ at sum (REPL36:1:2283) -//│ at sum (REPL36:1:2438) -//│ at sum (REPL36:1:2438) -//│ at sum (REPL36:1:2438) -//│ at sum (REPL36:1:2438) -//│ at sum (REPL36:1:2438) -//│ at sum (REPL36:1:2438) -//│ at sum (REPL36:1:2438) -//│ at sum (REPL36:1:2438) -//│ at $_stack$_safe$_body$_ (REPL36:1:2750) +//│ = 5050 // instance checks diff --git a/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls b/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls index ff48e0ad97..6446a8b3fe 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls @@ -60,7 +60,7 @@ module A with //│ static [definitionMetadata] = ["class", "Cont$ctor$Test$"]; //│ }); //│ doUnwind$ = function doUnwind$(Test$instance, tmp, res, pc) { -//│ res.contTrace.last.next = new Cont$ctor$Test$1(pc); +//│ res.contTrace.last.next = Cont$ctor$Test$$(true, Test$instance, tmp, pc); //│ res.contTrace.last = res.contTrace.last.next; //│ return res //│ }; diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls b/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls index 6aaef483f2..016e3bac1f 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls @@ -21,7 +21,6 @@ class A :todo @tailrec fun f = 2 -//│ ═══[WARNING] Tail call optimization is not yet implemented. :w @tailcall @@ -31,10 +30,6 @@ fun g = 2 :todo fun test = @tailcall f -//│ ╔══[WARNING] This annotation has no effect. -//│ ╟── Tail call optimization is not yet implemented. -//│ ║ l.33: @tailcall f -//│ ╙── ^ :w let f = 0 @@ -50,7 +45,6 @@ fun test = class A with @tailrec fun f = 2 -//│ ═══[WARNING] Tail call optimization is not yet implemented. :w class A with @@ -63,7 +57,6 @@ class A module A with @tailrec fun f = 2 -//│ ═══[WARNING] Tail call optimization is not yet implemented. :w class A From 217cbcf0d1f34fa44e5ed9c7c631e47c54fa7622 Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Mon, 1 Dec 2025 19:45:49 +0800 Subject: [PATCH 08/16] add warnings and update tests, clean up --- .../scala/hkmc2/codegen/LambdaRewriter.scala | 1 + .../src/main/scala/hkmc2/codegen/Lifter.scala | 2 +- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 9 +- .../hkmc2/semantics/ucs/Normalization.scala | 2 +- hkmc2/shared/src/test/mlscript/HkScratch.mls | 3 +- .../test/mlscript/lifter/EffectHandlers.mls | 93 ++----------------- .../src/test/mlscript/tailrec/Annots.mls | 12 +-- .../src/test/mlscript/tailrec/Errors.mls | 20 +++- 8 files changed, 44 insertions(+), 98 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala index cec6c76a24..06a23125af 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala @@ -11,6 +11,7 @@ import hkmc2.syntax.Tree object LambdaRewriter: def desugar(b: Block)(using State) = + def rewriteOneBlk(b: Block) = b match case Assign(lhs, Lambda(params, body), rest) if !lhs.isInstanceOf[TempSymbol] => val newSym = BlockMemberSymbol(lhs.nme, Nil, diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index cb1fd6e234..c8ec1200fe 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -359,7 +359,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): val imutVars = captureFnVars.vars imutVars.filter: s => !mutVars.contains(s) && candVars.contains(s) - + case class FunSyms[T <: DefinitionSymbol[?]](b: BlockMemberSymbol, d: T): def asPath = Value.Ref(b, S(d)) object FunSyms: diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 562ebecd01..9720aff78a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -167,13 +167,16 @@ class TailRecOpt(using State, TL, Raise): S(ret) def optScc(scc: SccOfCalls, owner: Opt[InnerSymbol]): List[FunDefn] = - if scc.calls.size == 0 then return scc.funs - val nonTailCalls = scc.calls .collect: case c: NormalCall => c.f2 -> c.call .toMap + if nonTailCalls.size == scc.calls.length then + for f <- scc.funs if f.isTailRec do + raise(WarningReport(msg"This function does not directly self-recurse, but is marked @tailrec." -> f.dSym.toLoc :: Nil)) + return scc.funs + if !nonTailCalls.isEmpty then for f <- scc.funs if f.isTailRec do val reportLoc = nonTailCalls.get(f.dSym) match @@ -181,7 +184,7 @@ class TailRecOpt(using State, TL, Raise): case Some(value) => value.toLoc case None => nonTailCalls.head._2.toLoc raise(ErrorReport( - msg"`${f.sym.nme}` is not tail recursive." -> f.dSym.toLoc + msg"This function is not tail recursive." -> f.dSym.toLoc :: msg"It could self-recurse through this call, which is not a tail call." -> reportLoc :: Nil )) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index b817ac55f1..14c3de567e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -328,7 +328,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e // The symbol for the loop label if the term is a `while`. lazy val loopLabel = new TempSymbol(t) lazy val f = new BlockMemberSymbol("while", Nil, false) - lazy val tSym = new TermSymbol(syntax.Fun, N, Tree.Ident(f.nme)) + lazy val tSym = TermSymbol.fromFunBms(f, N) val normalized = tl.scoped("ucs:normalize"): normalize(inputSplit)(using VarSet()) tl.scoped("ucs:normalized"): diff --git a/hkmc2/shared/src/test/mlscript/HkScratch.mls b/hkmc2/shared/src/test/mlscript/HkScratch.mls index ef38e8b363..fd12527772 100644 --- a/hkmc2/shared/src/test/mlscript/HkScratch.mls +++ b/hkmc2/shared/src/test/mlscript/HkScratch.mls @@ -9,4 +9,5 @@ // :todo - +@tailrec +fun f = 2 diff --git a/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls b/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls index 6446a8b3fe..e628fd9870 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls @@ -3,96 +3,23 @@ fun f() = 3 -:sjs :effectHandlers module A with data class Test with f() val a = 1 -//│ JS (unsanitized): -//│ let A1, Cont$ctor$Test$1, doUnwind$, Cont$ctor$Test$$; -//│ Cont$ctor$Test$$ = function Cont$ctor$Test$$(isMut, Test$instance, tmp, pc) { -//│ let tmp1, tmp2; -//│ if (isMut === true) { -//│ tmp1 = new Cont$ctor$Test$1(pc); -//│ } else { -//│ tmp1 = globalThis.Object.freeze(new Cont$ctor$Test$1(pc)); -//│ } -//│ tmp2 = tmp1(Test$instance, tmp); -//│ return tmp2 -//│ }; -//│ globalThis.Object.freeze(class Cont$ctor$Test$ extends runtime.FunctionContFrame.class { -//│ static { -//│ Cont$ctor$Test$1 = this -//│ } -//│ constructor(pc) { -//│ return (Test$instance, tmp) => { -//│ let tmp1; -//│ tmp1 = super(null); -//│ this.tmp = tmp; -//│ this.Test$instance = Test$instance; -//│ this.pc = pc; -//│ return this; -//│ } -//│ } -//│ #tmp; -//│ #Test$instance; -//│ get tmp() { return this.#tmp; } -//│ set tmp(value) { this.#tmp = value; } -//│ get Test$instance() { return this.#Test$instance; } -//│ set Test$instance(value) { this.#Test$instance = value; } -//│ resume(value$) { -//│ if (this.pc === 1) { -//│ this.tmp = value$; -//│ } -//│ contLoop: while (true) { -//│ if (this.pc === 2) { -//│ return this.Test$instance -//│ } else if (this.pc === 1) { -//│ this.Test$instance.a = 1; -//│ this.pc = 2; -//│ continue contLoop -//│ } -//│ break; -//│ } -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "Cont$ctor$Test$"]; -//│ }); -//│ doUnwind$ = function doUnwind$(Test$instance, tmp, res, pc) { -//│ res.contTrace.last.next = Cont$ctor$Test$$(true, Test$instance, tmp, pc); -//│ res.contTrace.last = res.contTrace.last.next; -//│ return res -//│ }; -//│ globalThis.Object.freeze(class A { -//│ static { -//│ A1 = this -//│ } -//│ constructor() { -//│ runtime.Unit; -//│ } -//│ static { -//│ globalThis.Object.freeze(class Test { -//│ static { -//│ A.Test = this -//│ } -//│ constructor() { -//│ let tmp; -//│ tmp = f(); -//│ if (tmp instanceof runtime.EffectSig.class) { -//│ return doUnwind$(this, tmp, tmp, 1) -//│ } -//│ this.a = 1; -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "Test"]; -//│ }); -//│ } -//│ toString() { return runtime.render(this); } -//│ static [definitionMetadata] = ["class", "A"]; -//│ }); :effectHandlers class Test with f() val a = 1 + +class Eff + +:effectHandlers +:expect 7 +handle h = Eff with + fun perform(x)(k) = k(x) + 1 +let x = 5 +h.perform(x) + 1 +//│ = 7 diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls b/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls index 016e3bac1f..00afc4659f 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls @@ -18,16 +18,15 @@ class A //│ ╙── ^^^^^ //│ = 4 -:todo @tailrec -fun f = 2 +fun f = g +fun g = f :w @tailcall fun g = 2 //│ ═══[WARNING] This annotation has no effect. -:todo fun test = @tailcall f @@ -37,14 +36,15 @@ fun test = @tailcall f //│ ╔══[WARNING] This annotation has no effect. //│ ╟── This annotation is not supported on reference terms. -//│ ║ l.42: @tailcall f +//│ ║ l.36: @tailcall f //│ ╙── ^ //│ f = 0 :todo class A with @tailrec - fun f = 2 + fun f() = g() + fun g() = f() :w class A with @@ -77,5 +77,5 @@ fun test = @tailcall 1 + 2 //│ ╔══[WARNING] This annotation has no effect. //│ ╟── The @tailcall annotation has no effect on calls to built-in symbols. -//│ ║ l.84: @tailcall 1 + 2 +//│ ║ l.77: @tailcall 1 + 2 //│ ╙── ^^^^^ diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls index d9754b9aa4..1800e2c677 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls @@ -20,7 +20,7 @@ fun g(x) = @tailcall f(x) @tailrec fun f(x) = g(x) fun g(x) = f(x); f(x) -//│ ╔══[ERROR] `f` is not tail recursive. +//│ ╔══[ERROR] This function is not tail recursive. //│ ║ l.20: @tailrec fun f(x) = //│ ║ ^ //│ ╟── It could self-recurse through this call, which is not a tail call. @@ -35,7 +35,7 @@ fun g(x) = h(x) fun h(x) = g(x) -//│ ╔══[ERROR] `f` is not tail recursive. +//│ ╔══[ERROR] This function is not tail recursive. //│ ║ l.31: @tailrec fun f(x) = //│ ║ ^ //│ ╟── It could self-recurse through this call, which is not a tail call. @@ -50,7 +50,7 @@ fun g(x) = f(x) fun h(x) = f(x) -//│ ╔══[ERROR] `f` is not tail recursive. +//│ ╔══[ERROR] This function is not tail recursive. //│ ║ l.46: @tailrec fun f(x) = //│ ║ ^ //│ ╟── It could self-recurse through this call, which is not a tail call. @@ -67,3 +67,17 @@ module A with //│ ║ l.64: @tailcall f(x - 1) //│ ╙── ^^^^^^^^ +:w +@tailrec +fun f = 2 +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.72: fun f = 2 +//│ ╙── ^ + +:w +module A with + @tailrec + fun f() = 2 +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.80: fun f() = 2 +//│ ╙── ^ From e295bbdda0eaa82c3cf18d7ed9538c6166f2881a Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Fri, 5 Dec 2025 16:52:42 +0800 Subject: [PATCH 09/16] pr comments, fix some bugs --- .../shared/src/main/scala/hkmc2/Config.scala | 2 +- .../scala/hkmc2/codegen/BlockTraverser.scala | 1 + .../main/scala/hkmc2/codegen/TailRecOpt.scala | 108 ++++++++++------ hkmc2/shared/src/test/mlscript/HkScratch.mls | 3 +- .../src/test/mlscript/codegen/RandomStuff.mls | 18 ++- .../src/test/mlscript/codegen/TraceLog.mls | 17 ++- .../src/test/mlscript/codegen/While.mls | 36 +++--- .../test/mlscript/handlers/StackSafety.mls | 1 + .../test/mlscript/lifter/StackSafetyLift.mls | 1 + .../src/test/mlscript/llir/BasisLLIR.mls | 1 + .../shared/src/test/mlscript/llir/Classes.mls | 1 + .../src/test/mlscript/llir/ControlFlow.mls | 1 + .../src/test/mlscript/llir/HigherOrder.mls | 1 + hkmc2/shared/src/test/mlscript/nofib/mate.mls | 8 +- .../test/mlscript/std/FingerTreeListTest.mls | 27 ++-- .../src/test/mlscript/tailrec/Annots.mls | 10 +- .../src/test/mlscript/tailrec/Errors.mls | 23 ++-- .../src/test/mlscript/tailrec/TailRecOpt.mls | 115 +++++------------- .../mlscript/ucs/patterns/Compilation.mls | 4 +- .../src/test/scala/hkmc2/MLsDiffMaker.scala | 4 +- 20 files changed, 199 insertions(+), 183 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/Config.scala b/hkmc2/shared/src/main/scala/hkmc2/Config.scala index f46e001a21..e7b3ae69b8 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/Config.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/Config.scala @@ -41,7 +41,7 @@ object Config: target = CompilationTarget.JS, rewriteWhileLoops = true, stageCode = false, - tailRecOpt = false, + tailRecOpt = true, ) case class SanityChecks(light: Bool) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala index 761d12b8fb..5d1137e5e5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala @@ -127,6 +127,7 @@ class BlockTraverser: cls.traverse applyPath(path) case Case.Tup(len, inf) => () + case Case.Field(_, _) => () def applyHandler(hdr: Handler): Unit = hdr.sym.traverse diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 9720aff78a..ede4fb60fa 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -27,17 +27,16 @@ class TailRecOpt(using State, TL, Raise): object TailCallShape: def unapply(b: Block): Opt[(TermSymbol, Call)] = b match case Return(c @ CallToFun(r), _) => S((r, c)) - case Assign(a, c @ CallToFun(r), Return(b, _)) if a == b => S((r, c)) + case Assign(a, c @ CallToFun(r), Return(Value.Ref(b, _), _)) if a === b => S((r, c)) case _ => N - sealed abstract class CallEdge: + enum CallEdge: val f1: TermSymbol val f2: TermSymbol val call: Call - - case class TailCall(f1: TermSymbol, f2: TermSymbol)(val call: Call) extends CallEdge - case class NormalCall(f1: TermSymbol, f2: TermSymbol)(val call: Call) extends CallEdge + case TailCall(f1: TermSymbol, f2: TermSymbol)(val call: Call) + case NormalCall(f1: TermSymbol, f2: TermSymbol)(val call: Call) class CallFinder(f: FunDefn) extends BlockTraverserShallow: @@ -55,7 +54,7 @@ class TailRecOpt(using State, TL, Raise): edges override def applyBlock(b: Block): Unit = b match - case TailCallShape(r, c) => edges ::= TailCall(f.dSym, r)(c) + case TailCallShape(r, c) => edges ::= CallEdge.TailCall(f.dSym, r)(c) case Return(c: Call, _) => if c.explicitTailCall then raise(ErrorReport(msg"Only direct calls in tail position may be marked @tailcall." -> c.toLoc :: Nil)) @@ -66,7 +65,7 @@ class TailRecOpt(using State, TL, Raise): if c.explicitTailCall then raise(ErrorReport(msg"This call is not in tail position." -> c.toLoc :: Nil)) c match - case CallToFun(r) => edges ::= NormalCall(f.dSym, r)(c) + case CallToFun(r) => edges ::= CallEdge.NormalCall(f.dSym, r)(c) case _ => case _ => super.applyResult(r) @@ -84,7 +83,7 @@ class TailRecOpt(using State, TL, Raise): val cg = buildCallGraph(fs).filter: c => val cond = defnSyms.contains(c.f1) && defnSyms.contains(c.f2) c.match - case c: TailCall if c.call.explicitTailCall && !cond => + case c: CallEdge.TailCall if c.call.explicitTailCall && !cond => raise(ErrorReport( msg"This tail call exits the current scope is not optimized." -> c.call.toLoc :: Nil)) case _ => @@ -101,13 +100,13 @@ class TailRecOpt(using State, TL, Raise): .groupBy: c => val s1 = sccMap(c.f1) val s2 = sccMap(c.f2) - if s1 != s2 && c.call.explicitTailCall then + if s1 =/= s2 && c.call.explicitTailCall then raise(ErrorReport( msg"This call is not optimized as it does not directly recurse through its parent function." -> c.call.toLoc :: Nil)) -1 else s1 .filter: - (id, _) => id != -1 + (id, _) => id =/= -1 sccs.sccs.toList.map: v => val (id, tss) = v @@ -166,16 +165,21 @@ class TailRecOpt(using State, TL, Raise): case _ => return N S(ret) - def optScc(scc: SccOfCalls, owner: Opt[InnerSymbol]): List[FunDefn] = - val nonTailCalls = scc.calls + def optScc(scc: SccOfCalls, owner: Opt[InnerSymbol]): (Opt[FunDefn], List[FunDefn]) = + // remove calls which don't flow into this scc + val fSyms = scc.funs.map(_.dSym).toSet + + val calls = scc.calls.filter(c => fSyms.contains(c.f2)) + + val nonTailCalls = calls .collect: - case c: NormalCall => c.f2 -> c.call + case c: CallEdge.NormalCall => c.f2 -> c.call .toMap - if nonTailCalls.size == scc.calls.length then + if nonTailCalls.size === calls.length then for f <- scc.funs if f.isTailRec do raise(WarningReport(msg"This function does not directly self-recurse, but is marked @tailrec." -> f.dSym.toLoc :: Nil)) - return scc.funs + return (N, scc.funs) if !nonTailCalls.isEmpty then for f <- scc.funs if f.isTailRec do @@ -191,22 +195,22 @@ class TailRecOpt(using State, TL, Raise): val maxParamLen = maxInt(scc.funs, paramsLen) val paramSyms = - if scc.funs.length == 1 then (getParamSyms(scc.funs.head)) + if scc.funs.length === 1 then (getParamSyms(scc.funs.head)) else for i <- 0 to maxParamLen - 1 yield VarSymbol(Tree.Ident("param" + i)) .toList val paramSymsArr = ArrayBuffer.from(paramSyms) val dSymIds = scc.funs.map(_.dSym).zipWithIndex.toMap val bms = - if scc.funs.size == 1 then scc.funs.head.sym + if scc.funs.size === 1 then scc.funs.head.sym else BlockMemberSymbol(scc.funs.map(_.sym.nme).mkString("_"), Nil, true) val dSym = - if scc.funs.size == 1 then scc.funs.head.dSym + if scc.funs.size === 1 then scc.funs.head.dSym else TermSymbol(syntax.Fun, owner, Tree.Ident(bms.nme)) val loopSym = TempSymbol(N, "loopLabel") val curIdSym = VarSymbol(Tree.Ident("id")) - class FunRewriter(f: FunDefn) extends BlockTransformer(SymbolSubst()): + class FunRewriter(f: FunDefn) extends BlockTransformerShallow(SymbolSubst()): val params = getParamSyms(f) val paramsSet = f.params.toSet val paramsIdxes = params.zipWithIndex.toMap @@ -225,11 +229,13 @@ class TailRecOpt(using State, TL, Raise): val argVals = rewriteCallArgs(f, c) match case Some(value) => value case None => return super.applyBlock(b) - val cont = Assign(curIdSym, Value.Lit(Tree.IntLit(dSymIds(dSym))), Continue(loopSym)) + val cont = + if scc.funs.size === 1 then Continue(loopSym) + else Assign(curIdSym, Value.Lit(Tree.IntLit(dSymIds(dSym))), Continue(loopSym)) paramSyms.zip(argVals).foldRight[Block](cont): - case ((sym, res), acc) => res match - case Value.Ref(`sym`, _) => acc - case _ => applyResult(res)(Assign(sym, _, acc)) + case ((sym, res), acc) => applyResult(res)(Assign(sym, _, acc)) match + case Assign(sym, Value.Ref(sym1, _), rest) if sym === sym1 => rest + case x => x case None => super.applyBlock(b) case _ => super.applyBlock(b) @@ -237,7 +243,7 @@ class TailRecOpt(using State, TL, Raise): Case.Lit(Tree.IntLit(dSymIds(f.dSym))) -> FunRewriter(f).applyBlock(f.body) val switch = - if arms.length == 1 then arms.head._2 + if arms.length === 1 then arms.head._2 else Match(curIdSym.asPath, arms, N, End()) val loop = Label(loopSym, true, switch, End()) @@ -247,7 +253,7 @@ class TailRecOpt(using State, TL, Raise): case None => Value.Ref(bms, S(dSym)) val rewrittenFuns = - if scc.funs.size == 1 then Nil + if scc.funs.size === 1 then Nil else scc.funs.map: f => val paramArgs = getParamSyms(f).map(_.asPath.asArg) val args = @@ -262,17 +268,22 @@ class TailRecOpt(using State, TL, Raise): val params = val initial = paramSyms.map(Param.simple(_)) - if scc.funs.length == 1 then initial + if scc.funs.length === 1 then initial else Param.simple(curIdSym) :: initial - - FunDefn( + + val loopDefn = FunDefn( owner, bms, dSym, PlainParamList(params) :: Nil, - loop - )(false) :: rewrittenFuns + loop)(false) + + if scc.funs.size === 1 then (N, loopDefn :: Nil) + else (S(loopDefn), rewrittenFuns) def optFunctions(fs: List[FunDefn], owner: Opt[InnerSymbol]) = - partFns(fs).flatMap(optScc(_, owner)) + partFns(fs).map(optScc(_, owner)).foldLeft[(List[FunDefn], List[FunDefn])](Nil, Nil): + case ((newFns, fns), (newFnOpt, fns_)) => newFnOpt match + case Some(value) => (value :: newFns, fns_ ::: fns) + case None => (newFns, fns_ ::: fns) def reportClassesTailrec(c: ClsLikeDefn) = new BlockTraverserShallow(): @@ -285,19 +296,23 @@ class TailRecOpt(using State, TL, Raise): raise(ErrorReport(msg"Calls from class methods cannot yet be marked @tailcall." -> c.toLoc :: Nil)) case _ => super.applyResult(r) + def optFunctionsFlat(fs: List[FunDefn], owner: Opt[InnerSymbol]) = + val (a, b) = optFunctions(fs, owner) + a ::: b + def optClasses(cs: List[ClsLikeDefn]) = cs.map: c => // Class methods cannot yet be optimized as they cannot yet be marked final. if c.k is syntax.Cls then reportClassesTailrec(c) val companion = c.companion.map: comp => - val cMtds = optFunctions(comp.methods, S(comp.isym)) + val cMtds = optFunctionsFlat(comp.methods, S(comp.isym)) comp.copy(methods = cMtds) c.copy(companion = companion) else - val mtds = optFunctions(c.methods, S(c.isym)) + val mtds = optFunctionsFlat(c.methods, S(c.isym)) val companion = c.companion.map: comp => - val cMtds = optFunctions(comp.methods, S(comp.isym)) + val cMtds = optFunctionsFlat(comp.methods, S(comp.isym)) comp.copy(methods = cMtds) c.copy(methods = mtds, companion = companion) @@ -309,7 +324,26 @@ class TailRecOpt(using State, TL, Raise): case f: FunDefn => (f :: fs, cs) case c: ClsLikeDefn => (fs, c :: cs) case _ => (fs, cs) // unreachable as floatOutDefns only floats out FunDefns and ClsLikeDefns - val bod1 = optFunctions(funs, N).foldLeft(blk): + val (optFNew, optF) = optFunctions(funs, N) + val optC = optClasses(clses) + + val fMap = optF.map(f => f.dSym -> f).toMap + // Scala needs this annotation to type check for some reason + val cMap: Map[DefinitionSymbol[? <: ClassLikeDef] & InnerSymbol, ClsLikeDefn] = + optC.map(c => c.isym -> c).toMap + + // replace them in place + val transformer = new BlockTransformerShallow(SymbolSubst()): + override def applyDefn(defn: Defn)(k: Defn => Block): Block = defn match + case f: FunDefn => fMap.get(f.dSym) match + case Some(value) => k(value) + case None => k(f) + + case c: ClsLikeDefn => cMap.get(c.isym) match + case Some(value) => k(value) + case None => k(c) + + case _ => super.applyDefn(defn)(k) + + optFNew.foldLeft(transformer.applyBlock(b)): case (acc, f) => Define(f, acc) - optClasses(clses).foldLeft(bod1): - case (acc, c) => Define(c, acc) diff --git a/hkmc2/shared/src/test/mlscript/HkScratch.mls b/hkmc2/shared/src/test/mlscript/HkScratch.mls index fd12527772..ef38e8b363 100644 --- a/hkmc2/shared/src/test/mlscript/HkScratch.mls +++ b/hkmc2/shared/src/test/mlscript/HkScratch.mls @@ -9,5 +9,4 @@ // :todo -@tailrec -fun f = 2 + diff --git a/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls b/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls index 10940ea078..98b618b132 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls @@ -7,14 +7,17 @@ fun foo() = if false do foo() //│ let foo; //│ foo = function foo() { //│ let scrut; -//│ scrut = false; -//│ if (scrut === true) { return foo() } else { return runtime.Unit } +//│ loopLabel: while (true) { +//│ scrut = false; +//│ if (scrut === true) { continue loopLabel } else { return runtime.Unit } +//│ break; +//│ } //│ }; :sjs fun foo() = foo() //│ JS (unsanitized): -//│ let foo1; foo1 = function foo() { return foo1() }; +//│ let foo1; foo1 = function foo() { loopLabel: while (true) { continue loopLabel; break; } }; :sjs @@ -63,7 +66,10 @@ do fun f = f () //│ JS (unsanitized): -//│ let f, f1; f1 = 1; f = function f() { let tmp; tmp = f(); return tmp }; runtime.Unit +//│ let f, f1; +//│ f1 = 1; +//│ f = function f() { loopLabel: while (true) { continue loopLabel; break; } }; +//│ runtime.Unit //│ f = 1 :sjs @@ -71,10 +77,10 @@ do let foo = 1 fun foo(x) = foo //│ ╔══[ERROR] Name 'foo' is already used -//│ ║ l.71: let foo = 1 +//│ ║ l.81: let foo = 1 //│ ║ ^^^^^^^ //│ ╟── by a member declared in the same block -//│ ║ l.72: fun foo(x) = foo +//│ ║ l.82: fun foo(x) = foo //│ ╙── ^^^^^^^^^^^^^^^^ //│ JS (unsanitized): //│ let foo3, foo4; foo3 = function foo(x) { return foo4 }; foo4 = 1; diff --git a/hkmc2/shared/src/test/mlscript/codegen/TraceLog.mls b/hkmc2/shared/src/test/mlscript/codegen/TraceLog.mls index 1b3593632c..de72d61fde 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/TraceLog.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/TraceLog.mls @@ -10,10 +10,19 @@ fun fib(a) = if //│ let fib; //│ fib = function fib(a) { //│ let scrut, tmp, tmp1, tmp2, tmp3; -//│ scrut = a <= 1; -//│ if (scrut === true) { -//│ return a -//│ } else { tmp = a - 1; tmp1 = fib(tmp); tmp2 = a - 2; tmp3 = fib(tmp2); return tmp1 + tmp3 } +//│ loopLabel: while (true) { +//│ scrut = a <= 1; +//│ if (scrut === true) { +//│ return a +//│ } else { +//│ tmp = a - 1; +//│ tmp1 = fib(tmp); +//│ tmp2 = a - 2; +//│ tmp3 = fib(tmp2); +//│ return tmp1 + tmp3 +//│ } +//│ break; +//│ } //│ }; fun f(x) = g(x) diff --git a/hkmc2/shared/src/test/mlscript/codegen/While.mls b/hkmc2/shared/src/test/mlscript/codegen/While.mls index 407c7639f2..cff4bb2dd8 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/While.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/While.mls @@ -61,13 +61,16 @@ while x //│ tmp4 = undefined; //│ while3 = (undefined, function () { //│ let tmp6; -//│ if (x2 === true) { -//│ tmp6 = Predef.print("Hello World"); -//│ x2 = false; -//│ tmp4 = runtime.Unit; -//│ return while3() -//│ } else { tmp4 = 42; } -//│ return runtime.LoopEnd +//│ loopLabel: while (true) { +//│ if (x2 === true) { +//│ tmp6 = Predef.print("Hello World"); +//│ x2 = false; +//│ tmp4 = runtime.Unit; +//│ continue loopLabel +//│ } else { tmp4 = 42; } +//│ return runtime.LoopEnd; +//│ break; +//│ } //│ }); //│ tmp5 = while3(); //│ tmp4 @@ -297,10 +300,10 @@ while print("Hello World"); false then 0(0) else 1 //│ ╔══[PARSE ERROR] Unexpected 'then' keyword here -//│ ║ l.297: then 0(0) +//│ ║ l.301: then 0(0) //│ ╙── ^^^^ //│ ╔══[ERROR] Unrecognized term split (false literal) -//│ ║ l.296: while print("Hello World"); false +//│ ║ l.300: while print("Hello World"); false //│ ╙── ^^^^^ //│ > Hello World //│ ═══[RUNTIME ERROR] Error: match error @@ -310,12 +313,12 @@ while { print("Hello World"), false } then 0(0) else 1 //│ ╔══[ERROR] Unexpected infix use of keyword 'then' here -//│ ║ l.309: while { print("Hello World"), false } +//│ ║ l.313: while { print("Hello World"), false } //│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.310: then 0(0) +//│ ║ l.314: then 0(0) //│ ╙── ^^^^^^^^^^^ //│ ╔══[ERROR] Illegal position for prefix keyword 'else'. -//│ ║ l.311: else 1 +//│ ║ l.315: else 1 //│ ╙── ^^^^ :fixme @@ -325,14 +328,14 @@ while then 0(0) else 1 //│ ╔══[ERROR] Unexpected infix use of keyword 'then' here -//│ ║ l.323: print("Hello World") +//│ ║ l.327: print("Hello World") //│ ║ ^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.324: false +//│ ║ l.328: false //│ ║ ^^^^^^^^^ -//│ ║ l.325: then 0(0) +//│ ║ l.329: then 0(0) //│ ╙── ^^^^^^^^^^^ //│ ╔══[ERROR] Illegal position for prefix keyword 'else'. -//│ ║ l.326: else 1 +//│ ║ l.330: else 1 //│ ╙── ^^^^ @@ -362,6 +365,7 @@ class Lazy[out A](f: () -> A) with cached :expect [1, 2, 3] +:lift let arr = [1, 2, 3] let output = mut [] let i = 0 diff --git a/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls b/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls index 86684cf7ea..5c9785196e 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls @@ -1,4 +1,5 @@ :js +:noTailRec // sanity check :expect 5050 diff --git a/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls b/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls index 8b772c4748..abd2215af4 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls @@ -1,5 +1,6 @@ :js :lift +:noTailRec // sanity check :expect 5050 diff --git a/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls b/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls index bb419bc5ea..5203c021db 100644 --- a/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls +++ b/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls @@ -1,6 +1,7 @@ :js :llir :cpp +:noTailRec // The LLIR lowering does not yet support Label blocks // This file contains all tests for LLIR in the original MLscript compiler. diff --git a/hkmc2/shared/src/test/mlscript/llir/Classes.mls b/hkmc2/shared/src/test/mlscript/llir/Classes.mls index 2bfb5b2433..690d5cccc5 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Classes.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Classes.mls @@ -1,5 +1,6 @@ :llir :cpp +:noTailRec // The LLIR lowering does not yet support Label blocks :intl abstract class Callable diff --git a/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls b/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls index 53d6099063..b5bee1df93 100644 --- a/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls +++ b/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls @@ -1,6 +1,7 @@ :js :llir :cpp +:noTailRec // The LLIR lowering does not yet support Label blocks :sllir :intl diff --git a/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls b/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls index 147b491a39..71fb10d514 100644 --- a/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls +++ b/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls @@ -2,6 +2,7 @@ :llir :cpp :intl +:noTailRec // The LLIR lowering does not yet support Label blocks //│ //│ Interpreted: diff --git a/hkmc2/shared/src/test/mlscript/nofib/mate.mls b/hkmc2/shared/src/test/mlscript/nofib/mate.mls index a486598674..30abe7e4a9 100644 --- a/hkmc2/shared/src/test/mlscript/nofib/mate.mls +++ b/hkmc2/shared/src/test/mlscript/nofib/mate.mls @@ -6,6 +6,8 @@ :silent import "fs" +:noTailRec + :... //│ ———————————————————————————————————————————————————————————————————————————————— fun rqpart(le, x, ys, rle, rgt, r) = if ys is @@ -573,5 +575,9 @@ print(nofibListToString(testMate_nofib(0))) //│ > if K-K7; 2. N/QR2-QN4, //│ > if K-Q7; 3. Q/QN1-K1++ //│ > if K-K6; 3. Q/QN1-Q3++ -//│ > if P/QN6xN/QR7; 2. Q/QN1-QB2, ...; 3. B/KR4-KB2++ +//│ > if K-K6; 3. Q/QN1-Q3++ +//│ > if K-K7; 2. N/QR2-QN4, +//│ > if K-Q7; 3. Q/QN1-K1++ +//│ > if K-K6; 3. Q/QN1-Q3++ +//│ > if K-K6; 3. Q/QN1-Q3++ //│ > diff --git a/hkmc2/shared/src/test/mlscript/std/FingerTreeListTest.mls b/hkmc2/shared/src/test/mlscript/std/FingerTreeListTest.mls index 01db6e89df..a37c8bf69e 100644 --- a/hkmc2/shared/src/test/mlscript/std/FingerTreeListTest.mls +++ b/hkmc2/shared/src/test/mlscript/std/FingerTreeListTest.mls @@ -136,17 +136,22 @@ fun popByIndex(start, end, acc, lft) = //│ let popByIndex; //│ popByIndex = function popByIndex(start, end, acc, lft) { //│ let scrut, tmp34, tmp35, tmp36; -//│ scrut = start >= end; -//│ if (scrut === true) { -//│ return acc -//│ } else { -//│ tmp34 = start + 1; -//│ tmp35 = runtime.safeCall(lft.at(start)); -//│ tmp36 = globalThis.Object.freeze([ -//│ ...acc, -//│ tmp35 -//│ ]); -//│ return popByIndex(tmp34, end, tmp36, lft) +//│ loopLabel: while (true) { +//│ scrut = start >= end; +//│ if (scrut === true) { +//│ return acc +//│ } else { +//│ tmp34 = start + 1; +//│ tmp35 = runtime.safeCall(lft.at(start)); +//│ tmp36 = globalThis.Object.freeze([ +//│ ...acc, +//│ tmp35 +//│ ]); +//│ start = tmp34; +//│ acc = tmp36; +//│ continue loopLabel +//│ } +//│ break; //│ } //│ }; diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls b/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls index 00afc4659f..698e2a7a56 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls @@ -28,7 +28,7 @@ fun g = 2 //│ ═══[WARNING] This annotation has no effect. fun test = - @tailcall f + @tailcall test :w let f = 0 @@ -45,6 +45,9 @@ class A with @tailrec fun f() = g() fun g() = f() +//│ ╔══[ERROR] Class methods may not yet be marked @tailrec. +//│ ║ l.46: fun f() = g() +//│ ╙── ^ :w class A with @@ -57,6 +60,9 @@ class A module A with @tailrec fun f = 2 +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.62: fun f = 2 +//│ ╙── ^ :w class A @@ -77,5 +83,5 @@ fun test = @tailcall 1 + 2 //│ ╔══[WARNING] This annotation has no effect. //│ ╟── The @tailcall annotation has no effect on calls to built-in symbols. -//│ ║ l.77: @tailcall 1 + 2 +//│ ║ l.83: @tailcall 1 + 2 //│ ╙── ^^^^^ diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls index 1800e2c677..474ec7bbba 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls @@ -1,19 +1,18 @@ :js -:tailrec :e fun f(x) = @tailcall f(x) f(x) //│ ╔══[ERROR] This call is not in tail position. -//│ ║ l.6: @tailcall f(x) +//│ ║ l.5: @tailcall f(x) //│ ╙── ^^^^ :e fun f(x) = 2 fun g(x) = @tailcall f(x) //│ ╔══[ERROR] This call is not optimized as it does not directly recurse through its parent function. -//│ ║ l.14: fun g(x) = @tailcall f(x) +//│ ║ l.13: fun g(x) = @tailcall f(x) //│ ╙── ^^^^ :e @@ -21,10 +20,10 @@ fun g(x) = @tailcall f(x) g(x) fun g(x) = f(x); f(x) //│ ╔══[ERROR] This function is not tail recursive. -//│ ║ l.20: @tailrec fun f(x) = +//│ ║ l.19: @tailrec fun f(x) = //│ ║ ^ //│ ╟── It could self-recurse through this call, which is not a tail call. -//│ ║ l.22: fun g(x) = f(x); f(x) +//│ ║ l.21: fun g(x) = f(x); f(x) //│ ╙── ^^^^ :e @@ -36,10 +35,10 @@ fun g(x) = fun h(x) = g(x) //│ ╔══[ERROR] This function is not tail recursive. -//│ ║ l.31: @tailrec fun f(x) = +//│ ║ l.30: @tailrec fun f(x) = //│ ║ ^ //│ ╟── It could self-recurse through this call, which is not a tail call. -//│ ║ l.34: f(x) +//│ ║ l.33: f(x) //│ ╙── ^^^^ :e @@ -51,10 +50,10 @@ fun g(x) = fun h(x) = f(x) //│ ╔══[ERROR] This function is not tail recursive. -//│ ║ l.46: @tailrec fun f(x) = +//│ ║ l.45: @tailrec fun f(x) = //│ ║ ^ //│ ╟── It could self-recurse through this call, which is not a tail call. -//│ ║ l.49: h(x) +//│ ║ l.48: h(x) //│ ╙── ^^^^ :e @@ -64,14 +63,14 @@ module A with @tailcall f(x - 1) @tailcall f(x - 1) //│ ╔══[ERROR] This call is not in tail position. -//│ ║ l.64: @tailcall f(x - 1) +//│ ║ l.63: @tailcall f(x - 1) //│ ╙── ^^^^^^^^ :w @tailrec fun f = 2 //│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. -//│ ║ l.72: fun f = 2 +//│ ║ l.71: fun f = 2 //│ ╙── ^ :w @@ -79,5 +78,5 @@ module A with @tailrec fun f() = 2 //│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. -//│ ║ l.80: fun f() = 2 +//│ ║ l.79: fun f() = 2 //│ ╙── ^ diff --git a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls index 1fe1abf903..d50d939358 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls @@ -1,5 +1,4 @@ :js -:tailrec :expect 200010000 fun sum_impl(n, acc) = if n == 0 then acc else sum_impl(n - 1, n + acc) @@ -18,19 +17,11 @@ fun f(a, b, c) = f(10000, 20000, 0) //│ JS (unsanitized): //│ let f, g, g_f; -//│ f = function f(a, b, c) { -//│ return g_f(1, a, b, c, undefined) -//│ }; -//│ g = function g(a, b, c, d) { -//│ return g_f(0, a, b, c, d) -//│ }; //│ g_f = function g_f(id, param0, param1, param2, param3) { //│ let scrut, scrut1, tmp, tmp1, tmp2; //│ loopLabel: while (true) { //│ if (id === 0) { //│ tmp = param2 + param3; -//│ param0 = param0; -//│ param1 = param1; //│ param2 = tmp; //│ id = 1; //│ continue loopLabel @@ -39,8 +30,6 @@ f(10000, 20000, 0) //│ if (scrut === true) { //│ tmp1 = param0 - 1; //│ param0 = tmp1; -//│ param1 = param1; -//│ param2 = param2; //│ param3 = 1; //│ id = 0; //│ continue loopLabel @@ -48,18 +37,22 @@ f(10000, 20000, 0) //│ scrut1 = param1 > 0; //│ if (scrut1 === true) { //│ tmp2 = param1 - 1; -//│ param0 = param0; //│ param1 = tmp2; -//│ param2 = param2; //│ param3 = 2; //│ id = 0; //│ continue loopLabel -//│ } else { return param2 } +//│ } else { +//│ return param2 +//│ } //│ } //│ } //│ break; //│ } //│ }; +//│ g = function g(a, b, c, d) { +//│ return g_f(0, a, b, c, d) +//│ }; +//│ f = function f(a, b, c) { return g_f(1, a, b, c, undefined) }; //│ f(10000, 20000, 0) //│ = 50000 @@ -78,14 +71,8 @@ A.sum(20000) //│ constructor() { //│ runtime.Unit; //│ } -//│ static sum(n) { -//│ loopLabel: while (true) { -//│ return A.sum_impl(n, 0); -//│ break; -//│ } -//│ } //│ static sum_impl(n, acc) { -//│ let scrut, tmp, tmp1, id; +//│ let scrut, tmp, tmp1; //│ loopLabel: while (true) { //│ scrut = n == 0; //│ if (scrut === true) { @@ -95,11 +82,13 @@ A.sum(20000) //│ tmp1 = n + acc; //│ n = tmp; //│ acc = tmp1; -//│ id = 0; //│ continue loopLabel //│ } //│ break; //│ } +//│ } +//│ static sum(n) { +//│ return A.sum_impl(n, 0) //│ } //│ toString() { return runtime.render(this); } //│ static [definitionMetadata] = ["class", "A"]; @@ -122,8 +111,7 @@ fun sumOf(x, idx, acc) = sumOf(x, idx - 1, acc + x![idx]()) sumOf(x, n - 1, 0) -// check that spreads. where supported, are compiled correctly -:sjs +// Check that spreads, where supported, are compiled correctly fun f(x, ...z) = if x < 0 then 0 else g(x, 0, x, x) @@ -131,60 +119,13 @@ fun g(x, y, ...z) = if x < 0 then 0 else f(x - 1, 0, ...[x, x, x]) f(100) -//│ JS (unsanitized): -//│ let f1, g1, f_g; -//│ g1 = function g(x1, y, ...z) { -//│ return f_g(1, x1, y, z) -//│ }; -//│ f1 = function f(x1, ...z) { -//│ return f_g(0, x1, z, undefined) -//│ }; -//│ f_g = function f_g(id, param0, param1, param2) { -//│ let scrut, scrut1, tmp4, tmp5; -//│ loopLabel: while (true) { -//│ if (id === 0) { -//│ scrut = param0 < 0; -//│ if (scrut === true) { -//│ return 0 -//│ } else { -//│ param0 = param0; -//│ param1 = [ -//│ 0, -//│ param0, -//│ param0 -//│ ]; -//│ id = 1; -//│ continue loopLabel -//│ } -//│ } else if (id === 1) { -//│ scrut1 = param0 < 0; -//│ if (scrut1 === true) { -//│ return 0 -//│ } else { -//│ tmp4 = param0 - 1; -//│ tmp5 = globalThis.Object.freeze([ -//│ param0, -//│ param0, -//│ param0 -//│ ]); -//│ param0 = tmp4; -//│ param1 = 0; -//│ param2 = tmp5; -//│ id = 0; -//│ continue loopLabel -//│ } -//│ } -//│ break; -//│ } -//│ }; -//│ f1(100) //│ = 0 -:todo +:fixme fun f(x, y, z) = @tailcall f(...[1, 2, 3]) //│ ═══[ERROR] Spreads are not yet fully supported in calls marked @tailcall. -:todo +:fixme fun f(x, ...y) = if x < 0 then g(0, 0, 0) else 0 @@ -202,7 +143,7 @@ fun f(x) = @tailcall f(x) @tailcall g() //│ ╔══[ERROR] This call is not in tail position. -//│ ║ l.201: @tailcall f(x) +//│ ║ l.143: @tailcall f(x) //│ ╙── ^ :lift @@ -211,22 +152,24 @@ fun f(x) = @tailcall f(x) @tailcall g() -:todo +// Functions inside module definitions are lifted to the top level. This means they cannot yet be optimized. :lift +:fixme module A with fun f(x) = - fun g(x) = - @tailcall f(x) - @tailcall g(x) + fun g(x) = if x < 0 then 0 else @tailcall f(x) + @tailcall g(x - 1) +A.f(10000) //│ ╔══[ERROR] This tail call exits the current scope is not optimized. -//│ ║ l.219: @tailcall f(x) -//│ ╙── ^^^ +//│ ║ l.161: fun g(x) = if x < 0 then 0 else @tailcall f(x) +//│ ╙── ^^^ //│ ╔══[ERROR] This tail call exits the current scope is not optimized. -//│ ║ l.220: @tailcall g(x) -//│ ╙── ^^^^ +//│ ║ l.162: @tailcall g(x - 1) +//│ ╙── ^^^^^^^^ +//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded // These calls are represented as field selections and don't yet have the explicitTailCall parameter. -:todo +:breakme :e module A with fun f = g @@ -248,12 +191,12 @@ class A with fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) (new A).f(10) //│ ╔══[ERROR] Calls from class methods cannot yet be marked @tailcall. -//│ ║ l.247: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) +//│ ║ l.191: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) //│ ╙── ^^^^^^^^ //│ ╔══[ERROR] Class methods may not yet be marked @tailrec. -//│ ║ l.247: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) +//│ ║ l.191: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) //│ ╙── ^ //│ ╔══[ERROR] Calls from class methods cannot yet be marked @tailcall. -//│ ║ l.248: fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) +//│ ║ l.192: fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) //│ ╙── ^^^^^^^^ //│ = 0 diff --git a/hkmc2/shared/src/test/mlscript/ucs/patterns/Compilation.mls b/hkmc2/shared/src/test/mlscript/ucs/patterns/Compilation.mls index f8647114fc..c783aabc7c 100644 --- a/hkmc2/shared/src/test/mlscript/ucs/patterns/Compilation.mls +++ b/hkmc2/shared/src/test/mlscript/ucs/patterns/Compilation.mls @@ -43,6 +43,4 @@ fun trimStart(str) = //│ ╔══[ERROR] String concatenation is not supported in pattern compilation. //│ ║ l.42: if str is @compile { (" " | "\t") ~ rest } then trimStart(rest) else str //│ ╙── ^^^^^^^^^^^^^^^^^^ -//│ ╔══[COMPILATION ERROR] No definition found in scope for member 'rest' -//│ ║ l.42: if str is @compile { (" " | "\t") ~ rest } then trimStart(rest) else str -//│ ╙── ^^^^ +//│ ═══[COMPILATION ERROR] No definition found in scope for member 'rest' diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala index b8b60db9cb..0fbbce6bd1 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala @@ -72,7 +72,7 @@ abstract class MLsDiffMaker extends DiffMaker: val importQQ = NullaryCommand("qq") val stageCode = NullaryCommand("staging") val dontRewriteWhile = NullaryCommand("dontRewriteWhile") - val tailRecOpt = NullaryCommand("tailrec") + val noTailRecOpt = NullaryCommand("noTailRec") def mkConfig: Config = import Config.* @@ -101,7 +101,7 @@ abstract class MLsDiffMaker extends DiffMaker: stageCode = stageCode.isSet, target = if wasm.isSet then CompilationTarget.Wasm else CompilationTarget.JS, rewriteWhileLoops = !dontRewriteWhile.isSet, - tailRecOpt = tailRecOpt.isSet, + tailRecOpt = !noTailRecOpt.isSet, ) From 39fcbbc8fbfaf76a49df5bc32a765800e9e862a2 Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Fri, 5 Dec 2025 18:16:35 +0800 Subject: [PATCH 10/16] fix a bad bug and update tests --- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 37 +++++++++++++++++-- .../src/test/mlscript/codegen/RandomStuff.mls | 4 +- .../src/test/mlscript/codegen/While.mls | 18 ++++----- hkmc2/shared/src/test/mlscript/nofib/mate.mls | 8 +--- .../src/test/mlscript/tailrec/TailRecOpt.mls | 12 +++--- 5 files changed, 52 insertions(+), 27 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index ede4fb60fa..e8cecb5c61 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -232,10 +232,41 @@ class TailRecOpt(using State, TL, Raise): val cont = if scc.funs.size === 1 then Continue(loopSym) else Assign(curIdSym, Value.Lit(Tree.IntLit(dSymIds(dSym))), Continue(loopSym)) - paramSyms.zip(argVals).foldRight[Block](cont): - case ((sym, res), acc) => applyResult(res)(Assign(sym, _, acc)) match - case Assign(sym, Value.Ref(sym1, _), rest) if sym === sym1 => rest + + // In some cases, we could have assignments like this: + // param0 = whatever + // param1 = + // which means param1's value is incorrect. + // We should thus assign the params to temporary symbols + // if they are needed for a subsequent assignment. + var assignedSyms: Map[VarSymbol, TempSymbol] = paramSyms.map: + case sym => sym -> TempSymbol(N, sym.nme + "_tmp") + .toMap + var requiredTmps: Set[(VarSymbol, TempSymbol)] = Set.empty + + val paramRewriter = new BlockDataTransformer(SymbolSubst()): + override def applyValue(v: Value)(k: Value => Block): Block = v match + case Value.Ref(l: VarSymbol, disamb) => assignedSyms.get(l) match + case S(v) => + requiredTmps += (l, v) + k(Value.Ref(v, disamb)) + case _ => super.applyValue(v)(k) + case _ => super.applyValue(v)(k) + + // remove symbols from assignedSyms as we encounter them + // note that foldRight will call the function right to left + val assigns = paramSyms.zip(argVals).foldRight[Block](cont): (v, acc) => + val (sym, res) = v + assignedSyms -= sym + val ret = applyResult(res)(Assign(sym, _, acc)) match + case Assign(sym, res, rest) => paramRewriter.applyResult(res)(Assign(sym, _, rest)) match + case Assign(sym, Value.Ref(sym1, _), rest) if sym === sym1 => rest + case x => x case x => x + ret + // bind the tmps + requiredTmps.toList.foldRight(assigns): + case ((v, l), acc) => Assign(l, Value.Ref(v), acc) case None => super.applyBlock(b) case _ => super.applyBlock(b) diff --git a/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls b/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls index 98b618b132..c90c88145c 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls @@ -77,10 +77,10 @@ do let foo = 1 fun foo(x) = foo //│ ╔══[ERROR] Name 'foo' is already used -//│ ║ l.81: let foo = 1 +//│ ║ l.77: let foo = 1 //│ ║ ^^^^^^^ //│ ╟── by a member declared in the same block -//│ ║ l.82: fun foo(x) = foo +//│ ║ l.78: fun foo(x) = foo //│ ╙── ^^^^^^^^^^^^^^^^ //│ JS (unsanitized): //│ let foo3, foo4; foo3 = function foo(x) { return foo4 }; foo4 = 1; diff --git a/hkmc2/shared/src/test/mlscript/codegen/While.mls b/hkmc2/shared/src/test/mlscript/codegen/While.mls index cff4bb2dd8..a916c4e47a 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/While.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/While.mls @@ -300,10 +300,10 @@ while print("Hello World"); false then 0(0) else 1 //│ ╔══[PARSE ERROR] Unexpected 'then' keyword here -//│ ║ l.301: then 0(0) +//│ ║ l.300: then 0(0) //│ ╙── ^^^^ //│ ╔══[ERROR] Unrecognized term split (false literal) -//│ ║ l.300: while print("Hello World"); false +//│ ║ l.299: while print("Hello World"); false //│ ╙── ^^^^^ //│ > Hello World //│ ═══[RUNTIME ERROR] Error: match error @@ -313,12 +313,12 @@ while { print("Hello World"), false } then 0(0) else 1 //│ ╔══[ERROR] Unexpected infix use of keyword 'then' here -//│ ║ l.313: while { print("Hello World"), false } +//│ ║ l.312: while { print("Hello World"), false } //│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.314: then 0(0) +//│ ║ l.313: then 0(0) //│ ╙── ^^^^^^^^^^^ //│ ╔══[ERROR] Illegal position for prefix keyword 'else'. -//│ ║ l.315: else 1 +//│ ║ l.314: else 1 //│ ╙── ^^^^ :fixme @@ -328,14 +328,14 @@ while then 0(0) else 1 //│ ╔══[ERROR] Unexpected infix use of keyword 'then' here -//│ ║ l.327: print("Hello World") +//│ ║ l.326: print("Hello World") //│ ║ ^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.328: false +//│ ║ l.327: false //│ ║ ^^^^^^^^^ -//│ ║ l.329: then 0(0) +//│ ║ l.328: then 0(0) //│ ╙── ^^^^^^^^^^^ //│ ╔══[ERROR] Illegal position for prefix keyword 'else'. -//│ ║ l.330: else 1 +//│ ║ l.329: else 1 //│ ╙── ^^^^ diff --git a/hkmc2/shared/src/test/mlscript/nofib/mate.mls b/hkmc2/shared/src/test/mlscript/nofib/mate.mls index 30abe7e4a9..a486598674 100644 --- a/hkmc2/shared/src/test/mlscript/nofib/mate.mls +++ b/hkmc2/shared/src/test/mlscript/nofib/mate.mls @@ -6,8 +6,6 @@ :silent import "fs" -:noTailRec - :... //│ ———————————————————————————————————————————————————————————————————————————————— fun rqpart(le, x, ys, rle, rgt, r) = if ys is @@ -575,9 +573,5 @@ print(nofibListToString(testMate_nofib(0))) //│ > if K-K7; 2. N/QR2-QN4, //│ > if K-Q7; 3. Q/QN1-K1++ //│ > if K-K6; 3. Q/QN1-Q3++ -//│ > if K-K6; 3. Q/QN1-Q3++ -//│ > if K-K7; 2. N/QR2-QN4, -//│ > if K-Q7; 3. Q/QN1-K1++ -//│ > if K-K6; 3. Q/QN1-Q3++ -//│ > if K-K6; 3. Q/QN1-Q3++ +//│ > if P/QN6xN/QR7; 2. Q/QN1-QB2, ...; 3. B/KR4-KB2++ //│ > diff --git a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls index d50d939358..1b97dbfb94 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls @@ -143,7 +143,7 @@ fun f(x) = @tailcall f(x) @tailcall g() //│ ╔══[ERROR] This call is not in tail position. -//│ ║ l.143: @tailcall f(x) +//│ ║ l.142: @tailcall f(x) //│ ╙── ^ :lift @@ -161,10 +161,10 @@ module A with @tailcall g(x - 1) A.f(10000) //│ ╔══[ERROR] This tail call exits the current scope is not optimized. -//│ ║ l.161: fun g(x) = if x < 0 then 0 else @tailcall f(x) +//│ ║ l.160: fun g(x) = if x < 0 then 0 else @tailcall f(x) //│ ╙── ^^^ //│ ╔══[ERROR] This tail call exits the current scope is not optimized. -//│ ║ l.162: @tailcall g(x - 1) +//│ ║ l.161: @tailcall g(x - 1) //│ ╙── ^^^^^^^^ //│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded @@ -191,12 +191,12 @@ class A with fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) (new A).f(10) //│ ╔══[ERROR] Calls from class methods cannot yet be marked @tailcall. -//│ ║ l.191: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) +//│ ║ l.190: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) //│ ╙── ^^^^^^^^ //│ ╔══[ERROR] Class methods may not yet be marked @tailrec. -//│ ║ l.191: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) +//│ ║ l.190: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) //│ ╙── ^ //│ ╔══[ERROR] Calls from class methods cannot yet be marked @tailcall. -//│ ║ l.192: fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) +//│ ║ l.191: fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) //│ ╙── ^^^^^^^^ //│ = 0 From c0b9aa4a28239cf20b5ff6e60d973ba9763f4235 Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Fri, 5 Dec 2025 18:31:05 +0800 Subject: [PATCH 11/16] propagate locs --- hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala | 5 ++++- hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls | 2 +- hkmc2/shared/src/test/mlscript/ucs/patterns/Compilation.mls | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index e8cecb5c61..6f5f87fbf9 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -220,7 +220,10 @@ class TailRecOpt(using State, TL, Raise): case _ => l override def applyValue(v: Value)(k: Value => Block): Block = v match - case Value.Ref(l: VarSymbol, d) => k(Value.Ref(applyVarSym(l), d)) + case Value.Ref(l: VarSymbol, d) => + val s = applyVarSym(l) + if s is l then k(v) + else k(Value.Ref(s, d)) case _ => super.applyValue(v)(k) override def applyBlock(b: Block): Block = b match diff --git a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls index 1b97dbfb94..ea4a8d459c 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls @@ -74,7 +74,7 @@ A.sum(20000) //│ static sum_impl(n, acc) { //│ let scrut, tmp, tmp1; //│ loopLabel: while (true) { -//│ scrut = n == 0; +//│ scrut = Predef.equals(n, 0); //│ if (scrut === true) { //│ return acc //│ } else { diff --git a/hkmc2/shared/src/test/mlscript/ucs/patterns/Compilation.mls b/hkmc2/shared/src/test/mlscript/ucs/patterns/Compilation.mls index c783aabc7c..f8647114fc 100644 --- a/hkmc2/shared/src/test/mlscript/ucs/patterns/Compilation.mls +++ b/hkmc2/shared/src/test/mlscript/ucs/patterns/Compilation.mls @@ -43,4 +43,6 @@ fun trimStart(str) = //│ ╔══[ERROR] String concatenation is not supported in pattern compilation. //│ ║ l.42: if str is @compile { (" " | "\t") ~ rest } then trimStart(rest) else str //│ ╙── ^^^^^^^^^^^^^^^^^^ -//│ ═══[COMPILATION ERROR] No definition found in scope for member 'rest' +//│ ╔══[COMPILATION ERROR] No definition found in scope for member 'rest' +//│ ║ l.42: if str is @compile { (" " | "\t") ~ rest } then trimStart(rest) else str +//│ ╙── ^^^^ From 1464dc4d9709401e975f195189ce2cc277d9a0fa Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Fri, 5 Dec 2025 18:59:38 +0800 Subject: [PATCH 12/16] rewrite nested defns, even when the lifter is disabled --- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 25 +- .../src/test/mlscript-compile/Predef.mjs | 204 ++--- .../src/test/mlscript-compile/Runtime.mjs | 816 +++++++++--------- 3 files changed, 525 insertions(+), 520 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 6f5f87fbf9..7e967152ed 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -215,16 +215,18 @@ class TailRecOpt(using State, TL, Raise): val paramsSet = f.params.toSet val paramsIdxes = params.zipWithIndex.toMap - def applyVarSym(l: VarSymbol): VarSymbol = paramsIdxes.get(l) match - case Some(idx) => paramSymsArr(idx) - case _ => l + val symRewriter = new BlockTransformer(SymbolSubst()): + def applyVarSym(l: VarSymbol): VarSymbol = paramsIdxes.get(l) match + case Some(idx) => paramSymsArr(idx) + case _ => l + + override def applyValue(v: Value)(k: Value => Block): Block = v match + case Value.Ref(l: VarSymbol, d) => + val s = applyVarSym(l) + if s is l then k(v) + else k(Value.Ref(s, d)) + case _ => super.applyValue(v)(k) - override def applyValue(v: Value)(k: Value => Block): Block = v match - case Value.Ref(l: VarSymbol, d) => - val s = applyVarSym(l) - if s is l then k(v) - else k(Value.Ref(s, d)) - case _ => super.applyValue(v)(k) override def applyBlock(b: Block): Block = b match case TailCallShape(dSym, c) => dSymIds.get(dSym) match @@ -272,9 +274,12 @@ class TailRecOpt(using State, TL, Raise): case ((v, l), acc) => Assign(l, Value.Ref(v), acc) case None => super.applyBlock(b) case _ => super.applyBlock(b) + + def rewrite(b: Block) = + applyBlock(symRewriter.applyBlock(b)) val arms = scc.funs.map: f => - Case.Lit(Tree.IntLit(dSymIds(f.dSym))) -> FunRewriter(f).applyBlock(f.body) + Case.Lit(Tree.IntLit(dSymIds(f.dSym))) -> FunRewriter(f).rewrite(f.body) val switch = if arms.length === 1 then arms.head._2 diff --git a/hkmc2/shared/src/test/mlscript-compile/Predef.mjs b/hkmc2/shared/src/test/mlscript-compile/Predef.mjs index 00aaa4f709..f611458447 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Predef.mjs +++ b/hkmc2/shared/src/test/mlscript-compile/Predef.mjs @@ -39,36 +39,61 @@ globalThis.Object.freeze(class Predef { this.assert = globalThis.console.assert; this.foldl = Predef.fold; } - static id(x) { - return x - } - static apply(f, ...args) { - return runtime.safeCall(f(...args)) - } - static pipeInto(x, f) { - return runtime.safeCall(f(x)) - } static pipeFrom(f, x) { return runtime.safeCall(f(x)) } - static pipeIntoHi(x, f) { - return runtime.safeCall(f(x)) + static call(receiver, f) { + return (...args) => { + return f.call(receiver, ...args) + } } - static pipeFromHi(f, x) { + static pipeIntoHi(x, f) { return runtime.safeCall(f(x)) } - static tap(x, f) { - let tmp; - tmp = runtime.safeCall(f(x)); - return (tmp , x) - } - static pat(f, x) { - let tmp; - tmp = runtime.safeCall(f(x)); - return (tmp , x) + static foldr(f) { + return (first, ...rest) => { + let len, scrut, i, init, scrut1, tmp, tmp1, tmp2, tmp3; + len = rest.length; + scrut = len === 0; + if (scrut === true) { + return first + } else { + i = len - 1; + init = runtime.safeCall(rest.at(i)); + tmp4: while (true) { + scrut1 = i > 0; + if (scrut1 === true) { + tmp = i - 1; + i = tmp; + tmp1 = runtime.safeCall(rest.at(i)); + tmp2 = runtime.safeCall(f(tmp1, init)); + init = tmp2; + tmp3 = runtime.Unit; + continue tmp4 + } else { + tmp3 = runtime.Unit; + } + break; + } + return runtime.safeCall(f(first, init)) + } + } } - static alsoDo(x, eff) { - return x + static mkStr(...xs) { + let lambda, tmp; + lambda = (undefined, function (acc, x) { + let tmp1, tmp2, tmp3; + if (typeof x === 'string') { + tmp1 = true; + } else { + tmp1 = false; + } + tmp2 = runtime.safeCall(Predef.assert(tmp1)); + tmp3 = acc + x; + return (tmp2 , tmp3) + }); + tmp = runtime.safeCall(Predef.fold(lambda)); + return runtime.safeCall(tmp(...xs)) } static andThen(f, g) { return (x) => { @@ -77,27 +102,32 @@ globalThis.Object.freeze(class Predef { return runtime.safeCall(g(tmp)) } } - static compose(f, g) { - return (x) => { - let tmp; - tmp = runtime.safeCall(g(x)); - return runtime.safeCall(f(tmp)) - } + static enterHandleBlock(handler, body) { + return Runtime.enterHandleBlock(handler, body) + } + static alsoDo(x, eff) { + return x + } + static notImplemented(msg) { + let tmp; + tmp = "Not implemented: " + msg; + throw globalThis.Error(tmp) + } + static use(instance) { + return instance } static passTo(receiver, f) { return (...args) => { return runtime.safeCall(f(receiver, ...args)) } } - static passToLo(receiver, f) { - return (...args) => { - return runtime.safeCall(f(receiver, ...args)) - } + static tap(x, f) { + let tmp; + tmp = runtime.safeCall(f(x)); + return (tmp , x) } - static call(receiver, f) { - return (...args) => { - return f.call(receiver, ...args) - } + static tuple(...xs) { + return xs } static equals(a, b) { let scrut, scrut1, scrut2, ac, scrut3, md, scrut4, scrut5, scrut6, scrut7, scrut8, scrut9, scrut10, scrut11, tmp, lambda, lambda1, tmp1, tmp2, tmp3; @@ -228,16 +258,8 @@ globalThis.Object.freeze(class Predef { } return tmp } - static nequals(a, b) { - let tmp; - tmp = Predef.equals(a, b); - return ! tmp - } - static print(...xs) { - let tmp, tmp1; - tmp = runtime.safeCall(Predef.map(Predef.renderAsStr)); - tmp1 = runtime.safeCall(tmp(...xs)); - return runtime.safeCall(globalThis.console.log(...tmp1)) + static apply(f, ...args) { + return runtime.safeCall(f(...args)) } static renderAsStr(arg) { if (typeof arg === 'string') { @@ -246,70 +268,48 @@ globalThis.Object.freeze(class Predef { return runtime.safeCall(Predef.render(arg)) } } - static notImplemented(msg) { - let tmp; - tmp = "Not implemented: " + msg; - throw globalThis.Error(tmp) - } static get notImplementedError() { throw globalThis.Error("Not implemented"); } - static tuple(...xs) { - return xs - } - static foldr(f) { - return (first, ...rest) => { - let len, scrut, i, init, scrut1, tmp, tmp1, tmp2, tmp3; - len = rest.length; - scrut = len === 0; - if (scrut === true) { - return first - } else { - i = len - 1; - init = runtime.safeCall(rest.at(i)); - tmp4: while (true) { - scrut1 = i > 0; - if (scrut1 === true) { - tmp = i - 1; - i = tmp; - tmp1 = runtime.safeCall(rest.at(i)); - tmp2 = runtime.safeCall(f(tmp1, init)); - init = tmp2; - tmp3 = runtime.Unit; - continue tmp4 - } else { - tmp3 = runtime.Unit; - } - break; - } - return runtime.safeCall(f(first, init)) - } - } + static id(x) { + return x } - static mkStr(...xs) { - let lambda, tmp; - lambda = (undefined, function (acc, x) { - let tmp1, tmp2, tmp3; - if (typeof x === 'string') { - tmp1 = true; - } else { - tmp1 = false; - } - tmp2 = runtime.safeCall(Predef.assert(tmp1)); - tmp3 = acc + x; - return (tmp2 , tmp3) - }); - tmp = runtime.safeCall(Predef.fold(lambda)); - return runtime.safeCall(tmp(...xs)) + static nequals(a, b) { + let tmp; + tmp = Predef.equals(a, b); + return ! tmp } - static use(instance) { - return instance + static compose(f, g) { + return (x) => { + let tmp; + tmp = runtime.safeCall(g(x)); + return runtime.safeCall(f(tmp)) + } } - static enterHandleBlock(handler, body) { - return Runtime.enterHandleBlock(handler, body) + static pat(f, x) { + let tmp; + tmp = runtime.safeCall(f(x)); + return (tmp , x) } static raiseUnhandledEffect() { return Runtime.mkEffect(Runtime.FatalEffect, null) + } + static pipeFromHi(f, x) { + return runtime.safeCall(f(x)) + } + static passToLo(receiver, f) { + return (...args) => { + return runtime.safeCall(f(receiver, ...args)) + } + } + static pipeInto(x, f) { + return runtime.safeCall(f(x)) + } + static print(...xs) { + let tmp, tmp1; + tmp = runtime.safeCall(Predef.map(Predef.renderAsStr)); + tmp1 = runtime.safeCall(tmp(...xs)); + return runtime.safeCall(globalThis.console.log(...tmp1)) } toString() { return runtime.render(this); } static [definitionMetadata] = ["class", "Predef"]; diff --git a/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs b/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs index d5cc8a4258..3e05bed851 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs +++ b/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs @@ -423,216 +423,6 @@ globalThis.Object.freeze(class Runtime { static [definitionMetadata] = ["class", "Int31", [null]]; }); } - static get unreachable() { - throw globalThis.Error("unreachable"); - } - static checkArgs(functionName, expected, isUB, got) { - let scrut, name, scrut1, scrut2, tmp, lambda, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12; - tmp = got < expected; - lambda = (undefined, function () { - let lambda1; - lambda1 = (undefined, function () { - return got > expected - }); - return runtime.short_and(isUB, lambda1) - }); - scrut = runtime.short_or(tmp, lambda); - if (scrut === true) { - scrut1 = functionName.length > 0; - if (scrut1 === true) { - tmp1 = " '" + functionName; - tmp2 = tmp1 + "'"; - } else { - tmp2 = ""; - } - name = tmp2; - tmp3 = "Function" + name; - tmp4 = tmp3 + " expected "; - if (isUB === true) { - tmp5 = ""; - } else { - tmp5 = "at least "; - } - tmp6 = tmp4 + tmp5; - tmp7 = tmp6 + expected; - tmp8 = tmp7 + " argument"; - scrut2 = expected === 1; - if (scrut2 === true) { - tmp9 = ""; - } else { - tmp9 = "s"; - } - tmp10 = tmp8 + tmp9; - tmp11 = tmp10 + " but got "; - tmp12 = tmp11 + got; - throw globalThis.Error(tmp12) - } else { - return runtime.Unit - } - } - static safeCall(x) { - if (x === undefined) { - return runtime.Unit - } else { - return x - } - } - static checkCall(x) { - if (x === undefined) { - throw globalThis.Error("MLscript call unexpectedly returned `undefined`, the forbidden value.") - } else { - return x - } - } - static deboundMethod(mtdName, clsName) { - let tmp, tmp1, tmp2, tmp3; - tmp = "[debinding error] Method '" + mtdName; - tmp1 = tmp + "' of class '"; - tmp2 = tmp1 + clsName; - tmp3 = tmp2 + "' was accessed without being called."; - throw globalThis.Error(tmp3) - } - static try(f) { - let res; - res = runtime.safeCall(f()); - if (res instanceof Runtime.EffectSig.class) { - return Runtime.EffectHandle(res) - } else { - return res - } - } - static printRaw(x) { - let rcd, tmp; - rcd = globalThis.Object.freeze({ - indent: 2, - breakLength: 76 - }); - tmp = Runtime.render(x, rcd); - return runtime.safeCall(globalThis.console.log(tmp)) - } - static raisePrintStackEffect(showLocals) { - return Runtime.mkEffect(Runtime.PrintStackEffect, showLocals) - } - static topLevelEffect(tr, debug) { - let scrut, tmp, tmp1, tmp2, tmp3, tmp4, tmp5; - tmp6: while (true) { - scrut = tr.handler === Runtime.PrintStackEffect; - if (scrut === true) { - tmp = Runtime.showStackTrace("Stack Trace:", tr, debug, tr.handlerFun); - tmp1 = runtime.safeCall(globalThis.console.log(tmp)); - tmp2 = Runtime.resume(tr.contTrace); - tmp3 = runtime.safeCall(tmp2(runtime.Unit)); - tr = tmp3; - tmp4 = runtime.Unit; - continue tmp6 - } else { - tmp4 = runtime.Unit; - } - break; - } - if (tr instanceof Runtime.EffectSig.class) { - tmp5 = "Error: Unhandled effect " + tr.handler.constructor.name; - throw Runtime.showStackTrace(tmp5, tr, debug, false) - } else { - return tr - } - } - static showStackTrace(header, tr, debug, showLocals) { - let msg, curHandler, atTail, scrut, cur, scrut1, locals, curLocals, loc, loc1, localsMsg, scrut2, scrut3, tmp, tmp1, lambda, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14, tmp15, tmp16, tmp17, tmp18; - msg = header; - curHandler = tr.contTrace; - atTail = true; - if (debug === true) { - tmp19: while (true) { - scrut = curHandler !== null; - if (scrut === true) { - cur = curHandler.next; - tmp20: while (true) { - scrut1 = cur !== null; - if (scrut1 === true) { - locals = cur.getLocals; - tmp = locals.length - 1; - curLocals = runtime.safeCall(locals.at(tmp)); - loc = cur.getLoc; - if (loc === null) { - tmp1 = "pc=" + cur.pc; - } else { - tmp1 = loc; - } - loc1 = tmp1; - split_root$: { - split_1$: { - if (showLocals === true) { - scrut2 = curLocals.locals.length > 0; - if (scrut2 === true) { - lambda = (undefined, function (l) { - let tmp21, tmp22; - tmp21 = l.localName + "="; - tmp22 = Rendering.render(l.value); - return tmp21 + tmp22 - }); - tmp2 = runtime.safeCall(curLocals.locals.map(lambda)); - tmp3 = runtime.safeCall(tmp2.join(", ")); - tmp4 = " with locals: " + tmp3; - break split_root$ - } else { - break split_1$ - } - } else { - break split_1$ - } - } - tmp4 = ""; - } - localsMsg = tmp4; - tmp5 = "\n\tat " + curLocals.fnName; - tmp6 = tmp5 + " ("; - tmp7 = tmp6 + loc1; - tmp8 = tmp7 + ")"; - tmp9 = msg + tmp8; - msg = tmp9; - tmp10 = msg + localsMsg; - msg = tmp10; - cur = cur.next; - atTail = false; - tmp11 = runtime.Unit; - continue tmp20 - } else { - tmp11 = runtime.Unit; - } - break; - } - curHandler = curHandler.nextHandler; - scrut3 = curHandler !== null; - if (scrut3 === true) { - tmp12 = "\n\twith handler " + curHandler.handler.constructor.name; - tmp13 = msg + tmp12; - msg = tmp13; - atTail = false; - tmp14 = runtime.Unit; - } else { - tmp14 = runtime.Unit; - } - tmp15 = tmp14; - continue tmp19 - } else { - tmp15 = runtime.Unit; - } - break; - } - if (atTail === true) { - tmp16 = msg + "\n\tat tail position"; - msg = tmp16; - tmp17 = runtime.Unit; - } else { - tmp17 = runtime.Unit; - } - tmp18 = tmp17; - } else { - tmp18 = runtime.Unit; - } - return msg - } static showFunctionContChain(cont, hl, vis, reps) { let result, scrut, scrut1, scrut2, tmp, lambda, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7; if (cont instanceof Runtime.FunctionContFrame.class) { @@ -724,49 +514,127 @@ globalThis.Object.freeze(class Runtime { } } } - static debugCont(cont) { - let tmp, tmp1, tmp2; - tmp = globalThis.Object.freeze(new globalThis.Map()); - tmp1 = globalThis.Object.freeze(new globalThis.Set()); - tmp2 = Runtime.showFunctionContChain(cont, tmp, tmp1, 0); - return runtime.safeCall(globalThis.console.log(tmp2)) - } - static debugHandler(cont) { - let tmp, tmp1, tmp2; - tmp = globalThis.Object.freeze(new globalThis.Map()); - tmp1 = globalThis.Object.freeze(new globalThis.Set()); - tmp2 = Runtime.showHandlerContChain(cont, tmp, tmp1, 0); - return runtime.safeCall(globalThis.console.log(tmp2)) - } - static debugContTrace(contTrace) { - let scrut, scrut1, vis, hl, cur, scrut2, tmp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14; - if (contTrace instanceof Runtime.ContTrace.class) { - tmp = globalThis.console.log("resumed: ", contTrace.resumed); - scrut = contTrace.last === contTrace; + static resume(contTrace) { + return (value) => { + let scrut, tmp, tmp1; + scrut = contTrace.resumed; if (scrut === true) { - tmp1 = runtime.safeCall(globalThis.console.log("")); - } else { - tmp1 = runtime.Unit; - } - scrut1 = contTrace.lastHandler === contTrace; - if (scrut1 === true) { - tmp2 = runtime.safeCall(globalThis.console.log("")); + throw globalThis.Error("Multiple resumption") } else { - tmp2 = runtime.Unit; + tmp = runtime.Unit; } - vis = globalThis.Object.freeze(new globalThis.Set()); - hl = globalThis.Object.freeze(new globalThis.Map()); - tmp3 = globalThis.Object.freeze([ - contTrace.last - ]); - tmp4 = globalThis.Object.freeze(new globalThis.Set(tmp3)); - tmp5 = hl.set("last", tmp4); - tmp6 = globalThis.Object.freeze([ - contTrace.lastHandler - ]); - tmp7 = globalThis.Object.freeze(new globalThis.Set(tmp6)); - tmp8 = hl.set("last-handler", tmp7); - tmp9 = Runtime.showFunctionContChain(contTrace.next, hl, vis, 0); + contTrace.resumed = true; + tmp1 = Runtime.resumeContTrace(contTrace, value); + return Runtime.handleEffects(tmp1) + } + } + static checkCall(x) { + if (x === undefined) { + throw globalThis.Error("MLscript call unexpectedly returned `undefined`, the forbidden value.") + } else { + return x + } + } + static deboundMethod(mtdName, clsName) { + let tmp, tmp1, tmp2, tmp3; + tmp = "[debinding error] Method '" + mtdName; + tmp1 = tmp + "' of class '"; + tmp2 = tmp1 + clsName; + tmp3 = tmp2 + "' was accessed without being called."; + throw globalThis.Error(tmp3) + } + static topLevelEffect(tr, debug) { + let scrut, tmp, tmp1, tmp2, tmp3, tmp4, tmp5; + tmp6: while (true) { + scrut = tr.handler === Runtime.PrintStackEffect; + if (scrut === true) { + tmp = Runtime.showStackTrace("Stack Trace:", tr, debug, tr.handlerFun); + tmp1 = runtime.safeCall(globalThis.console.log(tmp)); + tmp2 = Runtime.resume(tr.contTrace); + tmp3 = runtime.safeCall(tmp2(runtime.Unit)); + tr = tmp3; + tmp4 = runtime.Unit; + continue tmp6 + } else { + tmp4 = runtime.Unit; + } + break; + } + if (tr instanceof Runtime.EffectSig.class) { + tmp5 = "Error: Unhandled effect " + tr.handler.constructor.name; + throw Runtime.showStackTrace(tmp5, tr, debug, false) + } else { + return tr + } + } + static debugHandler(cont) { + let tmp, tmp1, tmp2; + tmp = globalThis.Object.freeze(new globalThis.Map()); + tmp1 = globalThis.Object.freeze(new globalThis.Set()); + tmp2 = Runtime.showHandlerContChain(cont, tmp, tmp1, 0); + return runtime.safeCall(globalThis.console.log(tmp2)) + } + static checkDepth() { + let scrut, tmp, lambda; + tmp = Runtime.stackDepth >= Runtime.stackLimit; + lambda = (undefined, function () { + return Runtime.stackHandler !== null + }); + scrut = runtime.short_and(tmp, lambda); + if (scrut === true) { + return runtime.safeCall(Runtime.stackHandler.delay()) + } else { + return runtime.Unit + } + } + static plus_impl(lhs, rhs) { + let tmp; + split_root$: { + split_1$: { + if (lhs instanceof Runtime.Int31.class) { + if (rhs instanceof Runtime.Int31.class) { + tmp = lhs + rhs; + break split_root$ + } else { + break split_1$ + } + } else { + break split_1$ + } + } + tmp = Runtime.unreachable(); + } + return tmp + } + static debugContTrace(contTrace) { + let scrut, scrut1, vis, hl, cur, scrut2, tmp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14; + if (contTrace instanceof Runtime.ContTrace.class) { + tmp = globalThis.console.log("resumed: ", contTrace.resumed); + scrut = contTrace.last === contTrace; + if (scrut === true) { + tmp1 = runtime.safeCall(globalThis.console.log("")); + } else { + tmp1 = runtime.Unit; + } + scrut1 = contTrace.lastHandler === contTrace; + if (scrut1 === true) { + tmp2 = runtime.safeCall(globalThis.console.log("")); + } else { + tmp2 = runtime.Unit; + } + vis = globalThis.Object.freeze(new globalThis.Set()); + hl = globalThis.Object.freeze(new globalThis.Map()); + tmp3 = globalThis.Object.freeze([ + contTrace.last + ]); + tmp4 = globalThis.Object.freeze(new globalThis.Set(tmp3)); + tmp5 = hl.set("last", tmp4); + tmp6 = globalThis.Object.freeze([ + contTrace.lastHandler + ]); + tmp7 = globalThis.Object.freeze(new globalThis.Set(tmp6)); + tmp8 = hl.set("last-handler", tmp7); + tmp9 = Runtime.showFunctionContChain(contTrace.next, hl, vis, 0); tmp10 = runtime.safeCall(globalThis.console.log(tmp9)); cur = contTrace.nextHandler; tmp15: while (true) { @@ -778,15 +646,261 @@ globalThis.Object.freeze(class Runtime { tmp13 = runtime.Unit; continue tmp15 } else { - tmp13 = runtime.Unit; + tmp13 = runtime.Unit; + } + break; + } + return runtime.safeCall(globalThis.console.log()) + } else { + tmp14 = runtime.safeCall(globalThis.console.log("Not a cont trace:")); + return runtime.safeCall(globalThis.console.log(contTrace)) + } + } + static handleBlockImpl(cur, handler) { + let handlerFrame; + handlerFrame = new Runtime.HandlerContFrame.class(null, null, handler); + cur.contTrace.lastHandler.nextHandler = handlerFrame; + cur.contTrace.lastHandler = handlerFrame; + cur.contTrace.last = handlerFrame; + return Runtime.handleEffects(cur) + } + static checkArgs(functionName, expected, isUB, got) { + let scrut, name, scrut1, scrut2, tmp, lambda, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12; + tmp = got < expected; + lambda = (undefined, function () { + let lambda1; + lambda1 = (undefined, function () { + return got > expected + }); + return runtime.short_and(isUB, lambda1) + }); + scrut = runtime.short_or(tmp, lambda); + if (scrut === true) { + scrut1 = functionName.length > 0; + if (scrut1 === true) { + tmp1 = " '" + functionName; + tmp2 = tmp1 + "'"; + } else { + tmp2 = ""; + } + name = tmp2; + tmp3 = "Function" + name; + tmp4 = tmp3 + " expected "; + if (isUB === true) { + tmp5 = ""; + } else { + tmp5 = "at least "; + } + tmp6 = tmp4 + tmp5; + tmp7 = tmp6 + expected; + tmp8 = tmp7 + " argument"; + scrut2 = expected === 1; + if (scrut2 === true) { + tmp9 = ""; + } else { + tmp9 = "s"; + } + tmp10 = tmp8 + tmp9; + tmp11 = tmp10 + " but got "; + tmp12 = tmp11 + got; + throw globalThis.Error(tmp12) + } else { + return runtime.Unit + } + } + static printRaw(x) { + let rcd, tmp; + rcd = globalThis.Object.freeze({ + indent: 2, + breakLength: 76 + }); + tmp = Runtime.render(x, rcd); + return runtime.safeCall(globalThis.console.log(tmp)) + } + static resumeContTrace(contTrace, value) { + let cont, handlerCont, curDepth, scrut, scrut1, tmp, tmp1, tmp2, tmp3, tmp4; + cont = contTrace.next; + handlerCont = contTrace.nextHandler; + curDepth = Runtime.stackDepth; + tmp5: while (true) { + if (cont instanceof Runtime.FunctionContFrame.class) { + tmp = runtime.safeCall(cont.resume(value)); + value = tmp; + Runtime.stackDepth = curDepth; + if (value instanceof Runtime.EffectSig.class) { + value.contTrace.last.next = cont.next; + value.contTrace.lastHandler.nextHandler = handlerCont; + scrut = contTrace.last !== cont; + if (scrut === true) { + value.contTrace.last = contTrace.last; + tmp1 = runtime.Unit; + } else { + tmp1 = runtime.Unit; + } + scrut1 = handlerCont !== null; + if (scrut1 === true) { + value.contTrace.lastHandler = contTrace.lastHandler; + tmp2 = runtime.Unit; + } else { + tmp2 = runtime.Unit; + } + return value + } else { + cont = cont.next; + tmp3 = runtime.Unit; + } + tmp4 = tmp3; + continue tmp5 + } else { + if (handlerCont instanceof Runtime.HandlerContFrame.class) { + cont = handlerCont.next; + handlerCont = handlerCont.nextHandler; + tmp4 = runtime.Unit; + continue tmp5 + } else { + return value + } + } + break; + } + return tmp4 + } + static get unreachable() { + throw globalThis.Error("unreachable"); + } + static enterHandleBlock(handler, body) { + let cur; + cur = runtime.safeCall(body()); + if (cur instanceof Runtime.EffectSig.class) { + return Runtime.handleBlockImpl(cur, handler) + } else { + return cur + } + } + static handleEffects(cur) { + let nxt, scrut, tmp, tmp1; + tmp2: while (true) { + if (cur instanceof Runtime.EffectSig.class) { + nxt = Runtime.handleEffect(cur); + scrut = cur === nxt; + if (scrut === true) { + return cur + } else { + cur = nxt; + tmp = runtime.Unit; + } + tmp1 = tmp; + continue tmp2 + } else { + return cur + } + break; + } + return tmp1 + } + static showStackTrace(header, tr, debug, showLocals) { + let msg, curHandler, atTail, scrut, cur, scrut1, locals, curLocals, loc, loc1, localsMsg, scrut2, scrut3, tmp, tmp1, lambda, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14, tmp15, tmp16, tmp17, tmp18; + msg = header; + curHandler = tr.contTrace; + atTail = true; + if (debug === true) { + tmp19: while (true) { + scrut = curHandler !== null; + if (scrut === true) { + cur = curHandler.next; + tmp20: while (true) { + scrut1 = cur !== null; + if (scrut1 === true) { + locals = cur.getLocals; + tmp = locals.length - 1; + curLocals = runtime.safeCall(locals.at(tmp)); + loc = cur.getLoc; + if (loc === null) { + tmp1 = "pc=" + cur.pc; + } else { + tmp1 = loc; + } + loc1 = tmp1; + split_root$: { + split_1$: { + if (showLocals === true) { + scrut2 = curLocals.locals.length > 0; + if (scrut2 === true) { + lambda = (undefined, function (l) { + let tmp21, tmp22; + tmp21 = l.localName + "="; + tmp22 = Rendering.render(l.value); + return tmp21 + tmp22 + }); + tmp2 = runtime.safeCall(curLocals.locals.map(lambda)); + tmp3 = runtime.safeCall(tmp2.join(", ")); + tmp4 = " with locals: " + tmp3; + break split_root$ + } else { + break split_1$ + } + } else { + break split_1$ + } + } + tmp4 = ""; + } + localsMsg = tmp4; + tmp5 = "\n\tat " + curLocals.fnName; + tmp6 = tmp5 + " ("; + tmp7 = tmp6 + loc1; + tmp8 = tmp7 + ")"; + tmp9 = msg + tmp8; + msg = tmp9; + tmp10 = msg + localsMsg; + msg = tmp10; + cur = cur.next; + atTail = false; + tmp11 = runtime.Unit; + continue tmp20 + } else { + tmp11 = runtime.Unit; + } + break; + } + curHandler = curHandler.nextHandler; + scrut3 = curHandler !== null; + if (scrut3 === true) { + tmp12 = "\n\twith handler " + curHandler.handler.constructor.name; + tmp13 = msg + tmp12; + msg = tmp13; + atTail = false; + tmp14 = runtime.Unit; + } else { + tmp14 = runtime.Unit; + } + tmp15 = tmp14; + continue tmp19 + } else { + tmp15 = runtime.Unit; } break; } - return runtime.safeCall(globalThis.console.log()) + if (atTail === true) { + tmp16 = msg + "\n\tat tail position"; + msg = tmp16; + tmp17 = runtime.Unit; + } else { + tmp17 = runtime.Unit; + } + tmp18 = tmp17; } else { - tmp14 = runtime.safeCall(globalThis.console.log("Not a cont trace:")); - return runtime.safeCall(globalThis.console.log(contTrace)) + tmp18 = runtime.Unit; } + return msg + } + static mkEffect(handler, handlerFun) { + let res, tmp; + tmp = new Runtime.ContTrace.class(null, null, null, null, false); + res = new Runtime.EffectSig.class(tmp, handler, handlerFun); + res.contTrace.last = res.contTrace; + res.contTrace.lastHandler = res.contTrace; + return res } static debugEff(eff) { let tmp, tmp1, tmp2, tmp3; @@ -800,51 +914,39 @@ globalThis.Object.freeze(class Runtime { return runtime.safeCall(globalThis.console.log(eff)) } } - static mkEffect(handler, handlerFun) { - let res, tmp; - tmp = new Runtime.ContTrace.class(null, null, null, null, false); - res = new Runtime.EffectSig.class(tmp, handler, handlerFun); - res.contTrace.last = res.contTrace; - res.contTrace.lastHandler = res.contTrace; - return res - } - static handleBlockImpl(cur, handler) { - let handlerFrame; - handlerFrame = new Runtime.HandlerContFrame.class(null, null, handler); - cur.contTrace.lastHandler.nextHandler = handlerFrame; - cur.contTrace.lastHandler = handlerFrame; - cur.contTrace.last = handlerFrame; - return Runtime.handleEffects(cur) - } - static enterHandleBlock(handler, body) { - let cur; - cur = runtime.safeCall(body()); - if (cur instanceof Runtime.EffectSig.class) { - return Runtime.handleBlockImpl(cur, handler) - } else { - return cur - } - } - static handleEffects(cur) { - let nxt, scrut, tmp, tmp1; + static runStackSafe(limit, f) { + let result, scrut, saved, tmp, tmp1; + Runtime.stackLimit = limit; + Runtime.stackDepth = 1; + Runtime.stackHandler = Runtime.StackDelayHandler; + result = Runtime.enterHandleBlock(Runtime.StackDelayHandler, f); + Runtime.stackDepth = 1; tmp2: while (true) { - if (cur instanceof Runtime.EffectSig.class) { - nxt = Runtime.handleEffect(cur); - scrut = cur === nxt; - if (scrut === true) { - return cur - } else { - cur = nxt; - tmp = runtime.Unit; - } - tmp1 = tmp; + scrut = Runtime.stackResume !== null; + if (scrut === true) { + saved = Runtime.stackResume; + Runtime.stackResume = null; + tmp = runtime.safeCall(saved()); + result = tmp; + Runtime.stackDepth = 1; + tmp1 = runtime.Unit; continue tmp2 } else { - return cur + tmp1 = runtime.Unit; } break; } - return tmp1 + Runtime.stackLimit = 0; + Runtime.stackDepth = 0; + Runtime.stackHandler = null; + return result + } + static debugCont(cont) { + let tmp, tmp1, tmp2; + tmp = globalThis.Object.freeze(new globalThis.Map()); + tmp1 = globalThis.Object.freeze(new globalThis.Set()); + tmp2 = Runtime.showFunctionContChain(cont, tmp, tmp1, 0); + return runtime.safeCall(globalThis.console.log(tmp2)) } static handleEffect(cur) { let prevHandlerFrame, scrut, scrut1, scrut2, handlerFrame, saved, scrut3, scrut4, tmp, tmp1, tmp2, tmp3, tmp4, tmp5; @@ -907,126 +1009,24 @@ globalThis.Object.freeze(class Runtime { return Runtime.resumeContTrace(saved, cur) } } - static resume(contTrace) { - return (value) => { - let scrut, tmp, tmp1; - scrut = contTrace.resumed; - if (scrut === true) { - throw globalThis.Error("Multiple resumption") - } else { - tmp = runtime.Unit; - } - contTrace.resumed = true; - tmp1 = Runtime.resumeContTrace(contTrace, value); - return Runtime.handleEffects(tmp1) - } - } - static resumeContTrace(contTrace, value) { - let cont, handlerCont, curDepth, scrut, scrut1, tmp, tmp1, tmp2, tmp3, tmp4; - cont = contTrace.next; - handlerCont = contTrace.nextHandler; - curDepth = Runtime.stackDepth; - tmp5: while (true) { - if (cont instanceof Runtime.FunctionContFrame.class) { - tmp = runtime.safeCall(cont.resume(value)); - value = tmp; - Runtime.stackDepth = curDepth; - if (value instanceof Runtime.EffectSig.class) { - value.contTrace.last.next = cont.next; - value.contTrace.lastHandler.nextHandler = handlerCont; - scrut = contTrace.last !== cont; - if (scrut === true) { - value.contTrace.last = contTrace.last; - tmp1 = runtime.Unit; - } else { - tmp1 = runtime.Unit; - } - scrut1 = handlerCont !== null; - if (scrut1 === true) { - value.contTrace.lastHandler = contTrace.lastHandler; - tmp2 = runtime.Unit; - } else { - tmp2 = runtime.Unit; - } - return value - } else { - cont = cont.next; - tmp3 = runtime.Unit; - } - tmp4 = tmp3; - continue tmp5 - } else { - if (handlerCont instanceof Runtime.HandlerContFrame.class) { - cont = handlerCont.next; - handlerCont = handlerCont.nextHandler; - tmp4 = runtime.Unit; - continue tmp5 - } else { - return value - } - } - break; - } - return tmp4 + static raisePrintStackEffect(showLocals) { + return Runtime.mkEffect(Runtime.PrintStackEffect, showLocals) } - static checkDepth() { - let scrut, tmp, lambda; - tmp = Runtime.stackDepth >= Runtime.stackLimit; - lambda = (undefined, function () { - return Runtime.stackHandler !== null - }); - scrut = runtime.short_and(tmp, lambda); - if (scrut === true) { - return runtime.safeCall(Runtime.stackHandler.delay()) - } else { + static safeCall(x) { + if (x === undefined) { return runtime.Unit + } else { + return x } } - static runStackSafe(limit, f) { - let result, scrut, saved, tmp, tmp1; - Runtime.stackLimit = limit; - Runtime.stackDepth = 1; - Runtime.stackHandler = Runtime.StackDelayHandler; - result = Runtime.enterHandleBlock(Runtime.StackDelayHandler, f); - Runtime.stackDepth = 1; - tmp2: while (true) { - scrut = Runtime.stackResume !== null; - if (scrut === true) { - saved = Runtime.stackResume; - Runtime.stackResume = null; - tmp = runtime.safeCall(saved()); - result = tmp; - Runtime.stackDepth = 1; - tmp1 = runtime.Unit; - continue tmp2 - } else { - tmp1 = runtime.Unit; - } - break; - } - Runtime.stackLimit = 0; - Runtime.stackDepth = 0; - Runtime.stackHandler = null; - return result - } - static plus_impl(lhs, rhs) { - let tmp; - split_root$: { - split_1$: { - if (lhs instanceof Runtime.Int31.class) { - if (rhs instanceof Runtime.Int31.class) { - tmp = lhs + rhs; - break split_root$ - } else { - break split_1$ - } - } else { - break split_1$ - } - } - tmp = Runtime.unreachable(); + static try(f) { + let res; + res = runtime.safeCall(f()); + if (res instanceof Runtime.EffectSig.class) { + return Runtime.EffectHandle(res) + } else { + return res } - return tmp } toString() { return runtime.render(this); } static [definitionMetadata] = ["class", "Runtime"]; From e94892d6a7ade7f0c92c7db7c9a4fda5ae52337f Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Fri, 5 Dec 2025 20:32:22 +0800 Subject: [PATCH 13/16] preserve method order --- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 6 +- .../src/test/mlscript-compile/Predef.mjs | 204 ++--- .../src/test/mlscript-compile/Runtime.mjs | 732 +++++++++--------- 3 files changed, 473 insertions(+), 469 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 7e967152ed..debdf91f37 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -319,10 +319,14 @@ class TailRecOpt(using State, TL, Raise): else (S(loopDefn), rewrittenFuns) def optFunctions(fs: List[FunDefn], owner: Opt[InnerSymbol]) = - partFns(fs).map(optScc(_, owner)).foldLeft[(List[FunDefn], List[FunDefn])](Nil, Nil): + val (newFsOpt, fsOpt) = partFns(fs).map(optScc(_, owner)).foldLeft[(List[FunDefn], List[FunDefn])](Nil, Nil): case ((newFns, fns), (newFnOpt, fns_)) => newFnOpt match case Some(value) => (value :: newFns, fns_ ::: fns) case None => (newFns, fns_ ::: fns) + // preserve the order of function defns + val fMap = fsOpt.map(f => (f.dSym, f)).toMap + val fsRet = fs.map(f => fMap(f.dSym)) + (newFsOpt, fsRet) def reportClassesTailrec(c: ClsLikeDefn) = new BlockTraverserShallow(): diff --git a/hkmc2/shared/src/test/mlscript-compile/Predef.mjs b/hkmc2/shared/src/test/mlscript-compile/Predef.mjs index f611458447..00aaa4f709 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Predef.mjs +++ b/hkmc2/shared/src/test/mlscript-compile/Predef.mjs @@ -39,61 +39,36 @@ globalThis.Object.freeze(class Predef { this.assert = globalThis.console.assert; this.foldl = Predef.fold; } - static pipeFrom(f, x) { + static id(x) { + return x + } + static apply(f, ...args) { + return runtime.safeCall(f(...args)) + } + static pipeInto(x, f) { return runtime.safeCall(f(x)) } - static call(receiver, f) { - return (...args) => { - return f.call(receiver, ...args) - } + static pipeFrom(f, x) { + return runtime.safeCall(f(x)) } static pipeIntoHi(x, f) { return runtime.safeCall(f(x)) } - static foldr(f) { - return (first, ...rest) => { - let len, scrut, i, init, scrut1, tmp, tmp1, tmp2, tmp3; - len = rest.length; - scrut = len === 0; - if (scrut === true) { - return first - } else { - i = len - 1; - init = runtime.safeCall(rest.at(i)); - tmp4: while (true) { - scrut1 = i > 0; - if (scrut1 === true) { - tmp = i - 1; - i = tmp; - tmp1 = runtime.safeCall(rest.at(i)); - tmp2 = runtime.safeCall(f(tmp1, init)); - init = tmp2; - tmp3 = runtime.Unit; - continue tmp4 - } else { - tmp3 = runtime.Unit; - } - break; - } - return runtime.safeCall(f(first, init)) - } - } + static pipeFromHi(f, x) { + return runtime.safeCall(f(x)) } - static mkStr(...xs) { - let lambda, tmp; - lambda = (undefined, function (acc, x) { - let tmp1, tmp2, tmp3; - if (typeof x === 'string') { - tmp1 = true; - } else { - tmp1 = false; - } - tmp2 = runtime.safeCall(Predef.assert(tmp1)); - tmp3 = acc + x; - return (tmp2 , tmp3) - }); - tmp = runtime.safeCall(Predef.fold(lambda)); - return runtime.safeCall(tmp(...xs)) + static tap(x, f) { + let tmp; + tmp = runtime.safeCall(f(x)); + return (tmp , x) + } + static pat(f, x) { + let tmp; + tmp = runtime.safeCall(f(x)); + return (tmp , x) + } + static alsoDo(x, eff) { + return x } static andThen(f, g) { return (x) => { @@ -102,32 +77,27 @@ globalThis.Object.freeze(class Predef { return runtime.safeCall(g(tmp)) } } - static enterHandleBlock(handler, body) { - return Runtime.enterHandleBlock(handler, body) - } - static alsoDo(x, eff) { - return x - } - static notImplemented(msg) { - let tmp; - tmp = "Not implemented: " + msg; - throw globalThis.Error(tmp) - } - static use(instance) { - return instance + static compose(f, g) { + return (x) => { + let tmp; + tmp = runtime.safeCall(g(x)); + return runtime.safeCall(f(tmp)) + } } static passTo(receiver, f) { return (...args) => { return runtime.safeCall(f(receiver, ...args)) } } - static tap(x, f) { - let tmp; - tmp = runtime.safeCall(f(x)); - return (tmp , x) + static passToLo(receiver, f) { + return (...args) => { + return runtime.safeCall(f(receiver, ...args)) + } } - static tuple(...xs) { - return xs + static call(receiver, f) { + return (...args) => { + return f.call(receiver, ...args) + } } static equals(a, b) { let scrut, scrut1, scrut2, ac, scrut3, md, scrut4, scrut5, scrut6, scrut7, scrut8, scrut9, scrut10, scrut11, tmp, lambda, lambda1, tmp1, tmp2, tmp3; @@ -258,8 +228,16 @@ globalThis.Object.freeze(class Predef { } return tmp } - static apply(f, ...args) { - return runtime.safeCall(f(...args)) + static nequals(a, b) { + let tmp; + tmp = Predef.equals(a, b); + return ! tmp + } + static print(...xs) { + let tmp, tmp1; + tmp = runtime.safeCall(Predef.map(Predef.renderAsStr)); + tmp1 = runtime.safeCall(tmp(...xs)); + return runtime.safeCall(globalThis.console.log(...tmp1)) } static renderAsStr(arg) { if (typeof arg === 'string') { @@ -268,48 +246,70 @@ globalThis.Object.freeze(class Predef { return runtime.safeCall(Predef.render(arg)) } } + static notImplemented(msg) { + let tmp; + tmp = "Not implemented: " + msg; + throw globalThis.Error(tmp) + } static get notImplementedError() { throw globalThis.Error("Not implemented"); } - static id(x) { - return x - } - static nequals(a, b) { - let tmp; - tmp = Predef.equals(a, b); - return ! tmp + static tuple(...xs) { + return xs } - static compose(f, g) { - return (x) => { - let tmp; - tmp = runtime.safeCall(g(x)); - return runtime.safeCall(f(tmp)) + static foldr(f) { + return (first, ...rest) => { + let len, scrut, i, init, scrut1, tmp, tmp1, tmp2, tmp3; + len = rest.length; + scrut = len === 0; + if (scrut === true) { + return first + } else { + i = len - 1; + init = runtime.safeCall(rest.at(i)); + tmp4: while (true) { + scrut1 = i > 0; + if (scrut1 === true) { + tmp = i - 1; + i = tmp; + tmp1 = runtime.safeCall(rest.at(i)); + tmp2 = runtime.safeCall(f(tmp1, init)); + init = tmp2; + tmp3 = runtime.Unit; + continue tmp4 + } else { + tmp3 = runtime.Unit; + } + break; + } + return runtime.safeCall(f(first, init)) + } } } - static pat(f, x) { - let tmp; - tmp = runtime.safeCall(f(x)); - return (tmp , x) - } - static raiseUnhandledEffect() { - return Runtime.mkEffect(Runtime.FatalEffect, null) - } - static pipeFromHi(f, x) { - return runtime.safeCall(f(x)) + static mkStr(...xs) { + let lambda, tmp; + lambda = (undefined, function (acc, x) { + let tmp1, tmp2, tmp3; + if (typeof x === 'string') { + tmp1 = true; + } else { + tmp1 = false; + } + tmp2 = runtime.safeCall(Predef.assert(tmp1)); + tmp3 = acc + x; + return (tmp2 , tmp3) + }); + tmp = runtime.safeCall(Predef.fold(lambda)); + return runtime.safeCall(tmp(...xs)) } - static passToLo(receiver, f) { - return (...args) => { - return runtime.safeCall(f(receiver, ...args)) - } + static use(instance) { + return instance } - static pipeInto(x, f) { - return runtime.safeCall(f(x)) + static enterHandleBlock(handler, body) { + return Runtime.enterHandleBlock(handler, body) } - static print(...xs) { - let tmp, tmp1; - tmp = runtime.safeCall(Predef.map(Predef.renderAsStr)); - tmp1 = runtime.safeCall(tmp(...xs)); - return runtime.safeCall(globalThis.console.log(...tmp1)) + static raiseUnhandledEffect() { + return Runtime.mkEffect(Runtime.FatalEffect, null) } toString() { return runtime.render(this); } static [definitionMetadata] = ["class", "Predef"]; diff --git a/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs b/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs index 3e05bed851..d5cc8a4258 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs +++ b/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs @@ -423,6 +423,216 @@ globalThis.Object.freeze(class Runtime { static [definitionMetadata] = ["class", "Int31", [null]]; }); } + static get unreachable() { + throw globalThis.Error("unreachable"); + } + static checkArgs(functionName, expected, isUB, got) { + let scrut, name, scrut1, scrut2, tmp, lambda, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12; + tmp = got < expected; + lambda = (undefined, function () { + let lambda1; + lambda1 = (undefined, function () { + return got > expected + }); + return runtime.short_and(isUB, lambda1) + }); + scrut = runtime.short_or(tmp, lambda); + if (scrut === true) { + scrut1 = functionName.length > 0; + if (scrut1 === true) { + tmp1 = " '" + functionName; + tmp2 = tmp1 + "'"; + } else { + tmp2 = ""; + } + name = tmp2; + tmp3 = "Function" + name; + tmp4 = tmp3 + " expected "; + if (isUB === true) { + tmp5 = ""; + } else { + tmp5 = "at least "; + } + tmp6 = tmp4 + tmp5; + tmp7 = tmp6 + expected; + tmp8 = tmp7 + " argument"; + scrut2 = expected === 1; + if (scrut2 === true) { + tmp9 = ""; + } else { + tmp9 = "s"; + } + tmp10 = tmp8 + tmp9; + tmp11 = tmp10 + " but got "; + tmp12 = tmp11 + got; + throw globalThis.Error(tmp12) + } else { + return runtime.Unit + } + } + static safeCall(x) { + if (x === undefined) { + return runtime.Unit + } else { + return x + } + } + static checkCall(x) { + if (x === undefined) { + throw globalThis.Error("MLscript call unexpectedly returned `undefined`, the forbidden value.") + } else { + return x + } + } + static deboundMethod(mtdName, clsName) { + let tmp, tmp1, tmp2, tmp3; + tmp = "[debinding error] Method '" + mtdName; + tmp1 = tmp + "' of class '"; + tmp2 = tmp1 + clsName; + tmp3 = tmp2 + "' was accessed without being called."; + throw globalThis.Error(tmp3) + } + static try(f) { + let res; + res = runtime.safeCall(f()); + if (res instanceof Runtime.EffectSig.class) { + return Runtime.EffectHandle(res) + } else { + return res + } + } + static printRaw(x) { + let rcd, tmp; + rcd = globalThis.Object.freeze({ + indent: 2, + breakLength: 76 + }); + tmp = Runtime.render(x, rcd); + return runtime.safeCall(globalThis.console.log(tmp)) + } + static raisePrintStackEffect(showLocals) { + return Runtime.mkEffect(Runtime.PrintStackEffect, showLocals) + } + static topLevelEffect(tr, debug) { + let scrut, tmp, tmp1, tmp2, tmp3, tmp4, tmp5; + tmp6: while (true) { + scrut = tr.handler === Runtime.PrintStackEffect; + if (scrut === true) { + tmp = Runtime.showStackTrace("Stack Trace:", tr, debug, tr.handlerFun); + tmp1 = runtime.safeCall(globalThis.console.log(tmp)); + tmp2 = Runtime.resume(tr.contTrace); + tmp3 = runtime.safeCall(tmp2(runtime.Unit)); + tr = tmp3; + tmp4 = runtime.Unit; + continue tmp6 + } else { + tmp4 = runtime.Unit; + } + break; + } + if (tr instanceof Runtime.EffectSig.class) { + tmp5 = "Error: Unhandled effect " + tr.handler.constructor.name; + throw Runtime.showStackTrace(tmp5, tr, debug, false) + } else { + return tr + } + } + static showStackTrace(header, tr, debug, showLocals) { + let msg, curHandler, atTail, scrut, cur, scrut1, locals, curLocals, loc, loc1, localsMsg, scrut2, scrut3, tmp, tmp1, lambda, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14, tmp15, tmp16, tmp17, tmp18; + msg = header; + curHandler = tr.contTrace; + atTail = true; + if (debug === true) { + tmp19: while (true) { + scrut = curHandler !== null; + if (scrut === true) { + cur = curHandler.next; + tmp20: while (true) { + scrut1 = cur !== null; + if (scrut1 === true) { + locals = cur.getLocals; + tmp = locals.length - 1; + curLocals = runtime.safeCall(locals.at(tmp)); + loc = cur.getLoc; + if (loc === null) { + tmp1 = "pc=" + cur.pc; + } else { + tmp1 = loc; + } + loc1 = tmp1; + split_root$: { + split_1$: { + if (showLocals === true) { + scrut2 = curLocals.locals.length > 0; + if (scrut2 === true) { + lambda = (undefined, function (l) { + let tmp21, tmp22; + tmp21 = l.localName + "="; + tmp22 = Rendering.render(l.value); + return tmp21 + tmp22 + }); + tmp2 = runtime.safeCall(curLocals.locals.map(lambda)); + tmp3 = runtime.safeCall(tmp2.join(", ")); + tmp4 = " with locals: " + tmp3; + break split_root$ + } else { + break split_1$ + } + } else { + break split_1$ + } + } + tmp4 = ""; + } + localsMsg = tmp4; + tmp5 = "\n\tat " + curLocals.fnName; + tmp6 = tmp5 + " ("; + tmp7 = tmp6 + loc1; + tmp8 = tmp7 + ")"; + tmp9 = msg + tmp8; + msg = tmp9; + tmp10 = msg + localsMsg; + msg = tmp10; + cur = cur.next; + atTail = false; + tmp11 = runtime.Unit; + continue tmp20 + } else { + tmp11 = runtime.Unit; + } + break; + } + curHandler = curHandler.nextHandler; + scrut3 = curHandler !== null; + if (scrut3 === true) { + tmp12 = "\n\twith handler " + curHandler.handler.constructor.name; + tmp13 = msg + tmp12; + msg = tmp13; + atTail = false; + tmp14 = runtime.Unit; + } else { + tmp14 = runtime.Unit; + } + tmp15 = tmp14; + continue tmp19 + } else { + tmp15 = runtime.Unit; + } + break; + } + if (atTail === true) { + tmp16 = msg + "\n\tat tail position"; + msg = tmp16; + tmp17 = runtime.Unit; + } else { + tmp17 = runtime.Unit; + } + tmp18 = tmp17; + } else { + tmp18 = runtime.Unit; + } + return msg + } static showFunctionContChain(cont, hl, vis, reps) { let result, scrut, scrut1, scrut2, tmp, lambda, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7; if (cont instanceof Runtime.FunctionContFrame.class) { @@ -514,107 +724,29 @@ globalThis.Object.freeze(class Runtime { } } } - static resume(contTrace) { - return (value) => { - let scrut, tmp, tmp1; - scrut = contTrace.resumed; + static debugCont(cont) { + let tmp, tmp1, tmp2; + tmp = globalThis.Object.freeze(new globalThis.Map()); + tmp1 = globalThis.Object.freeze(new globalThis.Set()); + tmp2 = Runtime.showFunctionContChain(cont, tmp, tmp1, 0); + return runtime.safeCall(globalThis.console.log(tmp2)) + } + static debugHandler(cont) { + let tmp, tmp1, tmp2; + tmp = globalThis.Object.freeze(new globalThis.Map()); + tmp1 = globalThis.Object.freeze(new globalThis.Set()); + tmp2 = Runtime.showHandlerContChain(cont, tmp, tmp1, 0); + return runtime.safeCall(globalThis.console.log(tmp2)) + } + static debugContTrace(contTrace) { + let scrut, scrut1, vis, hl, cur, scrut2, tmp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14; + if (contTrace instanceof Runtime.ContTrace.class) { + tmp = globalThis.console.log("resumed: ", contTrace.resumed); + scrut = contTrace.last === contTrace; if (scrut === true) { - throw globalThis.Error("Multiple resumption") + tmp1 = runtime.safeCall(globalThis.console.log("")); } else { - tmp = runtime.Unit; - } - contTrace.resumed = true; - tmp1 = Runtime.resumeContTrace(contTrace, value); - return Runtime.handleEffects(tmp1) - } - } - static checkCall(x) { - if (x === undefined) { - throw globalThis.Error("MLscript call unexpectedly returned `undefined`, the forbidden value.") - } else { - return x - } - } - static deboundMethod(mtdName, clsName) { - let tmp, tmp1, tmp2, tmp3; - tmp = "[debinding error] Method '" + mtdName; - tmp1 = tmp + "' of class '"; - tmp2 = tmp1 + clsName; - tmp3 = tmp2 + "' was accessed without being called."; - throw globalThis.Error(tmp3) - } - static topLevelEffect(tr, debug) { - let scrut, tmp, tmp1, tmp2, tmp3, tmp4, tmp5; - tmp6: while (true) { - scrut = tr.handler === Runtime.PrintStackEffect; - if (scrut === true) { - tmp = Runtime.showStackTrace("Stack Trace:", tr, debug, tr.handlerFun); - tmp1 = runtime.safeCall(globalThis.console.log(tmp)); - tmp2 = Runtime.resume(tr.contTrace); - tmp3 = runtime.safeCall(tmp2(runtime.Unit)); - tr = tmp3; - tmp4 = runtime.Unit; - continue tmp6 - } else { - tmp4 = runtime.Unit; - } - break; - } - if (tr instanceof Runtime.EffectSig.class) { - tmp5 = "Error: Unhandled effect " + tr.handler.constructor.name; - throw Runtime.showStackTrace(tmp5, tr, debug, false) - } else { - return tr - } - } - static debugHandler(cont) { - let tmp, tmp1, tmp2; - tmp = globalThis.Object.freeze(new globalThis.Map()); - tmp1 = globalThis.Object.freeze(new globalThis.Set()); - tmp2 = Runtime.showHandlerContChain(cont, tmp, tmp1, 0); - return runtime.safeCall(globalThis.console.log(tmp2)) - } - static checkDepth() { - let scrut, tmp, lambda; - tmp = Runtime.stackDepth >= Runtime.stackLimit; - lambda = (undefined, function () { - return Runtime.stackHandler !== null - }); - scrut = runtime.short_and(tmp, lambda); - if (scrut === true) { - return runtime.safeCall(Runtime.stackHandler.delay()) - } else { - return runtime.Unit - } - } - static plus_impl(lhs, rhs) { - let tmp; - split_root$: { - split_1$: { - if (lhs instanceof Runtime.Int31.class) { - if (rhs instanceof Runtime.Int31.class) { - tmp = lhs + rhs; - break split_root$ - } else { - break split_1$ - } - } else { - break split_1$ - } - } - tmp = Runtime.unreachable(); - } - return tmp - } - static debugContTrace(contTrace) { - let scrut, scrut1, vis, hl, cur, scrut2, tmp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14; - if (contTrace instanceof Runtime.ContTrace.class) { - tmp = globalThis.console.log("resumed: ", contTrace.resumed); - scrut = contTrace.last === contTrace; - if (scrut === true) { - tmp1 = runtime.safeCall(globalThis.console.log("")); - } else { - tmp1 = runtime.Unit; + tmp1 = runtime.Unit; } scrut1 = contTrace.lastHandler === contTrace; if (scrut1 === true) { @@ -656,6 +788,26 @@ globalThis.Object.freeze(class Runtime { return runtime.safeCall(globalThis.console.log(contTrace)) } } + static debugEff(eff) { + let tmp, tmp1, tmp2, tmp3; + if (eff instanceof Runtime.EffectSig.class) { + tmp = runtime.safeCall(globalThis.console.log("Debug EffectSig:")); + tmp1 = globalThis.console.log("handler: ", eff.handler.constructor.name); + tmp2 = globalThis.console.log("handlerFun: ", eff.handlerFun); + return Runtime.debugContTrace(eff.contTrace) + } else { + tmp3 = runtime.safeCall(globalThis.console.log("Not an effect:")); + return runtime.safeCall(globalThis.console.log(eff)) + } + } + static mkEffect(handler, handlerFun) { + let res, tmp; + tmp = new Runtime.ContTrace.class(null, null, null, null, false); + res = new Runtime.EffectSig.class(tmp, handler, handlerFun); + res.contTrace.last = res.contTrace; + res.contTrace.lastHandler = res.contTrace; + return res + } static handleBlockImpl(cur, handler) { let handlerFrame; handlerFrame = new Runtime.HandlerContFrame.class(null, null, handler); @@ -664,110 +816,6 @@ globalThis.Object.freeze(class Runtime { cur.contTrace.last = handlerFrame; return Runtime.handleEffects(cur) } - static checkArgs(functionName, expected, isUB, got) { - let scrut, name, scrut1, scrut2, tmp, lambda, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12; - tmp = got < expected; - lambda = (undefined, function () { - let lambda1; - lambda1 = (undefined, function () { - return got > expected - }); - return runtime.short_and(isUB, lambda1) - }); - scrut = runtime.short_or(tmp, lambda); - if (scrut === true) { - scrut1 = functionName.length > 0; - if (scrut1 === true) { - tmp1 = " '" + functionName; - tmp2 = tmp1 + "'"; - } else { - tmp2 = ""; - } - name = tmp2; - tmp3 = "Function" + name; - tmp4 = tmp3 + " expected "; - if (isUB === true) { - tmp5 = ""; - } else { - tmp5 = "at least "; - } - tmp6 = tmp4 + tmp5; - tmp7 = tmp6 + expected; - tmp8 = tmp7 + " argument"; - scrut2 = expected === 1; - if (scrut2 === true) { - tmp9 = ""; - } else { - tmp9 = "s"; - } - tmp10 = tmp8 + tmp9; - tmp11 = tmp10 + " but got "; - tmp12 = tmp11 + got; - throw globalThis.Error(tmp12) - } else { - return runtime.Unit - } - } - static printRaw(x) { - let rcd, tmp; - rcd = globalThis.Object.freeze({ - indent: 2, - breakLength: 76 - }); - tmp = Runtime.render(x, rcd); - return runtime.safeCall(globalThis.console.log(tmp)) - } - static resumeContTrace(contTrace, value) { - let cont, handlerCont, curDepth, scrut, scrut1, tmp, tmp1, tmp2, tmp3, tmp4; - cont = contTrace.next; - handlerCont = contTrace.nextHandler; - curDepth = Runtime.stackDepth; - tmp5: while (true) { - if (cont instanceof Runtime.FunctionContFrame.class) { - tmp = runtime.safeCall(cont.resume(value)); - value = tmp; - Runtime.stackDepth = curDepth; - if (value instanceof Runtime.EffectSig.class) { - value.contTrace.last.next = cont.next; - value.contTrace.lastHandler.nextHandler = handlerCont; - scrut = contTrace.last !== cont; - if (scrut === true) { - value.contTrace.last = contTrace.last; - tmp1 = runtime.Unit; - } else { - tmp1 = runtime.Unit; - } - scrut1 = handlerCont !== null; - if (scrut1 === true) { - value.contTrace.lastHandler = contTrace.lastHandler; - tmp2 = runtime.Unit; - } else { - tmp2 = runtime.Unit; - } - return value - } else { - cont = cont.next; - tmp3 = runtime.Unit; - } - tmp4 = tmp3; - continue tmp5 - } else { - if (handlerCont instanceof Runtime.HandlerContFrame.class) { - cont = handlerCont.next; - handlerCont = handlerCont.nextHandler; - tmp4 = runtime.Unit; - continue tmp5 - } else { - return value - } - } - break; - } - return tmp4 - } - static get unreachable() { - throw globalThis.Error("unreachable"); - } static enterHandleBlock(handler, body) { let cur; cur = runtime.safeCall(body()); @@ -798,156 +846,6 @@ globalThis.Object.freeze(class Runtime { } return tmp1 } - static showStackTrace(header, tr, debug, showLocals) { - let msg, curHandler, atTail, scrut, cur, scrut1, locals, curLocals, loc, loc1, localsMsg, scrut2, scrut3, tmp, tmp1, lambda, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14, tmp15, tmp16, tmp17, tmp18; - msg = header; - curHandler = tr.contTrace; - atTail = true; - if (debug === true) { - tmp19: while (true) { - scrut = curHandler !== null; - if (scrut === true) { - cur = curHandler.next; - tmp20: while (true) { - scrut1 = cur !== null; - if (scrut1 === true) { - locals = cur.getLocals; - tmp = locals.length - 1; - curLocals = runtime.safeCall(locals.at(tmp)); - loc = cur.getLoc; - if (loc === null) { - tmp1 = "pc=" + cur.pc; - } else { - tmp1 = loc; - } - loc1 = tmp1; - split_root$: { - split_1$: { - if (showLocals === true) { - scrut2 = curLocals.locals.length > 0; - if (scrut2 === true) { - lambda = (undefined, function (l) { - let tmp21, tmp22; - tmp21 = l.localName + "="; - tmp22 = Rendering.render(l.value); - return tmp21 + tmp22 - }); - tmp2 = runtime.safeCall(curLocals.locals.map(lambda)); - tmp3 = runtime.safeCall(tmp2.join(", ")); - tmp4 = " with locals: " + tmp3; - break split_root$ - } else { - break split_1$ - } - } else { - break split_1$ - } - } - tmp4 = ""; - } - localsMsg = tmp4; - tmp5 = "\n\tat " + curLocals.fnName; - tmp6 = tmp5 + " ("; - tmp7 = tmp6 + loc1; - tmp8 = tmp7 + ")"; - tmp9 = msg + tmp8; - msg = tmp9; - tmp10 = msg + localsMsg; - msg = tmp10; - cur = cur.next; - atTail = false; - tmp11 = runtime.Unit; - continue tmp20 - } else { - tmp11 = runtime.Unit; - } - break; - } - curHandler = curHandler.nextHandler; - scrut3 = curHandler !== null; - if (scrut3 === true) { - tmp12 = "\n\twith handler " + curHandler.handler.constructor.name; - tmp13 = msg + tmp12; - msg = tmp13; - atTail = false; - tmp14 = runtime.Unit; - } else { - tmp14 = runtime.Unit; - } - tmp15 = tmp14; - continue tmp19 - } else { - tmp15 = runtime.Unit; - } - break; - } - if (atTail === true) { - tmp16 = msg + "\n\tat tail position"; - msg = tmp16; - tmp17 = runtime.Unit; - } else { - tmp17 = runtime.Unit; - } - tmp18 = tmp17; - } else { - tmp18 = runtime.Unit; - } - return msg - } - static mkEffect(handler, handlerFun) { - let res, tmp; - tmp = new Runtime.ContTrace.class(null, null, null, null, false); - res = new Runtime.EffectSig.class(tmp, handler, handlerFun); - res.contTrace.last = res.contTrace; - res.contTrace.lastHandler = res.contTrace; - return res - } - static debugEff(eff) { - let tmp, tmp1, tmp2, tmp3; - if (eff instanceof Runtime.EffectSig.class) { - tmp = runtime.safeCall(globalThis.console.log("Debug EffectSig:")); - tmp1 = globalThis.console.log("handler: ", eff.handler.constructor.name); - tmp2 = globalThis.console.log("handlerFun: ", eff.handlerFun); - return Runtime.debugContTrace(eff.contTrace) - } else { - tmp3 = runtime.safeCall(globalThis.console.log("Not an effect:")); - return runtime.safeCall(globalThis.console.log(eff)) - } - } - static runStackSafe(limit, f) { - let result, scrut, saved, tmp, tmp1; - Runtime.stackLimit = limit; - Runtime.stackDepth = 1; - Runtime.stackHandler = Runtime.StackDelayHandler; - result = Runtime.enterHandleBlock(Runtime.StackDelayHandler, f); - Runtime.stackDepth = 1; - tmp2: while (true) { - scrut = Runtime.stackResume !== null; - if (scrut === true) { - saved = Runtime.stackResume; - Runtime.stackResume = null; - tmp = runtime.safeCall(saved()); - result = tmp; - Runtime.stackDepth = 1; - tmp1 = runtime.Unit; - continue tmp2 - } else { - tmp1 = runtime.Unit; - } - break; - } - Runtime.stackLimit = 0; - Runtime.stackDepth = 0; - Runtime.stackHandler = null; - return result - } - static debugCont(cont) { - let tmp, tmp1, tmp2; - tmp = globalThis.Object.freeze(new globalThis.Map()); - tmp1 = globalThis.Object.freeze(new globalThis.Set()); - tmp2 = Runtime.showFunctionContChain(cont, tmp, tmp1, 0); - return runtime.safeCall(globalThis.console.log(tmp2)) - } static handleEffect(cur) { let prevHandlerFrame, scrut, scrut1, scrut2, handlerFrame, saved, scrut3, scrut4, tmp, tmp1, tmp2, tmp3, tmp4, tmp5; prevHandlerFrame = cur.contTrace; @@ -1009,24 +907,126 @@ globalThis.Object.freeze(class Runtime { return Runtime.resumeContTrace(saved, cur) } } - static raisePrintStackEffect(showLocals) { - return Runtime.mkEffect(Runtime.PrintStackEffect, showLocals) + static resume(contTrace) { + return (value) => { + let scrut, tmp, tmp1; + scrut = contTrace.resumed; + if (scrut === true) { + throw globalThis.Error("Multiple resumption") + } else { + tmp = runtime.Unit; + } + contTrace.resumed = true; + tmp1 = Runtime.resumeContTrace(contTrace, value); + return Runtime.handleEffects(tmp1) + } } - static safeCall(x) { - if (x === undefined) { - return runtime.Unit - } else { - return x + static resumeContTrace(contTrace, value) { + let cont, handlerCont, curDepth, scrut, scrut1, tmp, tmp1, tmp2, tmp3, tmp4; + cont = contTrace.next; + handlerCont = contTrace.nextHandler; + curDepth = Runtime.stackDepth; + tmp5: while (true) { + if (cont instanceof Runtime.FunctionContFrame.class) { + tmp = runtime.safeCall(cont.resume(value)); + value = tmp; + Runtime.stackDepth = curDepth; + if (value instanceof Runtime.EffectSig.class) { + value.contTrace.last.next = cont.next; + value.contTrace.lastHandler.nextHandler = handlerCont; + scrut = contTrace.last !== cont; + if (scrut === true) { + value.contTrace.last = contTrace.last; + tmp1 = runtime.Unit; + } else { + tmp1 = runtime.Unit; + } + scrut1 = handlerCont !== null; + if (scrut1 === true) { + value.contTrace.lastHandler = contTrace.lastHandler; + tmp2 = runtime.Unit; + } else { + tmp2 = runtime.Unit; + } + return value + } else { + cont = cont.next; + tmp3 = runtime.Unit; + } + tmp4 = tmp3; + continue tmp5 + } else { + if (handlerCont instanceof Runtime.HandlerContFrame.class) { + cont = handlerCont.next; + handlerCont = handlerCont.nextHandler; + tmp4 = runtime.Unit; + continue tmp5 + } else { + return value + } + } + break; } + return tmp4 } - static try(f) { - let res; - res = runtime.safeCall(f()); - if (res instanceof Runtime.EffectSig.class) { - return Runtime.EffectHandle(res) + static checkDepth() { + let scrut, tmp, lambda; + tmp = Runtime.stackDepth >= Runtime.stackLimit; + lambda = (undefined, function () { + return Runtime.stackHandler !== null + }); + scrut = runtime.short_and(tmp, lambda); + if (scrut === true) { + return runtime.safeCall(Runtime.stackHandler.delay()) } else { - return res + return runtime.Unit } + } + static runStackSafe(limit, f) { + let result, scrut, saved, tmp, tmp1; + Runtime.stackLimit = limit; + Runtime.stackDepth = 1; + Runtime.stackHandler = Runtime.StackDelayHandler; + result = Runtime.enterHandleBlock(Runtime.StackDelayHandler, f); + Runtime.stackDepth = 1; + tmp2: while (true) { + scrut = Runtime.stackResume !== null; + if (scrut === true) { + saved = Runtime.stackResume; + Runtime.stackResume = null; + tmp = runtime.safeCall(saved()); + result = tmp; + Runtime.stackDepth = 1; + tmp1 = runtime.Unit; + continue tmp2 + } else { + tmp1 = runtime.Unit; + } + break; + } + Runtime.stackLimit = 0; + Runtime.stackDepth = 0; + Runtime.stackHandler = null; + return result + } + static plus_impl(lhs, rhs) { + let tmp; + split_root$: { + split_1$: { + if (lhs instanceof Runtime.Int31.class) { + if (rhs instanceof Runtime.Int31.class) { + tmp = lhs + rhs; + break split_root$ + } else { + break split_1$ + } + } else { + break split_1$ + } + } + tmp = Runtime.unreachable(); + } + return tmp } toString() { return runtime.render(this); } static [definitionMetadata] = ["class", "Runtime"]; From fe960a39c1ffc42a3f52942fb178e85e177a5087 Mon Sep 17 00:00:00 2001 From: CAG2Mark Date: Sat, 6 Dec 2025 02:24:48 +0800 Subject: [PATCH 14/16] PR comments --- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 50 +++++++++---------- .../main/scala/hkmc2/semantics/Symbol.scala | 2 + .../src/test/mlscript/codegen/TraceLog.mls | 17 ++----- 3 files changed, 31 insertions(+), 38 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index debdf91f37..8de73d92bd 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -166,23 +166,25 @@ class TailRecOpt(using State, TL, Raise): S(ret) def optScc(scc: SccOfCalls, owner: Opt[InnerSymbol]): (Opt[FunDefn], List[FunDefn]) = + // sort the functions so the order is more predictable + val funs = scc.funs.sortBy(f => f.dSym.uid) // remove calls which don't flow into this scc - val fSyms = scc.funs.map(_.dSym).toSet + val fSyms = funs.map(_.dSym).toSet val calls = scc.calls.filter(c => fSyms.contains(c.f2)) - val nonTailCalls = calls + val nonTailCallsLs = calls .collect: case c: CallEdge.NormalCall => c.f2 -> c.call - .toMap - - if nonTailCalls.size === calls.length then - for f <- scc.funs if f.isTailRec do + val nonTailCalls = nonTailCallsLs.toMap + + if nonTailCallsLs.size === calls.length then + for f <- funs if f.isTailRec do raise(WarningReport(msg"This function does not directly self-recurse, but is marked @tailrec." -> f.dSym.toLoc :: Nil)) - return (N, scc.funs) + return (N, funs) if !nonTailCalls.isEmpty then - for f <- scc.funs if f.isTailRec do + for f <- funs if f.isTailRec do val reportLoc = nonTailCalls.get(f.dSym) match // always display a call to f, if possible case Some(value) => value.toLoc @@ -193,19 +195,19 @@ class TailRecOpt(using State, TL, Raise): :: Nil )) - val maxParamLen = maxInt(scc.funs, paramsLen) + val maxParamLen = maxInt(funs, paramsLen) val paramSyms = - if scc.funs.length === 1 then (getParamSyms(scc.funs.head)) + if funs.length === 1 then (getParamSyms(funs.head)) else for i <- 0 to maxParamLen - 1 yield VarSymbol(Tree.Ident("param" + i)) .toList val paramSymsArr = ArrayBuffer.from(paramSyms) - val dSymIds = scc.funs.map(_.dSym).zipWithIndex.toMap + val dSymIds = funs.map(_.dSym).zipWithIndex.toMap val bms = - if scc.funs.size === 1 then scc.funs.head.sym - else BlockMemberSymbol(scc.funs.map(_.sym.nme).mkString("_"), Nil, true) + if funs.size === 1 then funs.head.sym + else BlockMemberSymbol(funs.map(_.sym.nme).mkString("_"), Nil, true) val dSym = - if scc.funs.size === 1 then scc.funs.head.dSym + if funs.size === 1 then funs.head.dSym else TermSymbol(syntax.Fun, owner, Tree.Ident(bms.nme)) val loopSym = TempSymbol(N, "loopLabel") val curIdSym = VarSymbol(Tree.Ident("id")) @@ -278,7 +280,7 @@ class TailRecOpt(using State, TL, Raise): def rewrite(b: Block) = applyBlock(symRewriter.applyBlock(b)) - val arms = scc.funs.map: f => + val arms = funs.map: f => Case.Lit(Tree.IntLit(dSymIds(f.dSym))) -> FunRewriter(f).rewrite(f.body) val switch = @@ -292,8 +294,8 @@ class TailRecOpt(using State, TL, Raise): case None => Value.Ref(bms, S(dSym)) val rewrittenFuns = - if scc.funs.size === 1 then Nil - else scc.funs.map: f => + if funs.size === 1 then Nil + else funs.map: f => val paramArgs = getParamSyms(f).map(_.asPath.asArg) val args = Value.Lit(Tree.IntLit(dSymIds(f.dSym))).asArg @@ -307,7 +309,7 @@ class TailRecOpt(using State, TL, Raise): val params = val initial = paramSyms.map(Param.simple(_)) - if scc.funs.length === 1 then initial + if funs.length === 1 then initial else Param.simple(curIdSym) :: initial val loopDefn = FunDefn( @@ -315,7 +317,7 @@ class TailRecOpt(using State, TL, Raise): PlainParamList(params) :: Nil, loop)(false) - if scc.funs.size === 1 then (N, loopDefn :: Nil) + if funs.size === 1 then (N, loopDefn :: Nil) else (S(loopDefn), rewrittenFuns) def optFunctions(fs: List[FunDefn], owner: Opt[InnerSymbol]) = @@ -361,12 +363,10 @@ class TailRecOpt(using State, TL, Raise): def transform(b: Block) = val (blk, defns) = b.floatOutDefns() - val (funs, clses) = - defns.foldLeft[(List[FunDefn], List[ClsLikeDefn])](Nil, Nil): - case ((fs, cs), d) => d match - case f: FunDefn => (f :: fs, cs) - case c: ClsLikeDefn => (fs, c :: cs) - case _ => (fs, cs) // unreachable as floatOutDefns only floats out FunDefns and ClsLikeDefns + val (funs, clses) = defns.partitionMap: + case f: FunDefn => L(f) + case c: ClsLikeDefn => R(c) + case _ => die // unreachable as floatOutDefns only floats out FunDefns and ClsLikeDefns val (optFNew, optF) = optFunctions(funs, N) val optC = optClasses(clses) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index e44f7a188c..ae0909c2e8 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -265,10 +265,12 @@ class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.I override def toString: Str = s"term:${owner.map(o => s"${o}.").getOrElse("")}${id.name}${State.dbgUid(uid)}" def subst(using sub: SymbolSubst): TermSymbol = sub.mapTermSym(this) + object TermSymbol: def fromFunBms(b: BlockMemberSymbol, owner: Opt[InnerSymbol])(using State) = TermSymbol(syntax.Fun, owner, Tree.Ident(b.nme)) + sealed trait CtorSymbol extends Symbol: def subst(using sub: SymbolSubst): CtorSymbol = sub.mapCtorSym(this) diff --git a/hkmc2/shared/src/test/mlscript/codegen/TraceLog.mls b/hkmc2/shared/src/test/mlscript/codegen/TraceLog.mls index de72d61fde..1b3593632c 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/TraceLog.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/TraceLog.mls @@ -10,19 +10,10 @@ fun fib(a) = if //│ let fib; //│ fib = function fib(a) { //│ let scrut, tmp, tmp1, tmp2, tmp3; -//│ loopLabel: while (true) { -//│ scrut = a <= 1; -//│ if (scrut === true) { -//│ return a -//│ } else { -//│ tmp = a - 1; -//│ tmp1 = fib(tmp); -//│ tmp2 = a - 2; -//│ tmp3 = fib(tmp2); -//│ return tmp1 + tmp3 -//│ } -//│ break; -//│ } +//│ scrut = a <= 1; +//│ if (scrut === true) { +//│ return a +//│ } else { tmp = a - 1; tmp1 = fib(tmp); tmp2 = a - 2; tmp3 = fib(tmp2); return tmp1 + tmp3 } //│ }; fun f(x) = g(x) From b8be10d8e33b210ec02e4f56d5ac3bdb226f024e Mon Sep 17 00:00:00 2001 From: Lionel Parreaux Date: Sat, 6 Dec 2025 14:20:38 +0800 Subject: [PATCH 15/16] Minor changes --- .../src/main/scala/hkmc2/codegen/Block.scala | 10 ++-- .../hkmc2/codegen/BlockTransformer.scala | 2 +- .../hkmc2/codegen/BufferableTransform.scala | 2 +- .../scala/hkmc2/codegen/HandlerLowering.scala | 2 +- .../src/main/scala/hkmc2/codegen/Lifter.scala | 10 ++-- .../main/scala/hkmc2/codegen/Lowering.scala | 8 +-- .../hkmc2/codegen/StackSafeTransform.scala | 4 +- .../main/scala/hkmc2/codegen/TailRecOpt.scala | 12 ++--- .../hkmc2/semantics/ucs/Normalization.scala | 2 +- .../src/test/mlscript/tailrec/Errors.mls | 12 +++++ .../src/test/mlscript/tailrec/Simple.mls | 50 +++++++++++++++++++ 11 files changed, 88 insertions(+), 26 deletions(-) create mode 100644 hkmc2/shared/src/test/mlscript/tailrec/Simple.mls diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 42f2a275f6..47a717919d 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -219,7 +219,7 @@ sealed abstract class Block extends Product: val newBody = d.body.flattened if newBody is d.body then d - else d.copy(body = newBody)(isTailRec = d.isTailRec) + else d.copy(body = newBody)(forceTailRec = d.forceTailRec) case v: ValDefn => v case c: ClsLikeDefn => val newPreCtor = c.preCtor.flattened @@ -227,7 +227,7 @@ sealed abstract class Block extends Product: val newMethods = c.methods.mapConserve: case f@FunDefn(owner, sym, dSym, params, body) => val newBody = body.flattened - if newBody is body then f else f.copy(body = newBody)(isTailRec = f.isTailRec) + if newBody is body then f else f.copy(body = newBody)(forceTailRec = f.forceTailRec) if (newPreCtor is c.preCtor) && (newCtor is c.ctor) && (newMethods is c.methods) then c else c.copy(preCtor = newPreCtor, ctor = newCtor, methods = newMethods) @@ -343,13 +343,13 @@ final case class FunDefn( params: Ls[ParamList], body: Block, )( - val isTailRec: Bool, + val forceTailRec: Bool, ) extends Defn: val innerSym = N val asPath = Value.Ref(sym, S(dSym)) object FunDefn: - def withFreshSymbol(owner: Opt[InnerSymbol], sym: BlockMemberSymbol, params: Ls[ParamList], body: Block)(isTailRec: Bool)(using State) = - FunDefn(owner, sym, TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme)), params, body)(isTailRec) + def withFreshSymbol(owner: Opt[InnerSymbol], sym: BlockMemberSymbol, params: Ls[ParamList], body: Block)(forceTailRec: Bool)(using State) = + FunDefn(owner, sym, TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme)), params, body)(forceTailRec) final case class ValDefn( tsym: TermSymbol, diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala index f4e974a505..b57747173a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala @@ -161,7 +161,7 @@ class BlockTransformer(subst: SymbolSubst): val params2 = fun.params.mapConserve(applyParamList) val body2 = applySubBlock(fun.body) if (own2 is fun.owner) && (sym2 is fun.sym) && (dSym2 is fun.dSym) && (params2 is fun.params) && (body2 is fun.body) - then fun else FunDefn(own2, sym2, dSym2, params2, body2)(fun.isTailRec) + then fun else FunDefn(own2, sym2, dSym2, params2, body2)(fun.forceTailRec) def applyValDefn(defn: ValDefn)(k: ValDefn => Block): Block = val ValDefn(tsym, sym, rhs) = defn diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala index e2861330ea..f50b8f36fd 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala @@ -68,7 +68,7 @@ class BufferableTransform()(using Ctx, State, Raise): val blk = mkFieldReplacer(buf, idx).applyBlock(f.body) FunDefn(f.owner, f.sym, f.dSym, PlainParamList( Param(FldFlags.empty, buf, N, Modulefulness.none) :: Param(FldFlags.empty, idx, N, Modulefulness.none) :: Nil) :: f.params, - if isCtor then Begin(blk, Return(idx.asPath, false)) else blk)(isTailRec = f.isTailRec) + if isCtor then Begin(blk, Return(idx.asPath, false)) else blk)(forceTailRec = f.forceTailRec) val fakeCtor = transformFunDefn(FunDefn.withFreshSymbol( S(companionSym), BlockMemberSymbol("ctor", Nil, false), diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index 3a99d81ed4..b35e350613 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -605,7 +605,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, callSelf, L(f.sym), functionHandlerCtx(s"Cont$$func$$${symToStr(f.sym)}$$", f.sym.nme)) - )(isTailRec = f.isTailRec) + )(forceTailRec = f.forceTailRec) private def translateBody(cls: ClsLikeBody, sym: BlockMemberSymbol)(using HandlerCtx): ClsLikeBody = val curCtorCtx = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index c8ec1200fe..6d237681b7 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -943,7 +943,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): val newDef = FunDefn( base.owner, f.sym, f.dSym, PlainParamList(extraParams) :: f.params, f.body - )(f.isTailRec) + )(f.forceTailRec) val Lifted(lifted, extras) = liftDefnsInFn(newDef, newCtx) val args1 = extraParamsCpy.map(p => p.sym.asPath.asArg) @@ -953,7 +953,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): .ret(Call(singleCallBms.asPath, args1 ++ args2)(true, false, false)) // TODO: restParams not considered val mainDefn = FunDefn(f.owner, f.sym, f.dSym, PlainParamList(extraParamsCpy) :: headPlistCopy :: Nil, bdy)(false) - val auxDefn = FunDefn(N, singleCallBms.b, singleCallBms.d, flatPlist, lifted.body)(isTailRec = f.isTailRec) + val auxDefn = FunDefn(N, singleCallBms.b, singleCallBms.d, flatPlist, lifted.body)(forceTailRec = f.forceTailRec) if ctx.firstClsFns.contains(f.sym) then Lifted(mainDefn, auxDefn :: extras) @@ -1186,7 +1186,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): case value => S(value.copy(methods = cMethods, ctor = newCCtor.get)) val extras = (ctorDefnsLifted ++ fExtra ++ cfExtra ++ ctorIgnoredExtra).map: - case f: FunDefn => f.copy(owner = N)(isTailRec = f.isTailRec) + case f: FunDefn => f.copy(owner = N)(forceTailRec = f.forceTailRec) case c: ClsLikeDefn => c.copy(owner = N) case d => d @@ -1251,7 +1251,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): val transformed = BlockRewriter(ctx.inScopeISyms, captureCtx.addreplacedDefns(ignoredRewrite)).applyBlock(blk) if thisVars.reqCapture.size == 0 then - Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, transformed)(isTailRec = f.isTailRec), newDefns) + Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, transformed)(forceTailRec = f.forceTailRec), newDefns) else // move the function's parameters to the capture val paramsSet = f.params.flatMap(_.paramSyms) @@ -1262,7 +1262,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): .assign(captureSym, Instantiate(mut = true, // * Note: `mut` is needed for capture classes captureCls.sym.asPath, paramsList)) .rest(transformed) - Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, bod)(isTailRec = f.isTailRec), captureCls :: newDefns) + Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, bod)(forceTailRec = f.forceTailRec), captureCls :: newDefns) end liftDefnsInFn diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index c5cc7cb134..a4a0bef4c8 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -254,7 +254,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case (sym, params, split) => val paramLists = params :: Nil val bodyBlock = ucs.Normalization(this)(split)(Ret) - FunDefn.withFreshSymbol(N, sym, paramLists, bodyBlock)(isTailRec = false) + FunDefn.withFreshSymbol(N, sym, paramLists, bodyBlock)(forceTailRec = false) // The return type is intended to be consistent with `gatherMembers` (mtds, Nil, Nil, End()) case _ => gatherMembers(defn.body) @@ -468,7 +468,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val isOr = sym is State.orSymbol if isAnd || isOr then val lamSym = BlockMemberSymbol("lambda", Nil, false) - val lamDef = FunDefn.withFreshSymbol(N, lamSym, PlainParamList(Nil) :: Nil, returnedTerm(arg2))(isTailRec = false) + val lamDef = FunDefn.withFreshSymbol(N, lamSym, PlainParamList(Nil) :: Nil, returnedTerm(arg2))(forceTailRec = false) Define( lamDef, k(Call( @@ -599,7 +599,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): then k(Lambda(paramLists.head, bodyBlock)) else val lamSym = new BlockMemberSymbol("lambda", Nil, false) - val lamDef = FunDefn.withFreshSymbol(N, lamSym, paramLists, bodyBlock)(isTailRec = false) + val lamDef = FunDefn.withFreshSymbol(N, lamSym, paramLists, bodyBlock)(forceTailRec = false) Define( lamDef, k(lamDef.asPath)) @@ -939,7 +939,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case p: Path => k(p) case Lambda(params, body) => val lamSym = BlockMemberSymbol("lambda", Nil, false) - val lamDef = FunDefn.withFreshSymbol(N, lamSym, params :: Nil, body)(isTailRec = false) + val lamDef = FunDefn.withFreshSymbol(N, lamSym, params :: Nil, body)(forceTailRec = false) Define(lamDef, k(lamDef.asPath)) case r => val l = new TempSymbol(N) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala index 98a2d85d01..a98afafca9 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala @@ -40,7 +40,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, doUnwindMap: Map[ def wrapStackSafe(body: Block, resSym: Local, rest: Block) = val bodSym = BlockMemberSymbol("‹stack safe body›", Nil, false) - val bodFun = FunDefn.withFreshSymbol(N, bodSym, ParamList(ParamListFlags.empty, Nil, N) :: Nil, body)(isTailRec = false) + val bodFun = FunDefn.withFreshSymbol(N, bodSym, ParamList(ParamListFlags.empty, Nil, N) :: Nil, body)(forceTailRec = false) Define(bodFun, Assign(resSym, Call(runStackSafePath, intLit(depthLimit).asArg :: bodSym.asPath.asArg :: Nil)(true, true, false), rest)) def extractResTopLevel(res: Result, isTailCall: Bool, f: Result => Block, sym: Option[Symbol], curDepth: => Symbol) = @@ -173,6 +173,6 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, doUnwindMap: Map[ def rewriteFn(defn: FunDefn) = if doUnwindFns.contains(defn.sym) then defn - else FunDefn(defn.owner, defn.sym, defn.dSym, defn.params, rewriteBlk(defn.body, L(defn.sym), 1))(defn.isTailRec) + else FunDefn(defn.owner, defn.sym, defn.dSym, defn.params, rewriteBlk(defn.body, L(defn.sym), 1))(defn.forceTailRec) def transformTopLevel(b: Block) = transform(b, TempSymbol(N), true) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala index 8de73d92bd..3bb85d90d3 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -45,7 +45,7 @@ class TailRecOpt(using State, TL, Raise): def find = // Ignore functions with multiple parameter lists if f.params.length > 1 then - if f.isTailRec then + if f.forceTailRec then raise(ErrorReport(msg"Functions with more than one parameter list may not be marked @tailrec." -> f.dSym.toLoc :: Nil)) Nil else @@ -177,14 +177,14 @@ class TailRecOpt(using State, TL, Raise): .collect: case c: CallEdge.NormalCall => c.f2 -> c.call val nonTailCalls = nonTailCallsLs.toMap - - if nonTailCallsLs.size === calls.length then - for f <- funs if f.isTailRec do + + if nonTailCallsLs.sizeCompare(calls) === 0 then + for f <- funs if f.forceTailRec do raise(WarningReport(msg"This function does not directly self-recurse, but is marked @tailrec." -> f.dSym.toLoc :: Nil)) return (N, funs) if !nonTailCalls.isEmpty then - for f <- funs if f.isTailRec do + for f <- funs if f.forceTailRec do val reportLoc = nonTailCalls.get(f.dSym) match // always display a call to f, if possible case Some(value) => value.toLoc @@ -334,7 +334,7 @@ class TailRecOpt(using State, TL, Raise): new BlockTraverserShallow(): for f <- c.methods do applyBlock(f.body) - if f.isTailRec then + if f.forceTailRec then raise(ErrorReport(msg"Class methods may not yet be marked @tailrec." -> f.dSym.toLoc :: Nil)) override def applyResult(r: Result): Unit = r match case c: Call if c.explicitTailCall => diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index 14c3de567e..9377f0881b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -399,7 +399,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e Select(Value.Ref(State.runtimeSymbol), Tree.Ident("LoopEnd"))(S(State.loopEndSymbol)) val blk = blockBuilder .assign(l, Value.Lit(Tree.UnitLit(false))) - .define(FunDefn(N, f, tSym, PlainParamList(Nil) :: Nil, Begin(body, Return(loopEnd, false)))(isTailRec = false)) + .define(FunDefn(N, f, tSym, PlainParamList(Nil) :: Nil, Begin(body, Return(loopEnd, false)))(forceTailRec = false)) .assign(loopResult, Call(Value.Ref(f, S(tSym)), Nil)(true, true, false)) if summon[LoweringCtx].mayRet then blk diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls index 474ec7bbba..b42a706c8a 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls @@ -80,3 +80,15 @@ module A with //│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. //│ ║ l.79: fun f() = 2 //│ ╙── ^ + +:w +@tailrec +fun foo() = Foo.bar() +module Foo with + fun bar() = foo() +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.86: fun foo() = Foo.bar() +//│ ╙── ^^^ + + + diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Simple.mls b/hkmc2/shared/src/test/mlscript/tailrec/Simple.mls new file mode 100644 index 0000000000..ee607e3a54 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/tailrec/Simple.mls @@ -0,0 +1,50 @@ +:js + + +:sjs +@tailrec +fun f = f +//│ JS (unsanitized): +//│ let f; f = function f() { loopLabel: while (true) { continue loopLabel; break; } }; + +:fixme +module Foo with + @tailrec + fun bar = bar +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.13: fun bar = bar +//│ ╙── ^^^ + +@tailrec +fun f() = f() + +module Foo with + @tailrec + fun bar() = bar() + +fun f(x) = f(x + 1) + +:sjs +fun f(x) = f(f(x) + 1) +//│ JS (unsanitized): +//│ let f3; +//│ f3 = function f(x) { +//│ let tmp, tmp1; +//│ loopLabel: while (true) { tmp = f3(x); tmp1 = tmp + 1; x = tmp1; continue loopLabel; break; } +//│ }; + +module A with + fun f(x) = @tailcall f(f(x) + 1) + + +:fixme // TODO: support +module A with + @tailrec + fun f(x) = B.g(x + 1) +module B with + fun g(x) = A.f(x * 2) +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.43: fun f(x) = B.g(x + 1) +//│ ╙── ^ + + From 243de680bc046d5d1b38a65c39477371f18cf4bd Mon Sep 17 00:00:00 2001 From: Lionel Parreaux Date: Sat, 6 Dec 2025 15:06:17 +0800 Subject: [PATCH 16/16] Update hkmc2/shared/src/test/mlscript/tailrec/Errors.mls --- hkmc2/shared/src/test/mlscript/tailrec/Errors.mls | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls index b42a706c8a..04059a8fba 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls @@ -81,7 +81,7 @@ module A with //│ ║ l.79: fun f() = 2 //│ ╙── ^ -:w +:fixme // TODO: support @tailrec fun foo() = Foo.bar() module Foo with