Skip to content

Commit ea7dc6e

Browse files
chengluyuLPTK
andauthored
Unify patterns in UCS and UPS (#336)
Co-authored-by: Lionel Parreaux <[email protected]>
1 parent 481546f commit ea7dc6e

File tree

136 files changed

+5316
-3130
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

136 files changed

+5316
-3130
lines changed

hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package bbml
44

55
import scala.collection.mutable.{HashSet, HashMap, ListBuffer}
66
import scala.annotation.tailrec
7+
import sourcecode.{FileName, Line, Name}
78

89
import mlscript.utils.*, shorthands.*
910
import utils.*
@@ -73,7 +74,7 @@ object BbCtx:
7374
end BbCtx
7475

7576

76-
class BBTyper(using elState: Elaborator.State, tl: TL):
77+
class BBTyper(using elState: Elaborator.State, tl: TL)(using Ctx):
7778
import tl.{trace, log}
7879

7980
private val infVarState = new InfVarUid.State()
@@ -100,7 +101,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
100101
state.lowerBounds = ctx.getRegEnv :: Nil
101102
InfVar(ctx.lvl, infVarState.nextUid, state, false)(sym, "")
102103

103-
private def error(msg: Ls[Message -> Opt[Loc]])(using BbCtx) =
104+
private def error(using Line, FileName, Name, Raise)(msg: Ls[Message -> Opt[Loc]])(using BbCtx) =
104105
raise(ErrorReport(msg))
105106
Bot // TODO: error type?
106107

@@ -258,7 +259,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
258259
val res = freshVar(new TempSymbol(S(blk), "ctx"))(using ctx)
259260
constrain(bodyCtx, sk | res)
260261
(bodyTy, rhsCtx | res, rhsEff | bodyEff)
261-
case Term.IfLike(Keyword.`if`, Split.Let(_, cond, Split.Cons(Branch(_, FlatPattern.Lit(BoolLit(true)), Split.Else(cons)), Split.Else(alts)))) =>
262+
case Term.IfLike(Keyword.`if`, SimpleSplit.IfThenElse(cond, cons, alts)) =>
262263
val (condTy, condCtx, condEff) = typeCode(cond)
263264
val (consTy, consCtx, consEff) = typeCode(cons)
264265
val (altsTy, altsCtx, altsEff) = typeCode(alts)
@@ -363,8 +364,8 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
363364
given BbCtx = nextCtx
364365
constrain(ascribe(term, skolemize(pt))._2, Bot) // * never generalize terms with effects
365366
(pt, Bot)
366-
case (Term.IfLike(Keyword.`if`, branches), ty) => // * propagate
367-
typeSplit(branches, S(ty))
367+
case (Term.IfLike(Keyword.`if`, split), ty) => // * propagate
368+
typeSplit(split.getExpandedSplit, S(ty))
368369
case (Term.Asc(term, ty), rhs) =>
369370
ascribe(term, typeType(ty))
370371
ascribe(term, rhs)
@@ -550,7 +551,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
550551
case Term.Asc(term, ty) =>
551552
val res = typeType(ty)(using ctx)
552553
ascribe(term, res)
553-
case Term.IfLike(Keyword.`if`, branches) => typeSplit(branches, N)
554+
case Term.IfLike(Keyword.`if`, split) => typeSplit(split.getExpandedSplit, N)
554555
case reg @ Term.Region(sym, body) =>
555556
val sk = freshReg(sym)(using ctx)
556557
val nestCtx = ctx.nestReg(sk)

hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ sealed abstract class Block extends Product:
4848
// Note that the handler's LHS and body are not part of the current block, so we do not consider them here.
4949
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) => rst.definedVars + res
5050
case TryBlock(sub, fin, rst) => sub.definedVars ++ fin.definedVars ++ rst.definedVars
51-
case Label(lbl, bod, rst) => bod.definedVars ++ rst.definedVars
51+
case Label(lbl, _, bod, rst) => bod.definedVars ++ rst.definedVars
5252

