diff --git a/hkmc2/shared/src/main/scala/hkmc2/Config.scala b/hkmc2/shared/src/main/scala/hkmc2/Config.scala index 6e6a7f183..e7b3ae69b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/Config.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/Config.scala @@ -23,6 +23,7 @@ case class Config( stageCode: Bool, target: CompilationTarget, rewriteWhileLoops: Bool, + tailRecOpt: Bool, ): def stackSafety: Opt[StackSafety] = effectHandlers.flatMap(_.stackSafety) @@ -40,6 +41,7 @@ object Config: target = CompilationTarget.JS, rewriteWhileLoops = true, stageCode = false, + tailRecOpt = true, ) case class SanityChecks(light: Bool) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 6b17fbca2..47a717919 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -219,7 +219,7 @@ sealed abstract class Block extends Product: val newBody = d.body.flattened if newBody is d.body then d - else d.copy(body = newBody)(isTailRec = d.isTailRec) + else d.copy(body = newBody)(forceTailRec = d.forceTailRec) case v: ValDefn => v case c: ClsLikeDefn => val newPreCtor = c.preCtor.flattened @@ -227,7 +227,7 @@ sealed abstract class Block extends Product: val newMethods = c.methods.mapConserve: case f@FunDefn(owner, sym, dSym, params, body) => val newBody = body.flattened - if newBody is body then f else f.copy(body = newBody)(isTailRec = f.isTailRec) + if newBody is body then f else f.copy(body = newBody)(forceTailRec = f.forceTailRec) if (newPreCtor is c.preCtor) && (newCtor is c.ctor) && (newMethods is c.methods) then c else c.copy(preCtor = newPreCtor, ctor = newCtor, methods = newMethods) @@ -343,12 +343,13 @@ final case class FunDefn( params: Ls[ParamList], body: Block, )( - val isTailRec: Bool, + val forceTailRec: Bool, ) extends Defn: val innerSym = N + val asPath = Value.Ref(sym, S(dSym)) object FunDefn: - def withFreshSymbol(owner: Opt[InnerSymbol], sym: BlockMemberSymbol, params: Ls[ParamList], body: Block)(isTailRec: Bool)(using State) = - FunDefn(owner, sym, TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme)), params, body)(isTailRec) + def withFreshSymbol(owner: Opt[InnerSymbol], sym: BlockMemberSymbol, params: Ls[ParamList], body: Block)(forceTailRec: Bool)(using State) = + FunDefn(owner, sym, TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme)), params, body)(forceTailRec) final case class ValDefn( tsym: TermSymbol, diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala index f4e974a50..b57747173 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala @@ -161,7 +161,7 @@ class BlockTransformer(subst: SymbolSubst): val params2 = fun.params.mapConserve(applyParamList) val body2 = applySubBlock(fun.body) if (own2 is fun.owner) && (sym2 is fun.sym) && (dSym2 is fun.dSym) && (params2 is fun.params) && (body2 is fun.body) - then fun else FunDefn(own2, sym2, dSym2, params2, body2)(fun.isTailRec) + then fun else FunDefn(own2, sym2, dSym2, params2, body2)(fun.forceTailRec) def applyValDefn(defn: ValDefn)(k: ValDefn => Block): Block = val ValDefn(tsym, sym, rhs) = defn diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala index 761d12b8f..5d1137e5e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala @@ -127,6 +127,7 @@ class BlockTraverser: cls.traverse applyPath(path) case Case.Tup(len, inf) => () + case Case.Field(_, _) => () def applyHandler(hdr: Handler): Unit = hdr.sym.traverse diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala index e2861330e..f50b8f36f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/BufferableTransform.scala @@ -68,7 +68,7 @@ class BufferableTransform()(using Ctx, State, Raise): val blk = mkFieldReplacer(buf, idx).applyBlock(f.body) FunDefn(f.owner, f.sym, f.dSym, PlainParamList( Param(FldFlags.empty, buf, N, Modulefulness.none) :: Param(FldFlags.empty, idx, N, Modulefulness.none) :: Nil) :: f.params, - if isCtor then Begin(blk, Return(idx.asPath, false)) else blk)(isTailRec = f.isTailRec) + if isCtor then Begin(blk, Return(idx.asPath, false)) else blk)(forceTailRec = f.forceTailRec) val fakeCtor = transformFunDefn(FunDefn.withFreshSymbol( S(companionSym), BlockMemberSymbol("ctor", Nil, false), diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala index e9dd264b3..b35e35061 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala @@ -476,7 +476,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, // all the code in the first state private def translateBlock(b: Block, extraLocals: Set[Local], callSelf: Opt[Result], fnOrCls: FnOrCls, h: HandlerCtx): Block = val getLocalsFn = createGetLocalsFn(b, extraLocals)(using h) - given HandlerCtx = h.nestDebugScope(b.userDefinedVars ++ extraLocals, getLocalsFn.sym.asPath) + given HandlerCtx = h.nestDebugScope(b.userDefinedVars ++ extraLocals, getLocalsFn.asPath) val stage1 = firstPass(b) val stage2 = secondPass(stage1, fnOrCls, callSelf, getLocalsFn) if h.isTopLevel then stage2 else thirdPass(stage2) @@ -542,23 +542,26 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, val cls = if handlerCtx.isTopLevel then N else genContClass(b, callSelf) val ret = cls match - case None => genNormalBody(b, BlockMemberSymbol("", Nil), N) + case None => genNormalBody(b, BlockMemberSymbol("", Nil).asPath, N) case Some(cls) => // create the doUnwind function val doUnwindSym = BlockMemberSymbol(doUnwindNme, Nil, true) - doUnwindMap += fnOrCls -> doUnwindSym.asPath val pcSym = VarSymbol(Tree.Ident("pc")) val resSym = VarSymbol(Tree.Ident("res")) val doUnwindBlk = h.linkAndHandle( - LinkState(resSym, cls.sym.asPath, pcSym.asPath) + LinkState(resSym, Value.Ref(cls.sym, S(cls.isym)), pcSym.asPath) ) val doUnwindDef = FunDefn.withFreshSymbol( N, doUnwindSym, PlainParamList(Param.simple(resSym) :: Param.simple(pcSym) :: Nil) :: Nil, doUnwindBlk )(false) - val doUnwindLazy = Lazy(doUnwindSym.asPath) - val rst = genNormalBody(b, cls.sym, S(doUnwindLazy)) + + val doUnwindPath: Path = doUnwindDef.asPath + doUnwindMap += fnOrCls -> doUnwindPath + + val doUnwindLazy = Lazy(doUnwindPath) + val rst = genNormalBody(b, Value.Ref(cls.sym, S(cls.isym)), S(doUnwindLazy)) if doUnwindLazy.isEmpty && opt.stackSafety.isEmpty then blockBuilder @@ -592,7 +595,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, case pList :: Nil => val params = pList.params.map(p => p.sym.asPath.asArg) f.owner match - case None => S(Call(f.sym.asPath, params)(true, true, false)) + case None => S(Call(f.asPath, params)(true, true, false)) case Some(owner) => S(Call(Select(owner.asPath, Tree.Ident(f.sym.nme))(N), params)(true, true, false)) case _ => None // TODO: more than one plist @@ -602,7 +605,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, callSelf, L(f.sym), functionHandlerCtx(s"Cont$$func$$${symToStr(f.sym)}$$", f.sym.nme)) - )(isTailRec = f.isTailRec) + )(forceTailRec = f.forceTailRec) private def translateBody(cls: ClsLikeBody, sym: BlockMemberSymbol)(using HandlerCtx): ClsLikeBody = val curCtorCtx = @@ -631,11 +634,12 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, // Handle block becomes a FunDefn and CallPlaceholder private def translateHandleBlock(h: HandleBlock)(using HandlerCtx): Block = val sym = BlockMemberSymbol(s"handleBlock$$", Nil) + val tSym = TermSymbol.fromFunBms(sym, N) val lbl = freshTmp("handlerBody") val lblLoop = freshTmp("handlerLoop") val handlerBody = translateBlock( - h.body, Set.empty, S(Call(sym.asPath, Nil)(true, false, false)), L(sym), + h.body, Set.empty, S(Call(Value.Ref(sym, S(tSym)), Nil)(true, false, false)), L(sym), HandlerCtx( false, true, s"Cont$$handleBlock$$${symToStr(h.lhs)}$$", N, @@ -662,7 +666,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, handler.params, Define( fDef, - Return(PureCall(paths.mkEffectPath, h.cls.asPath :: Value.Ref(sym, N) :: Nil), false)))(false) + Return(PureCall(paths.mkEffectPath, h.cls.asPath :: fDef.asPath :: Nil), false)))(false) // Some limited handling of effects extending classes and having access to their fields. // Currently does not support super() raising effects. @@ -695,14 +699,14 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, .assign(h.lhs, Instantiate(mut = true, Value.Ref(clsDefn.sym, S(h.cls)), Nil)) .rest(handlerBody) - val defn = FunDefn.withFreshSymbol( + val defn = FunDefn( N, // no owner - sym, PlainParamList(Nil) :: Nil, body)(false) + sym, tSym, PlainParamList(Nil) :: Nil, body)(false) val result = blockBuilder .define(defn) .rest( - ResultPlaceholder(h.res, freshId(), Call(sym.asPath, Nil)(true, true, false), h.rest) + ResultPlaceholder(h.res, freshId(), Call(defn.asPath, Nil)(true, true, false), h.rest) ) result @@ -928,12 +932,12 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, N, // TODO: bufferable? )) - private def genNormalBody(b: Block, clsSym: BlockMemberSymbol, doUnwind: Opt[Lazy[Path]])(using HandlerCtx): Block = + private def genNormalBody(b: Block, clsPath: Path, doUnwind: Opt[Lazy[Path]])(using HandlerCtx): Block = val transform = new BlockTransformerShallow(SymbolSubst()): override def applyBlock(b: Block): Block = b match case ResultPlaceholder(res, uid, c, rest) => val doUnwindBlk = doUnwind match - case None => Assign(res, topLevelCall(LinkState(res, clsSym.asPath, Value.Lit(Tree.IntLit(uid)))), End()) + case None => Assign(res, topLevelCall(LinkState(res, clsPath, Value.Lit(Tree.IntLit(uid)))), End()) case Some(doUnwind) => Return(PureCall(doUnwind.get_!, res.asPath :: Value.Lit(Tree.IntLit(uid)) :: Nil), false) blockBuilder .assign(res, c) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala index be48baefa..06a23125a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala @@ -17,19 +17,21 @@ object LambdaRewriter: val newSym = BlockMemberSymbol(lhs.nme, Nil, nameIsMeaningful = true // TODO: lhs.nme is not always meaningful ) + val defn = FunDefn.withFreshSymbol(N, newSym, params :: Nil, body)(false) val blk = blockBuilder - .define(FunDefn.withFreshSymbol(N, newSym, params :: Nil, body)(false)) - .assign(lhs, newSym.asPath) + .define(defn) + .assign(lhs, defn.asPath) .rest(rest) (blk, Nil) case _ => - var lambdasList: List[(BlockMemberSymbol, Lambda)] = Nil + var lambdasList: List[((BlockMemberSymbol, TermSymbol), Lambda)] = Nil val lambdaRewriter = new BlockDataTransformer(SymbolSubst()): override def applyResult(r: Result)(k: Result => Block): Block = r match case lam: Lambda => val sym = BlockMemberSymbol("lambda", Nil, nameIsMeaningful = false) - lambdasList ::= (sym -> super.applyLam(lam)) - k(Value.Ref(sym, N)) + val tSym = TermSymbol.fromFunBms(sym, N) + lambdasList ::= ((sym, tSym) -> super.applyLam(lam)) + k(Value.Ref(sym, S(tSym))) case _ => super.applyResult(r)(k) val blk = lambdaRewriter.applyBlock(b) (blk, lambdasList) @@ -39,7 +41,7 @@ object LambdaRewriter: val (newBlk, lambdasList) = rewriteOneBlk(b) val lambdaDefns = lambdasList.map: case (sym, Lambda(params, body)) => - FunDefn.withFreshSymbol(N, sym, params :: Nil, body)(false) + FunDefn(N, sym._1, sym._2, params :: Nil, body)(false) val ret = lambdaDefns.foldLeft(newBlk): case (acc, defn) => Define(defn, acc) super.applyBlock(ret) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala index f642716f7..6d237681b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala @@ -110,8 +110,8 @@ object Lifter: object InstSel: def unapply(p: Path) = p match - case Value.Ref(l: BlockMemberSymbol, _) => S(l) - case s @ Select(Value.Ref(l: BlockMemberSymbol, _), Tree.Ident("class")) => S(l) + case Value.Ref(l: BlockMemberSymbol, d) => S((l, d)) + case s @ Select(Value.Ref(l: BlockMemberSymbol, _), Tree.Ident("class")) => S((l, s.symbol)) case _ => N def modOrObj(d: Defn) = d match @@ -360,14 +360,20 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): imutVars.filter: s => !mutVars.contains(s) && candVars.contains(s) + case class FunSyms[T <: DefinitionSymbol[?]](b: BlockMemberSymbol, d: T): + def asPath = Value.Ref(b, S(d)) + object FunSyms: + def fromFun(b: BlockMemberSymbol, owner: Opt[InnerSymbol] = N) = + FunSyms(b, TermSymbol.fromFunBms(b, owner)) + // Info required for lifting a definition. case class LiftedInfo( val reqdCaptures: List[BlockMemberSymbol], // The mutable captures a lifted definition must take. val reqdVars: List[Local], // The (passed by value) variables a lifted definition must take. val reqdInnerSyms: List[InnerSymbol], // The inner symbols a lifted definition must take. val reqdBms: List[BlockMemberSymbol], // BMS's belonging to unlifted definitions that this definition references. - val fakeCtorBms: Option[BlockMemberSymbol], // only for classes - val singleCallBms: BlockMemberSymbol, // optimization + val fakeCtorBms: Option[FunSyms[TermSymbol]], // only for classes + val singleCallBms: FunSyms[TermSymbol], // optimization ) case class Lifted[+T <: Defn]( @@ -579,7 +585,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): val info = LiftedInfo( includedCaptures, includedLocals, clsCaptures, - refBms, fakeCtorBms, singleCallBms + refBms, fakeCtorBms.map(FunSyms.fromFun(_)), FunSyms.fromFun(singleCallBms) ) d match @@ -624,27 +630,27 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): // Does *not* rewrite references to non-lifted BMS symbols. def rewriteBms(b: Block) = // BMS's that need to be created - val syms: LinkedHashMap[BlockMemberSymbol, Local] = LinkedHashMap.empty + val syms: LinkedHashMap[FunSyms[?], Local] = LinkedHashMap.empty val walker = new BlockDataTransformer(SymbolSubst()): // only scan within the block. don't traverse override def applyResult(r: Result)(k: Result => Block): Block = r match // if possible, directly rewrite the call using the efficient version - case c @ Call(RefOfBms(l, _), args) => ctx.bmsReqdInfo.get(l) match + case c @ Call(RefOfBms(l, S(d)), args) => ctx.bmsReqdInfo.get(l) match case Some(info) if !ctx.isModOrObj(l) => val extraArgs = ctx.defns.get(l) match // If it's a class, we need to add the isMut parameter. // Instantiation without `new mut` is always immutable - case Some(c: ClsLikeDefn) => Value.Lit(Tree.BoolLit(false)).asArg :: getCallArgs(l, ctx) - case _ => getCallArgs(l, ctx) + case Some(c: ClsLikeDefn) => Value.Lit(Tree.BoolLit(false)).asArg :: getCallArgs(FunSyms(l, d), ctx) + case _ => getCallArgs(FunSyms(l, d), ctx) applyListOf(args, applyArg(_)(_)): newArgs => k(Call(info.singleCallBms.asPath, extraArgs ++ newArgs)(c.isMlsFun, false, c.explicitTailCall)) case _ => super.applyResult(r)(k) - case c @ Instantiate(mut, InstSel(l), args) => + case c @ Instantiate(mut, InstSel(l, S(d)), args) => ctx.bmsReqdInfo.get(l) match case Some(info) if !ctx.isModOrObj(l) => - val extraArgs = Value.Lit(Tree.BoolLit(mut)).asArg :: getCallArgs(l, ctx) + val extraArgs = Value.Lit(Tree.BoolLit(mut)).asArg :: getCallArgs(FunSyms(l, d), ctx) applyListOf(args, applyArg(_)(_)): newArgs => k(Call(info.singleCallBms.asPath, extraArgs ++ newArgs)(true, false, false)) case _ => super.applyResult(r)(k) @@ -658,13 +664,13 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): // extract the call override def applyPath(p: Path)(k: Path => Block): Block = p match - case RefOfBms(l, disamb) if ctx.bmsReqdInfo.contains(l) && !ctx.isModOrObj(l) => + case RefOfBms(l, S(d)) if ctx.bmsReqdInfo.contains(l) && !ctx.isModOrObj(l) => val newSym = closureMap.get(l) match case None => // $this was previously used, but it may be confused with the `this` keyword // let's use $here instead val newSym = TempSymbol(N, l.nme + "$here") - syms.addOne(l -> newSym) // add to `syms`: this closure will be initialized in `applyBlock` + syms.addOne(FunSyms(l, d) -> newSym) // add to `syms`: this closure will be initialized in `applyBlock` closureMap.addOne(l -> newSym) // add to `closureMap`: `newSym` refers to the closure and can be used later newSym @@ -672,9 +678,9 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): case Some(value) if activeClosures.contains(value) => value // symbol exists, needs initialization case Some(value) => - syms.addOne(l -> value) + syms.addOne(FunSyms(l, d) -> value) value - k(Value.Ref(newSym, disamb)) + k(Value.Ref(newSym, S(d))) case _ => super.applyPath(p)(k) (walker.applyBlock(b), syms.toList) end rewriteBms @@ -692,7 +698,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): val pre = syms.foldLeft(blockBuilder): case (blk, (bms, local)) => val initial = blk.assign(local, createCall(bms, ctx)) - ctx.defns(bms) match + ctx.defns(bms.b) match case c: ClsLikeDefn => initial.assignFieldN(local.asPath, Tree.Ident("class"), bms.asPath) case _ => initial @@ -769,7 +775,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): case Some(sym) if !ctx.ignored(d.sym) => ctx.getBmsReqdInfo(d.sym) match case Some(_) => // has args blockBuilder - .assign(sym, Instantiate(mut = false, d.sym.asPath, getCallArgs(d.sym, ctx))) + .assign(sym, Instantiate(mut = false, d.sym.asPath, getCallArgs(FunSyms(d.sym, d.isym), ctx))) .rest(applyBlock(rest)) case None => // has no args // Objects with no parameters are instantiated statically @@ -831,8 +837,8 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): // When calling a lifted function or constructor, we need to pass, as arguments, the local variables, // inner symbols, etc that it needs to access. This function creates those arguments for that in // the correct order. - def getCallArgs(sym: BlockMemberSymbol, ctx: LifterCtx) = - val info = ctx.getBmsReqdInfo(sym).get + def getCallArgs(sym: FunSyms[?], ctx: LifterCtx) = + val info = ctx.getBmsReqdInfo(sym.b).get val localsArgs = info.reqdVars.map(s => ctx.getLocalPath(s).get.asArg) val capturesArgs = info.reqdCaptures.map(ctx.getCapturePath(_).get.asArg) val iSymArgs = info.reqdInnerSyms.map(ctx.getIsymPath(_).get.asArg) @@ -840,12 +846,12 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): bmsArgs ++ iSymArgs ++ localsArgs ++ capturesArgs // This creates a call to a lifted function or constructor. - def createCall(sym: BlockMemberSymbol, ctx: LifterCtx): Call = - val info = ctx.getBmsReqdInfo(sym).get + def createCall(syms: FunSyms[?], ctx: LifterCtx): Call = + val info = ctx.getBmsReqdInfo(syms.b).get val callSym = info.fakeCtorBms match case Some(v) => v - case None => sym - Call(callSym.asPath, getCallArgs(sym, ctx))(false, false, false) + case None => syms + Call(callSym.asPath, getCallArgs(syms, ctx))(false, false, false) /* * Explanation of liftOutDefnCont, liftDefnsInCls, liftDefnsInFn: @@ -937,7 +943,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): val newDef = FunDefn( base.owner, f.sym, f.dSym, PlainParamList(extraParams) :: f.params, f.body - )(f.isTailRec) + )(f.forceTailRec) val Lifted(lifted, extras) = liftDefnsInFn(newDef, newCtx) val args1 = extraParamsCpy.map(p => p.sym.asPath.asArg) @@ -947,7 +953,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): .ret(Call(singleCallBms.asPath, args1 ++ args2)(true, false, false)) // TODO: restParams not considered val mainDefn = FunDefn(f.owner, f.sym, f.dSym, PlainParamList(extraParamsCpy) :: headPlistCopy :: Nil, bdy)(false) - val auxDefn = FunDefn.withFreshSymbol(N, singleCallBms, flatPlist, lifted.body)(isTailRec = f.isTailRec) + val auxDefn = FunDefn(N, singleCallBms.b, singleCallBms.d, flatPlist, lifted.body)(forceTailRec = f.forceTailRec) if ctx.firstClsFns.contains(f.sym) then Lifted(mainDefn, auxDefn :: extras) @@ -1078,7 +1084,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): case Some(value) => (ParamList(value.flags, extraPlist.params ++ value.params, value.restParam), auxPlist) - val auxCtorDefn_ = FunDefn.withFreshSymbol(None, singleCallBms, headParams :: newAuxPlist, bod)(false) + val auxCtorDefn_ = FunDefn(None, singleCallBms.b, singleCallBms.d, headParams :: newAuxPlist, bod)(false) val auxCtorDefn = BlockTransformer(subst).applyFunDefn(auxCtorDefn_) // Lifted(lifted, extras ::: (fakeCtorDefn :: auxCtorDefn :: Nil)) @@ -1180,7 +1186,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): case value => S(value.copy(methods = cMethods, ctor = newCCtor.get)) val extras = (ctorDefnsLifted ++ fExtra ++ cfExtra ++ ctorIgnoredExtra).map: - case f: FunDefn => f.copy(owner = N)(isTailRec = f.isTailRec) + case f: FunDefn => f.copy(owner = N)(forceTailRec = f.forceTailRec) case c: ClsLikeDefn => c.copy(owner = N) case d => d @@ -1245,7 +1251,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): val transformed = BlockRewriter(ctx.inScopeISyms, captureCtx.addreplacedDefns(ignoredRewrite)).applyBlock(blk) if thisVars.reqCapture.size == 0 then - Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, transformed)(isTailRec = f.isTailRec), newDefns) + Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, transformed)(forceTailRec = f.forceTailRec), newDefns) else // move the function's parameters to the capture val paramsSet = f.params.flatMap(_.paramSyms) @@ -1256,7 +1262,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise): .assign(captureSym, Instantiate(mut = true, // * Note: `mut` is needed for capture classes captureCls.sym.asPath, paramsList)) .rest(transformed) - Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, bod)(isTailRec = f.isTailRec), captureCls :: newDefns) + Lifted(FunDefn(f.owner, f.sym, f.dSym, f.params, bod)(forceTailRec = f.forceTailRec), captureCls :: newDefns) end liftDefnsInFn diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 12a236cf8..ca6242e94 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -294,7 +294,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case (sym, params, split) => val paramLists = params :: Nil val bodyBlock = ucs.Normalization(this)(split)(Ret) - FunDefn.withFreshSymbol(N, sym, paramLists, bodyBlock)(isTailRec = false) + FunDefn.withFreshSymbol(N, sym, paramLists, bodyBlock)(forceTailRec = false) // The return type is intended to be consistent with `gatherMembers` (mtds, Nil, Nil, End()) case _ => gatherMembers(defn.body) @@ -510,12 +510,12 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val isOr = sym is State.orSymbol if isAnd || isOr then val lamSym = BlockMemberSymbol("lambda", Nil, false) - val lamDef = FunDefn.withFreshSymbol(N, lamSym, PlainParamList(Nil) :: Nil, returnedTerm(arg2))(isTailRec = false) + val lamDef = FunDefn.withFreshSymbol(N, lamSym, PlainParamList(Nil) :: Nil, returnedTerm(arg2))(forceTailRec = false) Define( lamDef, k(Call( Value.Ref(State.runtimeSymbol).selN(Tree.Ident(if isAnd then "short_and" else "short_or")), - Arg(N, ar1) :: Arg(N, Value.Ref(lamSym, N)) :: Nil + Arg(N, ar1) :: Arg(N, lamDef.asPath) :: Nil )(true, false, false))) else subTerm_nonTail(arg2): ar2 => @@ -647,10 +647,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx): then k(Lambda(paramLists.head, bodyBlock)) else val lamSym = new BlockMemberSymbol("lambda", Nil, false) - val lamDef = FunDefn.withFreshSymbol(N, lamSym, paramLists, bodyBlock)(isTailRec = false) + val lamDef = FunDefn.withFreshSymbol(N, lamSym, paramLists, bodyBlock)(forceTailRec = false) Define( lamDef, - k(Value.Ref(lamSym, N))) + k(lamDef.asPath)) case iftrm: st.IfLike => ucs.Normalization(this)(iftrm)(k) @@ -987,8 +987,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case p: Path => k(p) case Lambda(params, body) => val lamSym = BlockMemberSymbol("lambda", Nil, false) - val lamDef = FunDefn.withFreshSymbol(N, lamSym, params :: Nil, body)(isTailRec = false) - Define(lamDef, k(Value.Ref(lamSym, N))) + val lamDef = FunDefn.withFreshSymbol(N, lamSym, params :: Nil, body)(forceTailRec = false) + Define(lamDef, k(lamDef.asPath)) case r => val l = new TempSymbol(N) Assign(l, r, k(l |> Value.Ref.apply)) @@ -1021,10 +1021,14 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val merged = MergeMatchArmTransformer.applyBlock(bufferable) - val res = + val staged = if config.stageCode then Instrumentation(using summon).applyBlock(merged) else merged + val res = + if config.tailRecOpt then TailRecOpt().transform(staged) + else staged + Program( imps.map(imp => imp.sym -> imp.str), res @@ -1055,7 +1059,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case Annot.Untyped => () case a @ Annot.TailRec => target match - case TermDefinition(body = S(bod), k = syntax.Fun) => warn(a, S(msg"Tail call optimization is not yet implemented.")) + case TermDefinition(body = S(bod), k = syntax.Fun) => () case TermDefinition(k = syntax.Fun) => warn(a, S(msg"Only functions with a body may be marked as @tailrec.")) case _ => warn(a) @@ -1076,9 +1080,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case Annot.Untyped => () case a @ Annot.TailCall => receiver match case st.App(Ref(_: BuiltinSymbol), _) => warn(a, S(msg"The @tailcall annotation has no effect on calls to built-in symbols.")) - case st.App(_, _) => warn(a, S(msg"Tail call optimization is not yet implemented.")) + case st.App(_, _) => () case st.Resolved(_, defnSym) => defnSym.defn match - case S(td: TermDefinition) if (td.k is syntax.Fun) && td.params.isEmpty => warn(a, S(msg"Tail call optimization is not yet implemented.")) + case S(td: TermDefinition) if (td.k is syntax.Fun) && td.params.isEmpty => () case _ => warn(a) case _ => warn(a) case annot => warn(annot) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala index 98a2d85d0..a98afafca 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala @@ -40,7 +40,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, doUnwindMap: Map[ def wrapStackSafe(body: Block, resSym: Local, rest: Block) = val bodSym = BlockMemberSymbol("‹stack safe body›", Nil, false) - val bodFun = FunDefn.withFreshSymbol(N, bodSym, ParamList(ParamListFlags.empty, Nil, N) :: Nil, body)(isTailRec = false) + val bodFun = FunDefn.withFreshSymbol(N, bodSym, ParamList(ParamListFlags.empty, Nil, N) :: Nil, body)(forceTailRec = false) Define(bodFun, Assign(resSym, Call(runStackSafePath, intLit(depthLimit).asArg :: bodSym.asPath.asArg :: Nil)(true, true, false), rest)) def extractResTopLevel(res: Result, isTailCall: Bool, f: Result => Block, sym: Option[Symbol], curDepth: => Symbol) = @@ -173,6 +173,6 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, doUnwindMap: Map[ def rewriteFn(defn: FunDefn) = if doUnwindFns.contains(defn.sym) then defn - else FunDefn(defn.owner, defn.sym, defn.dSym, defn.params, rewriteBlk(defn.body, L(defn.sym), 1))(defn.isTailRec) + else FunDefn(defn.owner, defn.sym, defn.dSym, defn.params, rewriteBlk(defn.body, L(defn.sym), 1))(defn.forceTailRec) def transformTopLevel(b: Block) = transform(b, TempSymbol(N), true) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala new file mode 100644 index 000000000..3bb85d90d --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/TailRecOpt.scala @@ -0,0 +1,392 @@ +package hkmc2 + +import scala.collection.mutable + +import mlscript.utils.*, shorthands.* +import utils.* + +import hkmc2.codegen.* +import hkmc2.semantics.* +import hkmc2.Message.* +import hkmc2.semantics.Elaborator.State +import hkmc2.syntax.Tree +import scala.collection.mutable.ArrayBuffer +import java.lang.instrument.ClassDefinition + +// This optimization assumes the lifter has been run. +class TailRecOpt(using State, TL, Raise): + + object CallToFun: + def unapply(c: Call): Opt[TermSymbol] = c match + case Call(fun = Value.Ref(b, S(r: TermSymbol))) => S(r) + case Call(fun = s: Select) => s.symbol match + case Some(r: TermSymbol) => S(r) + case _ => N + case _ => N + + object TailCallShape: + def unapply(b: Block): Opt[(TermSymbol, Call)] = b match + case Return(c @ CallToFun(r), _) => S((r, c)) + case Assign(a, c @ CallToFun(r), Return(Value.Ref(b, _), _)) if a === b => S((r, c)) + case _ => N + + + enum CallEdge: + val f1: TermSymbol + val f2: TermSymbol + val call: Call + case TailCall(f1: TermSymbol, f2: TermSymbol)(val call: Call) + case NormalCall(f1: TermSymbol, f2: TermSymbol)(val call: Call) + + class CallFinder(f: FunDefn) extends BlockTraverserShallow: + + var edges: List[CallEdge] = Nil + + def find = + // Ignore functions with multiple parameter lists + if f.params.length > 1 then + if f.forceTailRec then + raise(ErrorReport(msg"Functions with more than one parameter list may not be marked @tailrec." -> f.dSym.toLoc :: Nil)) + Nil + else + edges = Nil + applyBlock(f.body) + edges + + override def applyBlock(b: Block): Unit = b match + case TailCallShape(r, c) => edges ::= CallEdge.TailCall(f.dSym, r)(c) + case Return(c: Call, _) => + if c.explicitTailCall then + raise(ErrorReport(msg"Only direct calls in tail position may be marked @tailcall." -> c.toLoc :: Nil)) + case _ => super.applyBlock(b) + + override def applyResult(r: Result): Unit = r match + case c: Call => + if c.explicitTailCall then + raise(ErrorReport(msg"This call is not in tail position." -> c.toLoc :: Nil)) + c match + case CallToFun(r) => edges ::= CallEdge.NormalCall(f.dSym, r)(c) + case _ => + case _ => super.applyResult(r) + + def buildCallGraph(fs: List[FunDefn]): List[CallEdge] = + fs.flatMap(f => CallFinder(f).find) + + case class SccOfCalls(funs: List[FunDefn], calls: List[CallEdge]) + + def partFns(fs: List[FunDefn]): List[SccOfCalls] = + val defnSyms = fs.map(_.dSym) + val tsToDefn = fs.map(f => f.dSym -> f).toMap + + // Only care about calls to functions in the same scope + // Note that the results may differ if the lifter has been run. + val cg = buildCallGraph(fs).filter: c => + val cond = defnSyms.contains(c.f1) && defnSyms.contains(c.f2) + c.match + case c: CallEdge.TailCall if c.call.explicitTailCall && !cond => + raise(ErrorReport( + msg"This tail call exits the current scope is not optimized." -> c.call.toLoc :: Nil)) + case _ => + cond + + val cgTup = cg.map(c => (c.f1, c.f2)) + val sccs = algorithms.sccsWithInfo(cgTup, defnSyms) + + // partition the call graph + val sccMap = sccs.sccs.flatMap: + case (id, scc) => scc.map(f => f -> id) + + val cgLabelled = cg + .groupBy: c => + val s1 = sccMap(c.f1) + val s2 = sccMap(c.f2) + if s1 =/= s2 && c.call.explicitTailCall then + raise(ErrorReport( + msg"This call is not optimized as it does not directly recurse through its parent function." -> c.call.toLoc :: Nil)) + -1 + else s1 + .filter: + (id, _) => id =/= -1 + + sccs.sccs.toList.map: v => + val (id, tss) = v + val cgs = cgLabelled.get(id) match + case Some(value) => value + case None => Nil + SccOfCalls(tss.map(tsToDefn), cgs) + + def maxInt[T](items: List[T], f: T => Int): Int = items.foldLeft(0): + case (l, item) => + val x = f(item) + if x > l then x else l + + def getParamSyms(f: FunDefn) = f.params.headOption match + case Some(ParamList(_, params, S(rest))) => + params.map(_.sym).appended(rest.sym) + case Some(p) => p.params.map(_.sym) + case None => Nil + + // assume only one parameter list + def paramsLen(f: FunDefn): Int = f.params match + case head :: next => + if head.restParam.isDefined then 1 + head.params.length + else head.params.length + case Nil => 0 + + def rewriteCallArgs(f: FunDefn, c: Call): Opt[List[Result]] = + // need to be careful in handling restParams + // if any arg is a spread that spreads across multiple parameters, then + // we ignore it for now + val ret = f.params match + case head :: Nil => + val (headArgs, restArgs) = head.restParam match + case Some(value) => c.args.splitAt(head.params.length) + case None => (c.args, Nil) + + var bad = false + val hd = for a <- headArgs yield a.spread match + case Some(true) => + if c.explicitTailCall then + raise(ErrorReport(msg"Spreads are not yet fully supported in calls marked @tailcall." -> a.value.toLoc :: Nil)) + bad = true + a.value + case _ => a.value + if bad then return N + + if head.restParam.isDefined then + val rest = + restArgs match + case Arg(S(true), value) :: Nil => value + case _ => Tuple(true, restArgs) + hd.appended(rest) + else + hd + case Nil => c.args.map(_.value) + case _ => return N + S(ret) + + def optScc(scc: SccOfCalls, owner: Opt[InnerSymbol]): (Opt[FunDefn], List[FunDefn]) = + // sort the functions so the order is more predictable + val funs = scc.funs.sortBy(f => f.dSym.uid) + // remove calls which don't flow into this scc + val fSyms = funs.map(_.dSym).toSet + + val calls = scc.calls.filter(c => fSyms.contains(c.f2)) + + val nonTailCallsLs = calls + .collect: + case c: CallEdge.NormalCall => c.f2 -> c.call + val nonTailCalls = nonTailCallsLs.toMap + + if nonTailCallsLs.sizeCompare(calls) === 0 then + for f <- funs if f.forceTailRec do + raise(WarningReport(msg"This function does not directly self-recurse, but is marked @tailrec." -> f.dSym.toLoc :: Nil)) + return (N, funs) + + if !nonTailCalls.isEmpty then + for f <- funs if f.forceTailRec do + val reportLoc = nonTailCalls.get(f.dSym) match + // always display a call to f, if possible + case Some(value) => value.toLoc + case None => nonTailCalls.head._2.toLoc + raise(ErrorReport( + msg"This function is not tail recursive." -> f.dSym.toLoc + :: msg"It could self-recurse through this call, which is not a tail call." -> reportLoc + :: Nil + )) + + val maxParamLen = maxInt(funs, paramsLen) + val paramSyms = + if funs.length === 1 then (getParamSyms(funs.head)) + else + for i <- 0 to maxParamLen - 1 yield VarSymbol(Tree.Ident("param" + i)) + .toList + val paramSymsArr = ArrayBuffer.from(paramSyms) + val dSymIds = funs.map(_.dSym).zipWithIndex.toMap + val bms = + if funs.size === 1 then funs.head.sym + else BlockMemberSymbol(funs.map(_.sym.nme).mkString("_"), Nil, true) + val dSym = + if funs.size === 1 then funs.head.dSym + else TermSymbol(syntax.Fun, owner, Tree.Ident(bms.nme)) + val loopSym = TempSymbol(N, "loopLabel") + val curIdSym = VarSymbol(Tree.Ident("id")) + + class FunRewriter(f: FunDefn) extends BlockTransformerShallow(SymbolSubst()): + val params = getParamSyms(f) + val paramsSet = f.params.toSet + val paramsIdxes = params.zipWithIndex.toMap + + val symRewriter = new BlockTransformer(SymbolSubst()): + def applyVarSym(l: VarSymbol): VarSymbol = paramsIdxes.get(l) match + case Some(idx) => paramSymsArr(idx) + case _ => l + + override def applyValue(v: Value)(k: Value => Block): Block = v match + case Value.Ref(l: VarSymbol, d) => + val s = applyVarSym(l) + if s is l then k(v) + else k(Value.Ref(s, d)) + case _ => super.applyValue(v)(k) + + + override def applyBlock(b: Block): Block = b match + case TailCallShape(dSym, c) => dSymIds.get(dSym) match + case Some(id) => + val argVals = rewriteCallArgs(f, c) match + case Some(value) => value + case None => return super.applyBlock(b) + val cont = + if scc.funs.size === 1 then Continue(loopSym) + else Assign(curIdSym, Value.Lit(Tree.IntLit(dSymIds(dSym))), Continue(loopSym)) + + // In some cases, we could have assignments like this: + // param0 = whatever + // param1 = + // which means param1's value is incorrect. + // We should thus assign the params to temporary symbols + // if they are needed for a subsequent assignment. + var assignedSyms: Map[VarSymbol, TempSymbol] = paramSyms.map: + case sym => sym -> TempSymbol(N, sym.nme + "_tmp") + .toMap + var requiredTmps: Set[(VarSymbol, TempSymbol)] = Set.empty + + val paramRewriter = new BlockDataTransformer(SymbolSubst()): + override def applyValue(v: Value)(k: Value => Block): Block = v match + case Value.Ref(l: VarSymbol, disamb) => assignedSyms.get(l) match + case S(v) => + requiredTmps += (l, v) + k(Value.Ref(v, disamb)) + case _ => super.applyValue(v)(k) + case _ => super.applyValue(v)(k) + + // remove symbols from assignedSyms as we encounter them + // note that foldRight will call the function right to left + val assigns = paramSyms.zip(argVals).foldRight[Block](cont): (v, acc) => + val (sym, res) = v + assignedSyms -= sym + val ret = applyResult(res)(Assign(sym, _, acc)) match + case Assign(sym, res, rest) => paramRewriter.applyResult(res)(Assign(sym, _, rest)) match + case Assign(sym, Value.Ref(sym1, _), rest) if sym === sym1 => rest + case x => x + case x => x + ret + // bind the tmps + requiredTmps.toList.foldRight(assigns): + case ((v, l), acc) => Assign(l, Value.Ref(v), acc) + case None => super.applyBlock(b) + case _ => super.applyBlock(b) + + def rewrite(b: Block) = + applyBlock(symRewriter.applyBlock(b)) + + val arms = funs.map: f => + Case.Lit(Tree.IntLit(dSymIds(f.dSym))) -> FunRewriter(f).rewrite(f.body) + + val switch = + if arms.length === 1 then arms.head._2 + else Match(curIdSym.asPath, arms, N, End()) + + val loop = Label(loopSym, true, switch, End()) + + val sel = owner match + case Some(value) => Select(Value.Ref(value, N), Tree.Ident(bms.nme))(S(dSym)) + case None => Value.Ref(bms, S(dSym)) + + val rewrittenFuns = + if funs.size === 1 then Nil + else funs.map: f => + val paramArgs = getParamSyms(f).map(_.asPath.asArg) + val args = + Value.Lit(Tree.IntLit(dSymIds(f.dSym))).asArg + :: paramArgs + ::: List.fill(maxParamLen - paramArgs.length)(Value.Lit(Tree.UnitLit(false)).asArg) + val newBod = Return( + Call(sel, args)(true, false, false), + false + ) + FunDefn(f.owner, f.sym, f.dSym, f.params, newBod)(false) + + val params = + val initial = paramSyms.map(Param.simple(_)) + if funs.length === 1 then initial + else Param.simple(curIdSym) :: initial + + val loopDefn = FunDefn( + owner, bms, dSym, + PlainParamList(params) :: Nil, + loop)(false) + + if funs.size === 1 then (N, loopDefn :: Nil) + else (S(loopDefn), rewrittenFuns) + + def optFunctions(fs: List[FunDefn], owner: Opt[InnerSymbol]) = + val (newFsOpt, fsOpt) = partFns(fs).map(optScc(_, owner)).foldLeft[(List[FunDefn], List[FunDefn])](Nil, Nil): + case ((newFns, fns), (newFnOpt, fns_)) => newFnOpt match + case Some(value) => (value :: newFns, fns_ ::: fns) + case None => (newFns, fns_ ::: fns) + // preserve the order of function defns + val fMap = fsOpt.map(f => (f.dSym, f)).toMap + val fsRet = fs.map(f => fMap(f.dSym)) + (newFsOpt, fsRet) + + def reportClassesTailrec(c: ClsLikeDefn) = + new BlockTraverserShallow(): + for f <- c.methods do + applyBlock(f.body) + if f.forceTailRec then + raise(ErrorReport(msg"Class methods may not yet be marked @tailrec." -> f.dSym.toLoc :: Nil)) + override def applyResult(r: Result): Unit = r match + case c: Call if c.explicitTailCall => + raise(ErrorReport(msg"Calls from class methods cannot yet be marked @tailcall." -> c.toLoc :: Nil)) + case _ => super.applyResult(r) + + def optFunctionsFlat(fs: List[FunDefn], owner: Opt[InnerSymbol]) = + val (a, b) = optFunctions(fs, owner) + a ::: b + + def optClasses(cs: List[ClsLikeDefn]) = cs.map: c => + // Class methods cannot yet be optimized as they cannot yet be marked final. + + if c.k is syntax.Cls then + reportClassesTailrec(c) + val companion = c.companion.map: comp => + val cMtds = optFunctionsFlat(comp.methods, S(comp.isym)) + comp.copy(methods = cMtds) + c.copy(companion = companion) + else + val mtds = optFunctionsFlat(c.methods, S(c.isym)) + val companion = c.companion.map: comp => + val cMtds = optFunctionsFlat(comp.methods, S(comp.isym)) + comp.copy(methods = cMtds) + c.copy(methods = mtds, companion = companion) + + def transform(b: Block) = + val (blk, defns) = b.floatOutDefns() + val (funs, clses) = defns.partitionMap: + case f: FunDefn => L(f) + case c: ClsLikeDefn => R(c) + case _ => die // unreachable as floatOutDefns only floats out FunDefns and ClsLikeDefns + val (optFNew, optF) = optFunctions(funs, N) + val optC = optClasses(clses) + + val fMap = optF.map(f => f.dSym -> f).toMap + // Scala needs this annotation to type check for some reason + val cMap: Map[DefinitionSymbol[? <: ClassLikeDef] & InnerSymbol, ClsLikeDefn] = + optC.map(c => c.isym -> c).toMap + + // replace them in place + val transformer = new BlockTransformerShallow(SymbolSubst()): + override def applyDefn(defn: Defn)(k: Defn => Block): Block = defn match + case f: FunDefn => fMap.get(f.dSym) match + case Some(value) => k(value) + case None => k(f) + + case c: ClsLikeDefn => cMap.get(c.isym) match + case Some(value) => k(value) + case None => k(c) + + case _ => super.applyDefn(defn)(k) + + optFNew.foldLeft(transformer.applyBlock(b)): + case (acc, f) => Define(f, acc) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala index 31432b48e..622e5d8f0 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/UsedVarAnalyzer.scala @@ -368,7 +368,7 @@ class UsedVarAnalyzer(b: Block, handlerPaths: Opt[HandlerPaths])(using State): handleCalledBms(l) case Instantiate(mut, InstSel(l), args) => args.map(super.applyArg) - handleCalledBms(l) + handleCalledBms(l._1) case _ => super.applyResult(r) override def applyPath(p: Path): Unit = p match diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala index 1f9113e75..ae0909c2e 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala @@ -266,6 +266,10 @@ class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.I def subst(using sub: SymbolSubst): TermSymbol = sub.mapTermSym(this) +object TermSymbol: + def fromFunBms(b: BlockMemberSymbol, owner: Opt[InnerSymbol])(using State) = + TermSymbol(syntax.Fun, owner, Tree.Ident(b.nme)) + sealed trait CtorSymbol extends Symbol: def subst(using sub: SymbolSubst): CtorSymbol = sub.mapCtorSym(this) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala index 066c807d2..9377f0881 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala @@ -328,6 +328,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e // The symbol for the loop label if the term is a `while`. lazy val loopLabel = new TempSymbol(t) lazy val f = new BlockMemberSymbol("while", Nil, false) + lazy val tSym = TermSymbol.fromFunBms(f, N) val normalized = tl.scoped("ucs:normalize"): normalize(inputSplit)(using VarSet()) tl.scoped("ucs:normalized"): @@ -338,7 +339,7 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e lazy val breakRoot = (r: Result) => Assign(l, r, Break(rootBreakLabel)) lazy val assignResult = (r: Result) => Assign(l, r, End()) val loopCont = if config.rewriteWhileLoops - then Return(Call(Value.Ref(f, N), Nil)(true, true, false), false) + then Return(Call(Value.Ref(f, S(tSym)), Nil)(true, true, false), false) else Continue(loopLabel) val cont = if kw === `while` then @@ -398,8 +399,8 @@ class Normalization(lowering: Lowering)(using tl: TL)(using Raise, Ctx, State) e Select(Value.Ref(State.runtimeSymbol), Tree.Ident("LoopEnd"))(S(State.loopEndSymbol)) val blk = blockBuilder .assign(l, Value.Lit(Tree.UnitLit(false))) - .define(FunDefn.withFreshSymbol(N, f, PlainParamList(Nil) :: Nil, Begin(body, Return(loopEnd, false)))(isTailRec = false)) - .assign(loopResult, Call(Value.Ref(f, N), Nil)(true, true, false)) + .define(FunDefn(N, f, tSym, PlainParamList(Nil) :: Nil, Begin(body, Return(loopEnd, false)))(forceTailRec = false)) + .assign(loopResult, Call(Value.Ref(f, S(tSym)), Nil)(true, true, false)) if summon[LoweringCtx].mayRet then blk .assign(isReturned, Call(Value.Ref(State.builtinOpsMap("!==")), diff --git a/hkmc2/shared/src/test/mlscript/codegen/FieldSymbols.mls b/hkmc2/shared/src/test/mlscript/codegen/FieldSymbols.mls index 6466aeb02..dd62328b1 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/FieldSymbols.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/FieldSymbols.mls @@ -141,9 +141,9 @@ case //│ rest = End of "" //│ rest = Assign: \ //│ lhs = $block$res -//│ rhs = Ref: +//│ rhs = Ref{disamb=term:lambda}: //│ l = member:lambda -//│ disamb = N +//│ disamb = S of term:lambda //│ rest = Return: \ //│ res = Lit of UnitLit of false //│ implct = true diff --git a/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls b/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls index 10940ea07..c90c88145 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/RandomStuff.mls @@ -7,14 +7,17 @@ fun foo() = if false do foo() //│ let foo; //│ foo = function foo() { //│ let scrut; -//│ scrut = false; -//│ if (scrut === true) { return foo() } else { return runtime.Unit } +//│ loopLabel: while (true) { +//│ scrut = false; +//│ if (scrut === true) { continue loopLabel } else { return runtime.Unit } +//│ break; +//│ } //│ }; :sjs fun foo() = foo() //│ JS (unsanitized): -//│ let foo1; foo1 = function foo() { return foo1() }; +//│ let foo1; foo1 = function foo() { loopLabel: while (true) { continue loopLabel; break; } }; :sjs @@ -63,7 +66,10 @@ do fun f = f () //│ JS (unsanitized): -//│ let f, f1; f1 = 1; f = function f() { let tmp; tmp = f(); return tmp }; runtime.Unit +//│ let f, f1; +//│ f1 = 1; +//│ f = function f() { loopLabel: while (true) { continue loopLabel; break; } }; +//│ runtime.Unit //│ f = 1 :sjs @@ -71,10 +77,10 @@ do let foo = 1 fun foo(x) = foo //│ ╔══[ERROR] Name 'foo' is already used -//│ ║ l.71: let foo = 1 +//│ ║ l.77: let foo = 1 //│ ║ ^^^^^^^ //│ ╟── by a member declared in the same block -//│ ║ l.72: fun foo(x) = foo +//│ ║ l.78: fun foo(x) = foo //│ ╙── ^^^^^^^^^^^^^^^^ //│ JS (unsanitized): //│ let foo3, foo4; foo3 = function foo(x) { return foo4 }; foo4 = 1; diff --git a/hkmc2/shared/src/test/mlscript/codegen/While.mls b/hkmc2/shared/src/test/mlscript/codegen/While.mls index 407c7639f..a916c4e47 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/While.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/While.mls @@ -61,13 +61,16 @@ while x //│ tmp4 = undefined; //│ while3 = (undefined, function () { //│ let tmp6; -//│ if (x2 === true) { -//│ tmp6 = Predef.print("Hello World"); -//│ x2 = false; -//│ tmp4 = runtime.Unit; -//│ return while3() -//│ } else { tmp4 = 42; } -//│ return runtime.LoopEnd +//│ loopLabel: while (true) { +//│ if (x2 === true) { +//│ tmp6 = Predef.print("Hello World"); +//│ x2 = false; +//│ tmp4 = runtime.Unit; +//│ continue loopLabel +//│ } else { tmp4 = 42; } +//│ return runtime.LoopEnd; +//│ break; +//│ } //│ }); //│ tmp5 = while3(); //│ tmp4 @@ -297,10 +300,10 @@ while print("Hello World"); false then 0(0) else 1 //│ ╔══[PARSE ERROR] Unexpected 'then' keyword here -//│ ║ l.297: then 0(0) +//│ ║ l.300: then 0(0) //│ ╙── ^^^^ //│ ╔══[ERROR] Unrecognized term split (false literal) -//│ ║ l.296: while print("Hello World"); false +//│ ║ l.299: while print("Hello World"); false //│ ╙── ^^^^^ //│ > Hello World //│ ═══[RUNTIME ERROR] Error: match error @@ -310,12 +313,12 @@ while { print("Hello World"), false } then 0(0) else 1 //│ ╔══[ERROR] Unexpected infix use of keyword 'then' here -//│ ║ l.309: while { print("Hello World"), false } +//│ ║ l.312: while { print("Hello World"), false } //│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.310: then 0(0) +//│ ║ l.313: then 0(0) //│ ╙── ^^^^^^^^^^^ //│ ╔══[ERROR] Illegal position for prefix keyword 'else'. -//│ ║ l.311: else 1 +//│ ║ l.314: else 1 //│ ╙── ^^^^ :fixme @@ -325,14 +328,14 @@ while then 0(0) else 1 //│ ╔══[ERROR] Unexpected infix use of keyword 'then' here -//│ ║ l.323: print("Hello World") +//│ ║ l.326: print("Hello World") //│ ║ ^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.324: false +//│ ║ l.327: false //│ ║ ^^^^^^^^^ -//│ ║ l.325: then 0(0) +//│ ║ l.328: then 0(0) //│ ╙── ^^^^^^^^^^^ //│ ╔══[ERROR] Illegal position for prefix keyword 'else'. -//│ ║ l.326: else 1 +//│ ║ l.329: else 1 //│ ╙── ^^^^ @@ -362,6 +365,7 @@ class Lazy[out A](f: () -> A) with cached :expect [1, 2, 3] +:lift let arr = [1, 2, 3] let output = mut [] let i = 0 diff --git a/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls b/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls index 01dafc5f6..3b67e7af1 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/StackSafety.mls @@ -1,4 +1,5 @@ :js +:noTailRec // * FIXME: Why doesn't the following work when using Predef function `(==) equals`? diff --git a/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls b/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls index ceac941c0..e628fd987 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/EffectHandlers.mls @@ -13,3 +13,13 @@ module A with class Test with f() val a = 1 + +class Eff + +:effectHandlers +:expect 7 +handle h = Eff with + fun perform(x)(k) = k(x) + 1 +let x = 5 +h.perform(x) + 1 +//│ = 7 diff --git a/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls b/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls index 4baf4109e..075ed1108 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/StackSafetyLift.mls @@ -1,5 +1,6 @@ :js :lift +:noTailRec // * FIXME: Why doesn't the following work when using Predef function `(==) equals`? diff --git a/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls b/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls index de1d89150..10f1968a4 100644 --- a/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls +++ b/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls @@ -1,6 +1,7 @@ :js :llir :cpp +:noTailRec // The LLIR lowering does not yet support Label blocks // This file contains all tests for LLIR in the original MLscript compiler. diff --git a/hkmc2/shared/src/test/mlscript/llir/Classes.mls b/hkmc2/shared/src/test/mlscript/llir/Classes.mls index 2bfb5b243..690d5cccc 100644 --- a/hkmc2/shared/src/test/mlscript/llir/Classes.mls +++ b/hkmc2/shared/src/test/mlscript/llir/Classes.mls @@ -1,5 +1,6 @@ :llir :cpp +:noTailRec // The LLIR lowering does not yet support Label blocks :intl abstract class Callable diff --git a/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls b/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls index c7bb017e9..ad6494b04 100644 --- a/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls +++ b/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls @@ -1,6 +1,7 @@ :js :llir :cpp +:noTailRec // The LLIR lowering does not yet support Label blocks :sllir :intl diff --git a/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls b/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls index 867c7146f..3492d2323 100644 --- a/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls +++ b/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls @@ -2,6 +2,7 @@ :llir :cpp :intl +:noTailRec // The LLIR lowering does not yet support Label blocks //│ //│ Interpreted: diff --git a/hkmc2/shared/src/test/mlscript/std/FingerTreeListTest.mls b/hkmc2/shared/src/test/mlscript/std/FingerTreeListTest.mls index 01db6e89d..a37c8bf69 100644 --- a/hkmc2/shared/src/test/mlscript/std/FingerTreeListTest.mls +++ b/hkmc2/shared/src/test/mlscript/std/FingerTreeListTest.mls @@ -136,17 +136,22 @@ fun popByIndex(start, end, acc, lft) = //│ let popByIndex; //│ popByIndex = function popByIndex(start, end, acc, lft) { //│ let scrut, tmp34, tmp35, tmp36; -//│ scrut = start >= end; -//│ if (scrut === true) { -//│ return acc -//│ } else { -//│ tmp34 = start + 1; -//│ tmp35 = runtime.safeCall(lft.at(start)); -//│ tmp36 = globalThis.Object.freeze([ -//│ ...acc, -//│ tmp35 -//│ ]); -//│ return popByIndex(tmp34, end, tmp36, lft) +//│ loopLabel: while (true) { +//│ scrut = start >= end; +//│ if (scrut === true) { +//│ return acc +//│ } else { +//│ tmp34 = start + 1; +//│ tmp35 = runtime.safeCall(lft.at(start)); +//│ tmp36 = globalThis.Object.freeze([ +//│ ...acc, +//│ tmp35 +//│ ]); +//│ start = tmp34; +//│ acc = tmp36; +//│ continue loopLabel +//│ } +//│ break; //│ } //│ }; diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls b/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls index 6aaef483f..698e2a7a5 100644 --- a/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls +++ b/hkmc2/shared/src/test/mlscript/tailrec/Annots.mls @@ -18,23 +18,17 @@ class A //│ ╙── ^^^^^ //│ = 4 -:todo @tailrec -fun f = 2 -//│ ═══[WARNING] Tail call optimization is not yet implemented. +fun f = g +fun g = f :w @tailcall fun g = 2 //│ ═══[WARNING] This annotation has no effect. -:todo fun test = - @tailcall f -//│ ╔══[WARNING] This annotation has no effect. -//│ ╟── Tail call optimization is not yet implemented. -//│ ║ l.33: @tailcall f -//│ ╙── ^ + @tailcall test :w let f = 0 @@ -42,15 +36,18 @@ fun test = @tailcall f //│ ╔══[WARNING] This annotation has no effect. //│ ╟── This annotation is not supported on reference terms. -//│ ║ l.42: @tailcall f +//│ ║ l.36: @tailcall f //│ ╙── ^ //│ f = 0 :todo class A with @tailrec - fun f = 2 -//│ ═══[WARNING] Tail call optimization is not yet implemented. + fun f() = g() + fun g() = f() +//│ ╔══[ERROR] Class methods may not yet be marked @tailrec. +//│ ║ l.46: fun f() = g() +//│ ╙── ^ :w class A with @@ -63,7 +60,9 @@ class A module A with @tailrec fun f = 2 -//│ ═══[WARNING] Tail call optimization is not yet implemented. +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.62: fun f = 2 +//│ ╙── ^ :w class A @@ -84,5 +83,5 @@ fun test = @tailcall 1 + 2 //│ ╔══[WARNING] This annotation has no effect. //│ ╟── The @tailcall annotation has no effect on calls to built-in symbols. -//│ ║ l.84: @tailcall 1 + 2 +//│ ║ l.83: @tailcall 1 + 2 //│ ╙── ^^^^^ diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls new file mode 100644 index 000000000..04059a8fb --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/tailrec/Errors.mls @@ -0,0 +1,94 @@ +:js + +:e +fun f(x) = + @tailcall f(x) + f(x) +//│ ╔══[ERROR] This call is not in tail position. +//│ ║ l.5: @tailcall f(x) +//│ ╙── ^^^^ + +:e +fun f(x) = 2 +fun g(x) = @tailcall f(x) +//│ ╔══[ERROR] This call is not optimized as it does not directly recurse through its parent function. +//│ ║ l.13: fun g(x) = @tailcall f(x) +//│ ╙── ^^^^ + +:e +@tailrec fun f(x) = + g(x) +fun g(x) = f(x); f(x) +//│ ╔══[ERROR] This function is not tail recursive. +//│ ║ l.19: @tailrec fun f(x) = +//│ ║ ^ +//│ ╟── It could self-recurse through this call, which is not a tail call. +//│ ║ l.21: fun g(x) = f(x); f(x) +//│ ╙── ^^^^ + +:e +@tailrec fun f(x) = + g(x) +fun g(x) = + f(x) + h(x) +fun h(x) = + g(x) +//│ ╔══[ERROR] This function is not tail recursive. +//│ ║ l.30: @tailrec fun f(x) = +//│ ║ ^ +//│ ╟── It could self-recurse through this call, which is not a tail call. +//│ ║ l.33: f(x) +//│ ╙── ^^^^ + +:e +@tailrec fun f(x) = + g(x) +fun g(x) = + h(x) + f(x) +fun h(x) = + f(x) +//│ ╔══[ERROR] This function is not tail recursive. +//│ ║ l.45: @tailrec fun f(x) = +//│ ║ ^ +//│ ╟── It could self-recurse through this call, which is not a tail call. +//│ ║ l.48: h(x) +//│ ╙── ^^^^ + +:e +module A with + fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) + fun g(x) = if x == 0 then 1 else + @tailcall f(x - 1) + @tailcall f(x - 1) +//│ ╔══[ERROR] This call is not in tail position. +//│ ║ l.63: @tailcall f(x - 1) +//│ ╙── ^^^^^^^^ + +:w +@tailrec +fun f = 2 +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.71: fun f = 2 +//│ ╙── ^ + +:w +module A with + @tailrec + fun f() = 2 +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.79: fun f() = 2 +//│ ╙── ^ + +:fixme // TODO: support +@tailrec +fun foo() = Foo.bar() +module Foo with + fun bar() = foo() +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.86: fun foo() = Foo.bar() +//│ ╙── ^^^ + + + diff --git a/hkmc2/shared/src/test/mlscript/tailrec/Simple.mls b/hkmc2/shared/src/test/mlscript/tailrec/Simple.mls new file mode 100644 index 000000000..ee607e3a5 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/tailrec/Simple.mls @@ -0,0 +1,50 @@ +:js + + +:sjs +@tailrec +fun f = f +//│ JS (unsanitized): +//│ let f; f = function f() { loopLabel: while (true) { continue loopLabel; break; } }; + +:fixme +module Foo with + @tailrec + fun bar = bar +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.13: fun bar = bar +//│ ╙── ^^^ + +@tailrec +fun f() = f() + +module Foo with + @tailrec + fun bar() = bar() + +fun f(x) = f(x + 1) + +:sjs +fun f(x) = f(f(x) + 1) +//│ JS (unsanitized): +//│ let f3; +//│ f3 = function f(x) { +//│ let tmp, tmp1; +//│ loopLabel: while (true) { tmp = f3(x); tmp1 = tmp + 1; x = tmp1; continue loopLabel; break; } +//│ }; + +module A with + fun f(x) = @tailcall f(f(x) + 1) + + +:fixme // TODO: support +module A with + @tailrec + fun f(x) = B.g(x + 1) +module B with + fun g(x) = A.f(x * 2) +//│ ╔══[WARNING] This function does not directly self-recurse, but is marked @tailrec. +//│ ║ l.43: fun f(x) = B.g(x + 1) +//│ ╙── ^ + + diff --git a/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls new file mode 100644 index 000000000..ea4a8d459 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/tailrec/TailRecOpt.mls @@ -0,0 +1,202 @@ +:js + +:expect 200010000 +fun sum_impl(n, acc) = if n == 0 then acc else sum_impl(n - 1, n + acc) +fun sum(n) = sum_impl(n, 0) +sum(20000) +//│ = 200010000 + +:sjs +:expect 50000 +fun g(a, b, c, d) = f(a, b, c + d) +fun f(a, b, c) = + if + a > 0 then g(a - 1, b, c, 1) + b > 0 then g(a, b - 1, c, 2) + else c +f(10000, 20000, 0) +//│ JS (unsanitized): +//│ let f, g, g_f; +//│ g_f = function g_f(id, param0, param1, param2, param3) { +//│ let scrut, scrut1, tmp, tmp1, tmp2; +//│ loopLabel: while (true) { +//│ if (id === 0) { +//│ tmp = param2 + param3; +//│ param2 = tmp; +//│ id = 1; +//│ continue loopLabel +//│ } else if (id === 1) { +//│ scrut = param0 > 0; +//│ if (scrut === true) { +//│ tmp1 = param0 - 1; +//│ param0 = tmp1; +//│ param3 = 1; +//│ id = 0; +//│ continue loopLabel +//│ } else { +//│ scrut1 = param1 > 0; +//│ if (scrut1 === true) { +//│ tmp2 = param1 - 1; +//│ param1 = tmp2; +//│ param3 = 2; +//│ id = 0; +//│ continue loopLabel +//│ } else { +//│ return param2 +//│ } +//│ } +//│ } +//│ break; +//│ } +//│ }; +//│ g = function g(a, b, c, d) { +//│ return g_f(0, a, b, c, d) +//│ }; +//│ f = function f(a, b, c) { return g_f(1, a, b, c, undefined) }; +//│ f(10000, 20000, 0) +//│ = 50000 + +:sjs +:expect 200010000 +module A with + fun sum_impl(n, acc) = if n == 0 then acc else @tailcall sum_impl(n - 1, n + acc) + fun sum(n) = sum_impl(n, 0) +A.sum(20000) +//│ JS (unsanitized): +//│ let A1; +//│ globalThis.Object.freeze(class A { +//│ static { +//│ A1 = this +//│ } +//│ constructor() { +//│ runtime.Unit; +//│ } +//│ static sum_impl(n, acc) { +//│ let scrut, tmp, tmp1; +//│ loopLabel: while (true) { +//│ scrut = Predef.equals(n, 0); +//│ if (scrut === true) { +//│ return acc +//│ } else { +//│ tmp = n - 1; +//│ tmp1 = n + acc; +//│ n = tmp; +//│ acc = tmp1; +//│ continue loopLabel +//│ } +//│ break; +//│ } +//│ } +//│ static sum(n) { +//│ return A.sum_impl(n, 0) +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "A"]; +//│ }); +//│ A1.sum(20000) +//│ = 200010000 + +:silent +:lift +:expect 200010000 +val x = mut [] +let n = 20000 +let i = 0 +while i < n do + x.push of () => i + set i += 1 +fun sumOf(x, idx, acc) = + if idx == 0 then acc + else + sumOf(x, idx - 1, acc + x![idx]()) +sumOf(x, n - 1, 0) + +// Check that spreads, where supported, are compiled correctly +fun f(x, ...z) = + if x < 0 then 0 + else g(x, 0, x, x) +fun g(x, y, ...z) = + if x < 0 then 0 + else f(x - 1, 0, ...[x, x, x]) +f(100) +//│ = 0 + +:fixme +fun f(x, y, z) = @tailcall f(...[1, 2, 3]) +//│ ═══[ERROR] Spreads are not yet fully supported in calls marked @tailcall. + +:fixme +fun f(x, ...y) = + if x < 0 then g(0, 0, 0) + else 0 +fun g(x, y, z) = + @tailcall f(...[1, 1, 2]) +g(0, 0, 0) +//│ ═══[ERROR] Spreads are not yet fully supported in calls marked @tailcall. +//│ = 0 + +:e +:lift +fun f(x) = + fun g() = + @tailcall f(x) + @tailcall f(x) + @tailcall g() +//│ ╔══[ERROR] This call is not in tail position. +//│ ║ l.142: @tailcall f(x) +//│ ╙── ^ + +:lift +fun f(x) = + fun g() = + @tailcall f(x) + @tailcall g() + +// Functions inside module definitions are lifted to the top level. This means they cannot yet be optimized. +:lift +:fixme +module A with + fun f(x) = + fun g(x) = if x < 0 then 0 else @tailcall f(x) + @tailcall g(x - 1) +A.f(10000) +//│ ╔══[ERROR] This tail call exits the current scope is not optimized. +//│ ║ l.160: fun g(x) = if x < 0 then 0 else @tailcall f(x) +//│ ╙── ^^^ +//│ ╔══[ERROR] This tail call exits the current scope is not optimized. +//│ ║ l.161: @tailcall g(x - 1) +//│ ╙── ^^^^^^^^ +//│ ═══[RUNTIME ERROR] RangeError: Maximum call stack size exceeded + +// These calls are represented as field selections and don't yet have the explicitTailCall parameter. +:breakme +:e +module A with + fun f = g + fun g = + @tailcall f + f + +:expect 0 +module A with + fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) + fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) +A.f(10) +//│ = 0 + +:todo +:expect 0 +class A with + @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) + fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) +(new A).f(10) +//│ ╔══[ERROR] Calls from class methods cannot yet be marked @tailcall. +//│ ║ l.190: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) +//│ ╙── ^^^^^^^^ +//│ ╔══[ERROR] Class methods may not yet be marked @tailrec. +//│ ║ l.190: @tailrec fun f(x) = if x == 0 then 0 else @tailcall g(x - 1) +//│ ╙── ^ +//│ ╔══[ERROR] Calls from class methods cannot yet be marked @tailcall. +//│ ║ l.191: fun g(x) = if x == 0 then 1 else @tailcall f(x - 1) +//│ ╙── ^^^^^^^^ +//│ = 0 diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala index f527cb10d..0fbbce6bd 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala @@ -72,6 +72,7 @@ abstract class MLsDiffMaker extends DiffMaker: val importQQ = NullaryCommand("qq") val stageCode = NullaryCommand("staging") val dontRewriteWhile = NullaryCommand("dontRewriteWhile") + val noTailRecOpt = NullaryCommand("noTailRec") def mkConfig: Config = import Config.* @@ -100,6 +101,7 @@ abstract class MLsDiffMaker extends DiffMaker: stageCode = stageCode.isSet, target = if wasm.isSet then CompilationTarget.Wasm else CompilationTarget.JS, rewriteWhileLoops = !dontRewriteWhile.isSet, + tailRecOpt = !noTailRecOpt.isSet, )