Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ case class Config(
stageCode: Bool,
target: CompilationTarget,
rewriteWhileLoops: Bool,
tailRecOpt: Bool,
):

def stackSafety: Opt[StackSafety] = effectHandlers.flatMap(_.stackSafety)
Expand All @@ -40,6 +41,7 @@ object Config:
target = CompilationTarget.JS,
rewriteWhileLoops = true,
stageCode = false,
tailRecOpt = true,
)

case class SanityChecks(light: Bool)
Expand Down
11 changes: 6 additions & 5 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,15 @@ sealed abstract class Block extends Product:
val newBody = d.body.flattened
if newBody is d.body
then d
else d.copy(body = newBody)(isTailRec = d.isTailRec)
else d.copy(body = newBody)(forceTailRec = d.forceTailRec)
case v: ValDefn => v
case c: ClsLikeDefn =>
val newPreCtor = c.preCtor.flattened
val newCtor = c.ctor.flattened
val newMethods = c.methods.mapConserve:
case f@FunDefn(owner, sym, dSym, params, body) =>
val newBody = body.flattened
if newBody is body then f else f.copy(body = newBody)(isTailRec = f.isTailRec)
if newBody is body then f else f.copy(body = newBody)(forceTailRec = f.forceTailRec)
if (newPreCtor is c.preCtor) && (newCtor is c.ctor) && (newMethods is c.methods)
then c
else c.copy(preCtor = newPreCtor, ctor = newCtor, methods = newMethods)
Expand Down Expand Up @@ -343,12 +343,13 @@ final case class FunDefn(
params: Ls[ParamList],
body: Block,
)(
val isTailRec: Bool,
val forceTailRec: Bool,
) extends Defn:
val innerSym = N
val asPath = Value.Ref(sym, S(dSym))
object FunDefn:
def withFreshSymbol(owner: Opt[InnerSymbol], sym: BlockMemberSymbol, params: Ls[ParamList], body: Block)(isTailRec: Bool)(using State) =
FunDefn(owner, sym, TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme)), params, body)(isTailRec)
def withFreshSymbol(owner: Opt[InnerSymbol], sym: BlockMemberSymbol, params: Ls[ParamList], body: Block)(forceTailRec: Bool)(using State) =
FunDefn(owner, sym, TermSymbol(syntax.Fun, owner, Tree.Ident(sym.nme)), params, body)(forceTailRec)