5353
lazy val size: Int = this match
5454
case _: Return | _: Throw | _: End | _: Break | _: Continue => 1
@@ -60,7 +60,7 @@ sealed abstract class Block extends Product:
6060
1 + arms.map(_._2.size).sum + dflt.map(_.size).getOrElse(0) + rst.size
6161
case Define(_, rst) => 1 + rst.size
6262
case TryBlock(sub, fin, rst) => 1 + sub.size + fin.size + rst.size
63-
case Label(_, bod, rst) => 1 + bod.size + rst.size
63+
case Label(_, _, bod, rst) => 1 + bod.size + rst.size
6464
case HandleBlock(lhs, res, par, args, cls, handlers, bdy, rst) => 1 + handlers.map(_.body.size).sum + bdy.size + rst.size
6565

6666
// TODO conserve if no changes
@@ -75,7 +75,7 @@ sealed abstract class Block extends Product:
7575
Match(scrut, arms.map(_ -> _.mapTail(f)), dflt.map(_.mapTail(f)), rst)
7676
case Match(scrut, arms, dflt, rst) =>
7777
Match(scrut, arms, dflt, rst.mapTail(f))
78-
case Label(label, body, rest) => Label(label, body, rest.mapTail(f))
78+
case Label(label, loop, body, rest) => Label(label, loop, body.mapTail(f), rest.mapTail(f))
7979
case af @ AssignField(lhs, nme, rhs, rest) =>
8080
AssignField(lhs, nme, rhs, rest.mapTail(f))(af.symbol)
8181
case adf @ AssignDynField(lhs, fld, arrayIdx, rhs, rest) =>
@@ -90,7 +90,7 @@ sealed abstract class Block extends Product:
9090
(pat, arm) => arm.freeVars -- pat.freeVars
9191
case Return(res, implct) => res.freeVars
9292
case Throw(exc) => exc.freeVars
93-
case Label(label, body, rest) => (body.freeVars - label) ++ rest.freeVars
93+
case Label(label, _, body, rest) => (body.freeVars - label) ++ rest.freeVars
9494
case Break(label) => Set(label)
9595
case Continue(label) => Set(label)
9696
case Begin(sub, rest) => sub.freeVars ++ rest.freeVars
@@ -110,7 +110,7 @@ sealed abstract class Block extends Product:
110110
(pat, arm) => arm.freeVarsLLIR -- pat.freeVarsLLIR
111111
case Return(res, implct) => res.freeVarsLLIR
112112
case Throw(exc) => exc.freeVarsLLIR
113-
case Label(label, body, rest) => (body.freeVarsLLIR - label) ++ rest.freeVarsLLIR
113+
case Label(label, _, body, rest) => (body.freeVarsLLIR - label) ++ rest.freeVarsLLIR
114114
case Break(label) => Set.empty
115115
case Continue(label) => Set.empty
116116
case Begin(sub, rest) => sub.freeVarsLLIR ++ rest.freeVarsLLIR
@@ -132,7 +132,7 @@ sealed abstract class Block extends Product:
132132
case AssignDynField(_, _, _, rhs, rest) => rhs.subBlocks ::: rest :: Nil
133133
case Define(d, rest) => d.subBlocks ::: rest :: Nil
134134
case HandleBlock(_, _, par, args, _, handlers, body, rest) => par.subBlocks ++ args.flatMap(_.subBlocks) ++ handlers.map(_.body) :+ body :+ rest
135-
case Label(_, body, rest) => body :: rest :: Nil
135+
case Label(_, _, body, rest) => body :: rest :: Nil
136136

137137
// TODO rm Lam from values and thus the need for these cases
138138
case Return(r, _) => r.subBlocks
@@ -177,12 +177,12 @@ sealed abstract class Block extends Product:
177177
then this
178178
else Match(scrut, newArms, newDflt, newRest)
179179

180-
case Label(label, body, rest) =>
180+
case Label(label, loop, body, rest) =>
181181
val newBody = body.flattened
182182
val newRest = rest.flatten(k)
183183
if (newBody is body) && (newRest is rest)
184184
then this
185-
else Label(label, newBody, newRest)
185+
else Label(label, loop, newBody, newRest)
186186

187187
case Begin(sub, rest) =>
188188
sub.flatten(_ => rest.flatten(k))
@@ -267,7 +267,7 @@ case class Return(res: Result, implct: Bool) extends BlockTail
267267

268268
case class Throw(exc: Result) extends BlockTail
269269

270-
case class Label(label: Local, body: Block, rest: Block) extends Block
270+
case class Label(label: Local, loop: Bool, body: Block, rest: Block) extends Block
271271