final case class ValDefn(
tsym: TermSymbol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class BlockTransformer(subst: SymbolSubst):
val params2 = fun.params.mapConserve(applyParamList)
val body2 = applySubBlock(fun.body)
if (own2 is fun.owner) && (sym2 is fun.sym) && (dSym2 is fun.dSym) && (params2 is fun.params) && (body2 is fun.body)
then fun else FunDefn(own2, sym2, dSym2, params2, body2)(fun.isTailRec)
then fun else FunDefn(own2, sym2, dSym2, params2, body2)(fun.forceTailRec)

def applyValDefn(defn: ValDefn)(k: ValDefn => Block): Block =
val ValDefn(tsym, sym, rhs) = defn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class BlockTraverser:
cls.traverse
applyPath(path)
case Case.Tup(len, inf) => ()
case Case.Field(_, _) => ()

def applyHandler(hdr: Handler): Unit =
hdr.sym.traverse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class BufferableTransform()(using Ctx, State, Raise):
val blk = mkFieldReplacer(buf, idx).applyBlock(f.body)
FunDefn(f.owner, f.sym, f.dSym, PlainParamList(
Param(FldFlags.empty, buf, N, Modulefulness.none) :: Param(FldFlags.empty, idx, N, Modulefulness.none) :: Nil) :: f.params,
if isCtor then Begin(blk, Return(idx.asPath, false)) else blk)(isTailRec = f.isTailRec)
if isCtor then Begin(blk, Return(idx.asPath, false)) else blk)(forceTailRec = f.forceTailRec)
val fakeCtor = transformFunDefn(FunDefn.withFreshSymbol(
S(companionSym),
BlockMemberSymbol("ctor", Nil, false),
Expand Down
34 changes: 19 additions & 15 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
// all the code in the first state
private def translateBlock(b: Block, extraLocals: Set[Local], callSelf: Opt[Result], fnOrCls: FnOrCls, h: HandlerCtx): Block =
val getLocalsFn = createGetLocalsFn(b, extraLocals)(using h)
given HandlerCtx = h.nestDebugScope(b.userDefinedVars ++ extraLocals, getLocalsFn.sym.asPath)
given HandlerCtx = h.nestDebugScope(b.userDefinedVars ++ extraLocals, getLocalsFn.asPath)
val stage1 = firstPass(b)
val stage2 = secondPass(stage1, fnOrCls, callSelf, getLocalsFn)
if h.isTopLevel then stage2 else thirdPass(stage2)
Expand Down Expand Up @@ -542,23 +542,26 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
val cls = if handlerCtx.isTopLevel then N else genContClass(b, callSelf)

val ret = cls match
case None => genNormalBody(b, BlockMemberSymbol("", Nil), N)
case None => genNormalBody(b, BlockMemberSymbol("", Nil).asPath, N)
case Some(cls) =>
// create the doUnwind function
val doUnwindSym = BlockMemberSymbol(doUnwindNme, Nil, true)
doUnwindMap += fnOrCls -> doUnwindSym.asPath
val pcSym = VarSymbol(Tree.Ident("pc"))
val resSym = VarSymbol(Tree.Ident("res"))
val doUnwindBlk = h.linkAndHandle(
LinkState(resSym, cls.sym.asPath, pcSym.asPath)
LinkState(resSym, Value.Ref(cls.sym, S(cls.isym)), pcSym.asPath)
)
val doUnwindDef = FunDefn.withFreshSymbol(
N, doUnwindSym,
PlainParamList(Param.simple(resSym) :: Param.simple(pcSym) :: Nil) :: Nil,
doUnwindBlk
)(false)
val doUnwindLazy = Lazy(doUnwindSym.asPath)
val rst = genNormalBody(b, cls.sym, S(doUnwindLazy))

val doUnwindPath: Path = doUnwindDef.asPath
doUnwindMap += fnOrCls -> doUnwindPath

val doUnwindLazy = Lazy(doUnwindPath)
val rst = genNormalBody(b, Value.Ref(cls.sym, S(cls.isym)), S(doUnwindLazy))

if doUnwindLazy.isEmpty && opt.stackSafety.isEmpty then
blockBuilder
Expand Down Expand Up @@ -592,7 +595,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
case pList :: Nil =>
val params = pList.params.map(p => p.sym.asPath.asArg)
f.owner match
case None => S(Call(f.sym.asPath, params)(true, true, false))
case None => S(Call(f.asPath, params)(true, true, false))
case Some(owner) =>
S(Call(Select(owner.asPath, Tree.Ident(f.sym.nme))(N), params)(true, true, false))
case _ => None // TODO: more than one plist
Expand All @@ -602,7 +605,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
callSelf,
L(f.sym),
functionHandlerCtx(s"Cont$$func$$${symToStr(f.sym)}$$", f.sym.nme))
)(isTailRec = f.isTailRec)
)(forceTailRec = f.forceTailRec)

private def translateBody(cls: ClsLikeBody, sym: BlockMemberSymbol)(using HandlerCtx): ClsLikeBody =
val curCtorCtx =
Expand Down Expand Up @@ -631,11 +634,12 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
// Handle block becomes a FunDefn and CallPlaceholder
private def translateHandleBlock(h: HandleBlock)(using HandlerCtx): Block =
val sym = BlockMemberSymbol(s"handleBlock$$", Nil)
val tSym = TermSymbol.fromFunBms(sym, N)
val lbl = freshTmp("handlerBody")
val lblLoop = freshTmp("handlerLoop")

val handlerBody = translateBlock(
h.body, Set.empty, S(Call(sym.asPath, Nil)(true, false, false)), L(sym),
h.body, Set.empty, S(Call(Value.Ref(sym, S(tSym)), Nil)(true, false, false)), L(sym),
HandlerCtx(
false, true,
s"Cont$$handleBlock$$${symToStr(h.lhs)}$$", N,
Expand All @@ -662,7 +666,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
handler.params,
Define(
fDef,
Return(PureCall(paths.mkEffectPath, h.cls.asPath :: Value.Ref(sym, N) :: Nil), false)))(false)
Return(PureCall(paths.mkEffectPath, h.cls.asPath :: fDef.asPath :: Nil), false)))(false)

// Some limited handling of effects extending classes and having access to their fields.
// Currently does not support super() raising effects.
Expand Down Expand Up @@ -695,14 +699,14 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
.assign(h.lhs, Instantiate(mut = true, Value.Ref(clsDefn.sym, S(h.cls)), Nil))
.rest(handlerBody)

val defn = FunDefn.withFreshSymbol(
val defn = FunDefn(
N, // no owner
sym, PlainParamList(Nil) :: Nil, body)(false)
sym, tSym, PlainParamList(Nil) :: Nil, body)(false)

val result = blockBuilder
.define(defn)
.rest(
ResultPlaceholder(h.res, freshId(), Call(sym.asPath, Nil)(true, true, false), h.rest)
ResultPlaceholder(h.res, freshId(), Call(defn.asPath, Nil)(true, true, false), h.rest)
)
result

Expand Down Expand Up @@ -928,12 +932,12 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
N, // TODO: bufferable?
))

private def genNormalBody(b: Block, clsSym: BlockMemberSymbol, doUnwind: Opt[Lazy[Path]])(using HandlerCtx): Block =
private def genNormalBody(b: Block, clsPath: Path, doUnwind: Opt[Lazy[Path]])(using HandlerCtx): Block =
val transform = new BlockTransformerShallow(SymbolSubst()):
override def applyBlock(b: Block): Block = b match
case ResultPlaceholder(res, uid, c, rest) =>
val doUnwindBlk = doUnwind match
case None => Assign(res, topLevelCall(LinkState(res, clsSym.asPath, Value.Lit(Tree.IntLit(uid)))), End())
case None => Assign(res, topLevelCall(LinkState(res, clsPath, Value.Lit(Tree.IntLit(uid)))), End())
case Some(doUnwind) => Return(PureCall(doUnwind.get_!, res.asPath :: Value.Lit(Tree.IntLit(uid)) :: Nil), false)
blockBuilder
.assign(res, c)
Expand Down
14 changes: 8 additions & 6 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/LambdaRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@ object LambdaRewriter:
val newSym = BlockMemberSymbol(lhs.nme, Nil,
nameIsMeaningful = true // TODO: lhs.nme is not always meaningful
)
val defn = FunDefn.withFreshSymbol(N, newSym, params :: Nil, body)(false)
val blk = blockBuilder
.define(FunDefn.withFreshSymbol(N, newSym, params :: Nil, body)(false))
.assign(lhs, newSym.asPath)
.define(defn)
.assign(lhs, defn.asPath)
.rest(rest)
(blk, Nil)
case _ =>
var lambdasList: List[(BlockMemberSymbol, Lambda)] = Nil
var lambdasList: List[((BlockMemberSymbol, TermSymbol), Lambda)] = Nil
val lambdaRewriter = new BlockDataTransformer(SymbolSubst()):
override def applyResult(r: Result)(k: Result => Block): Block = r match
case lam: Lambda =>
val sym = BlockMemberSymbol("lambda", Nil, nameIsMeaningful = false)
lambdasList ::= (sym -> super.applyLam(lam))
k(Value.Ref(sym, N))
val tSym = TermSymbol.fromFunBms(sym, N)
lambdasList ::= ((sym, tSym) -> super.applyLam(lam))
k(Value.Ref(sym, S(tSym)))
case _ => super.applyResult(r)(k)
val blk = lambdaRewriter.applyBlock(b)
(blk, lambdasList)
Expand All @@ -39,7 +41,7 @@ object LambdaRewriter:
val (newBlk, lambdasList) = rewriteOneBlk(b)
val lambdaDefns = lambdasList.map:
case (sym, Lambda(params, body)) =>
FunDefn.withFreshSymbol(N, sym, params :: Nil, body)(false)
FunDefn(N, sym._1, sym._2, params :: Nil, body)(false)
val ret = lambdaDefns.foldLeft(newBlk):
case (acc, defn) => Define(defn, acc)
super.applyBlock(ret)
Expand Down
Loading
Loading