272272
case class Break(label: Local) extends BlockTail
273273
case class Continue(label: Local) extends BlockTail
@@ -595,7 +595,7 @@ extension (k: Block => Block)
595595
def end = k.rest(End())
596596
def ifthen(scrut: Path, cse: Case, trm: Block, els: Opt[Block] = N): Block => Block =
597597
k.chain(Match(scrut, cse -> trm :: Nil, els, _))
598-
def label(label: Local, body: Block) = k.chain(Label(label, body, _))
598+
def label(label: Local, loop: Bool, body: Block) = k.chain(Label(label, loop, body, _))
599599
def ret(r: Result) = k.rest(Return(r, false))
600600
def staticif(b: Boolean, f: (Block => Block) => (Block => Block)) = if b then k.transform(f) else k
601601
def foldLeft[A](xs: Iterable[A])(f: (Block => Block, A) => Block => Block) = xs.foldLeft(k)(f)

hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ class BlockTransformer(subst: SymbolSubst):
4545
(arms2 is arms) &&
4646
(dflt2 is dflt) && (rst2 is rst)
4747
then b else Match(scrut2, arms2, dflt2, rst2)
48-
case Label(lbl, bod, rst) =>
48+
case Label(lbl, loop, bod, rst) =>
4949
val lbl2 = applyLocal(lbl)
5050
val bod2 = applySubBlock(bod)
5151
val rst2 = applySubBlock(rst)
52-
if (lbl2 is lbl) && (bod2 is bod) && (rst2 is rst) then b else Label(lbl2, bod2, rst2)
52+
if (lbl2 is lbl) && (bod2 is bod) && (rst2 is rst) then b else Label(lbl2, loop, bod2, rst2)
5353
case Begin(sub, rst) =>
5454
val sub2 = applySubBlock(sub)
5555
val rst2 = applySubBlock(rst)

hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class BlockTraverser:
3131
applyCase(arm._1); applySubBlock(arm._2)
3232
dflt.foreach(applySubBlock)
3333
applySubBlock(rst)
34-
case Label(lbl, bod, rst) => applyLocal(lbl); applySubBlock(bod); applySubBlock(rst)
34+
case Label(lbl, loop, bod, rst) => applyLocal(lbl); applySubBlock(bod); applySubBlock(rst)
3535
case Begin(sub, rst) => applySubBlock(sub); applySubBlock(rst)
3636
case TryBlock(sub, fin, rst) => applySubBlock(sub); applySubBlock(fin); applySubBlock(rst)
3737
case Assign(l, r, rst) => applyLocal(l); applyResult(r); applySubBlock(rst)

hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
332332
Match(scrut, newArms, dfltParts.map(_.head), StateTransition(restId)),
333333
BlockState(restId, restParts.head, N) :: states
334334
)
335-
case l @ Label(label, body, rest) =>
335+
case l @ Label(label, loop, body, rest) =>
336336
val startId = freshId() // start of body
337337

338338
val PartRet(restNew, restParts) = go(rest)
@@ -820,7 +820,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
820820
AssignField(runtimePath, stackDepthIdent, depthSym.asPath, mainMatchBlk)(N)
821821
else mainMatchBlk
822822

823-
val lbl = blockBuilder.label(loopLbl, withResetDepth).rest(End())
823+
val lbl = blockBuilder.label(loopLbl, loop = true, withResetDepth).rest(End())
824824

825825
def createAssignment(sym: Local) = Assign(sym, resumedVal.asPath, End())
826826

hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,11 +718,11 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
718718
(dflt2 is dflt) && (rst2 is rst)
719719
then b else Match(scrut2, arms2, dflt2, rst2)
720720

721-
case Label(lbl, bod, rst) =>
721+
case Label(lbl, loop, bod, rst) =>
722722
val lbl2 = applyLocal(lbl)
723723
val bod2 = applySubBlockAndReset(bod)
724724
val rst2 = applySubBlock(rst)
725-
if (lbl2 is lbl) && (bod2 is bod) && (rst2 is rst) then b else Label(lbl2, bod2, rst2)
725+
if (lbl2 is lbl) && (bod2 is bod) && (rst2 is rst) then b else Label(lbl2, loop, bod2, rst2)
726726
case TryBlock(sub, fin, rst) =>
727727
val sub2 = applySubBlockAndReset(sub)
728728
val fin2 = applySubBlockAndReset(fin)

0 commit comments

Comments
 (0)