diff --git a/Node.scala b/Node.scala index cbf67fa..82ab3f8 100644 --- a/Node.scala +++ b/Node.scala @@ -385,7 +385,7 @@ object Node: require( results.size == 1, - s"token(s) not found ${(tok +: toks).map(_.name).mkString(", ")}" + s"token(s) not found ${(tok +: toks).map(_.name).mkString(", ")}, in ${this.toShortString()}" ) results.head diff --git a/PassSeq.scala b/PassSeq.scala index 60b8267..f28866b 100644 --- a/PassSeq.scala +++ b/PassSeq.scala @@ -21,15 +21,16 @@ import dsl.* transparent trait PassSeq: def inputWellformed: Wellformed final def outputWellformed: Wellformed = - assert(entriesSealed) + allPasses // compute allPasses, needed to enforce complete initialization entries.last.wellformed private val entries = mutable.ListBuffer.empty[PassSeq.Entry] private var entriesSealed = false protected def prevWellformed(using BuildCtx): Wellformed = - require(entries.nonEmpty, "there is no previous Wellformed") - entries.last.wellformed + if entries.isEmpty + then inputWellformed + else entries.last.wellformed protected object wellformed: def :=(using ctx: BuildCtx)(wellformed: Wellformed): Unit = diff --git a/Wellformed.scala b/Wellformed.scala index 91d93cc..b3c8fe1 100644 --- a/Wellformed.scala +++ b/Wellformed.scala @@ -524,6 +524,15 @@ object Wellformed: s"$token's shape is not appropriate for adding cases ($shape)" ) + def deleteShape(): Unit = + token match + case Node.Top => + require(topShapeOpt.nonEmpty) + topShapeOpt = None + case token: Token => + require(assigns.contains(token)) + assigns.remove(token) + def importFrom(wf2: Wellformed): Unit = def fillFromShape(shape: Shape): Unit = shape match @@ -550,6 +559,7 @@ object Wellformed: topShapeOpt = Some(wf2.topShape) fillFromShape(wf2.topShape) case token: Token => fillFromTokenOrEmbed(token) + end importFrom private[Wellformed] def build(): Wellformed = require(topShapeOpt.nonEmpty) diff --git a/tla/ExprParser.scala b/tla/ExprParser.scala new file mode 100644 index 0000000..e80d553 --- /dev/null +++ b/tla/ExprParser.scala @@ -0,0 +1,970 @@ +package distcompiler.tla + +import cats.syntax.all.given + +import distcompiler.* +import dsl.* +import distcompiler.Builtin.{Error, SourceMarker} +import distcompiler.tla.TLAReader +import distcompiler.tla.TLAParser.rawExpression +import distcompiler.Manip.ops.pass.bottomUp +import distcompiler.tla.defns.THEOREM +import distcompiler.tla.TLAParser.RawExpression +import distcompiler.tla.TLAParser.rawConjunction + +object ExprParser extends PassSeq: + import distcompiler.dsl.* + import distcompiler.Builtin.{Error, SourceMarker} + import TLAReader.* + def inputWellformed: Wellformed = TLAParser.outputWellformed + // TODO: make private + def highPredInfixInfix(op: defns.InfixOperator): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node, Node)]]] = + field(tokens.Expr.withChildren( + field(tokens.Expr.TmpInfixGroup.withChildren( + field(tok(defns.InfixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token : defns.InfixOperator => + op.highPrecedence > op2Token.highPrecedence)) + ~ field(tokens.Expr) + ~ field(tokens.Expr) + ~ eof + )) + ~ eof + )) + def highPredInfixUnary(op: defns.InfixOperator): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node)]]] = + field(tokens.Expr.withChildren( + field(tokens.Expr.TmpUnaryGroup.withChildren( + field(tok(defns.PrefixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token : defns.PrefixOperator => + op.highPrecedence > op2Token.highPrecedence + case op2Token : defns.PostfixOperator => + op.highPrecedence > op2Token.precedence)) + ~ field(tokens.Expr) + ~ eof + )) + ~ eof + )) + def highPredUnaryInfix(op: defns.Operator): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node, Node)]]] = + field(tokens.Expr.withChildren( + field(tokens.Expr.TmpInfixGroup.withChildren( + field(tok(defns.InfixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token : defns.InfixOperator => + op.highPrecedence > op2Token.highPrecedence)) + ~ field(tokens.Expr) + ~ field(tokens.Expr) + ~ eof + )) + ~ eof + )) + + def badPredInfixInfix(op: defns.InfixOperator): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node, Node)]]] = + field(tokens.Expr.withChildren( + field(tokens.Expr.TmpInfixGroup.withChildren( + field(tok(defns.InfixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token : defns.InfixOperator => + !((op.highPrecedence > op2Token.highPrecedence) + || (op.lowPrecedence < op2Token.lowPrecedence) + || ((op == op2Token) && op.isAssociative)) + )) + ~ field(tokens.Expr) + ~ field(tokens.Expr) + ~ eof + )) + ~ eof + )) + def badPredInfixUnary(op: defns.InfixOperator): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node)]]] = + field(tokens.Expr.withChildren( + field(tokens.Expr.TmpUnaryGroup.withChildren( + field(tok(defns.PrefixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token : defns.PrefixOperator => + !((op.highPrecedence > op2Token.highPrecedence) + || (op.lowPrecedence < op2Token.lowPrecedence)) + case op2Token : defns.PostfixOperator => + !((op.highPrecedence > op2Token.precedence) + || (op.lowPrecedence < op2Token.precedence)))) + ~ field(tokens.Expr) + ~ eof + )) + ~ eof + )) + def badPredUnaryInfix(op: defns.Operator): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node, Node)]]] = + field(tokens.Expr.withChildren( + field(tokens.Expr.TmpInfixGroup.withChildren( + field(tok(defns.InfixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token : defns.InfixOperator => + !((op.highPrecedence > op2Token.highPrecedence) + || (op.lowPrecedence < op2Token.lowPrecedence)))) + ~ field(tokens.Expr) + ~ field(tokens.Expr) + ~ eof + )) + ~ eof + )) + def matchQuantifierId() :SeqPattern[(Node, TLAParser.RawExpression)] = + parent(tokens.Expr) *> + field(TLAReader.Alpha) + ~ skip(defns.`\\in`) + ~ field(rawExpression) + ~ trailing + def matchQuantifierIds() :SeqPattern[(List[Node], TLAParser.RawExpression)] = + parent(tokens.Expr) *> + field(tok(TLAReader.TupleGroup).withChildren( + field(repeatedSepBy(`,`)(tok(TLAReader.Alpha))) + ~ eof + )) + ~ skip(defns.`\\in`) + ~ field(rawExpression) + ~ trailing + + // replace all tokens.Expr(contents...) with + // tokens.Expr(tokens.ExprTry, contents...) + + val buildExpressions = passDef: + wellformed := prevWellformed.makeDerived: + val removedCases = Seq( + TLAReader.StringLiteral, + TLAReader.NumberLiteral, + TLAReader.TupleGroup + // TODO: remove cases + ) + TLAReader.groupTokens.foreach: tok => + tok.removeCases(removedCases*) + tok.addCases(tokens.Expr) + + tokens.Expr.TmpInfixGroup ::= fields( + choice(defns.InfixOperator.instances*), + tokens.Expr, + tokens.Expr + ) + tokens.Expr.TmpUnaryGroup ::= fields( + choice((defns.PrefixOperator.instances ++ defns.PostfixOperator.instances)*), + tokens.Expr, + ) + + tokens.Expr.deleteShape() + tokens.Expr.importFrom(tla.wellformed) + tokens.Expr.addCases( + tokens.Expr, + tokens.Expr.TmpInfixGroup, + tokens.Expr.TmpUnaryGroup) + + // TODO: assign the correct source with .like() + + pass(once = false, strategy = pass.bottomUp) // conjunction alignment + .rules: + on( + parent(tokens.Expr) *> + (field(tok(defns./\)) + ~ field(rawConjunction(1)) + ~ field(tok(defns./\)) + ~ field(rawConjunction(1)) + ~ eof + ).filter(things => + things match + case (and1 : Node, r1, and2 : Node, r2) => + val s1 = and1.sourceRange + val s2 = and2.sourceRange + // println("\nCol 1: " + s1.source.lines.lineColAtOffset(s1.offset)) + // println("\nCol 2: " + s2.source.lines.lineColAtOffset(s2.offset)) + s1.source.lines.lineColAtOffset(s1.offset)._2 == s2.source.lines.lineColAtOffset(s2.offset)._2 + ) + ).rewrite: (and1, r1, and2, r2) => + splice( + and1.unparent(), + tokens.Expr( + tokens.Expr.OpCall( + tokens.OpSym(and2.unparent()), + tokens.Expr.OpCall.Params( + r1.mkNode, + r2.mkNode + )))) + *> pass(once = false, strategy = pass.bottomUp) // remove leading /\, remove paren + .rules: + // on( + // field(tokens.Expr.withChildren( + // skip(defns./\) + // ~ field(tokens.Expr) + // ~ eof + // )) + // ~ eof + // ).rewrite: expr => + // splice(expr.unparent()) + on( + field(tokens.Expr.withChildren( + skip(defns./\) + ~ field(rawExpression) + ~ eof + )) + ~ eof + ).rewrite: expr => + splice(expr.mkNode) + *> pass(once = false, strategy = pass.bottomUp) // resolve quantifiers/opCall + .rules: + on( + parent(tokens.Expr) *> + field(TLAReader.Alpha) + ~ field(TLAReader.ParenthesesGroup.withChildren( + field(repeatedSepBy(`,`)(rawExpression)) + ~ eof + )) + ~ trailing + ).rewrite: (fun, args) => + splice( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id().like(fun), + tokens.Expr.OpCall.Params( + args.iterator.map(_.mkNode) + )))) + | on( + parent(tokens.Expr) *> + skip(tok(TLAReader.LaTexLike).src("\\E")) + ~ field(repeatedSepBy1(`,`)(matchQuantifierId())) + ~ skip(TLAReader.`:`) + ~ field(rawExpression) + ~ trailing + ).rewrite: (qBounds, expr) => + splice( + tokens.Expr(tokens.Expr.Exists( + tokens.QuantifierBounds( + qBounds.iterator.map( + (id, qExpr) => + tokens.QuantifierBound( + tokens.Id().like(id), + qExpr.mkNode + ))), + expr.mkNode + ))) + | on( + parent(tokens.Expr) *> + skip(tok(TLAReader.LaTexLike).src("\\E")) + ~ field(repeatedSepBy1(`,`)(matchQuantifierIds())) + ~ skip(TLAReader.`:`) + ~ field(rawExpression) + ~ trailing + ).rewrite: (qBounds, expr) => + splice( + tokens.Expr(tokens.Expr.Exists( + tokens.QuantifierBounds( + qBounds.iterator.map( + (ids, qExpr) => + tokens.QuantifierBound( + tokens.Ids( + ids.iterator.map( + id => tokens.Id().like(id) + ) + ), + qExpr.mkNode + ))), + expr.mkNode + ))) + | on( + parent(tokens.Expr) *> + skip(tok(TLAReader.LaTexLike).src("\\A")) + ~ field(repeatedSepBy1(`,`)(matchQuantifierId())) + ~ skip(TLAReader.`:`) + ~ field(rawExpression) + ~ trailing + ).rewrite: (qBounds, expr) => + splice( + tokens.Expr(tokens.Expr.Forall( + tokens.QuantifierBounds( + qBounds.iterator.map( + (id, qExpr) => + tokens.QuantifierBound( + tokens.Id().like(id), + qExpr.mkNode + ))), + expr.mkNode + ))) + | on( + parent(tokens.Expr) *> + skip(tok(TLAReader.LaTexLike).src("\\A")) + ~ field(repeatedSepBy1(`,`)(matchQuantifierIds())) + ~ skip(TLAReader.`:`) + ~ field(rawExpression) + ~ trailing + ).rewrite: (qBounds, expr) => + splice( + tokens.Expr(tokens.Expr.Forall( + tokens.QuantifierBounds( + qBounds.iterator.map( + (ids, qExpr) => + tokens.QuantifierBound( + tokens.Ids( + ids.iterator.map( + id => tokens.Id().like(id) + ) + ), + qExpr.mkNode + ))), + expr.mkNode + ))) + | on( + parent(tokens.Expr) *> + skip(tok(defns.CHOOSE)) + ~ field(matchQuantifierId()) + ~ skip(TLAReader.`:`) + ~ field(rawExpression) + ~ trailing + ).rewrite: (qBound, expr) => + qBound match + case (id, qExpr) => + splice( + tokens.Expr(tokens.Expr.Choose( + tokens.QuantifierBound( + tokens.Id().like(id), + qExpr.mkNode + ), + expr.mkNode + ))) + // TODO: tuple qbound + // id nil + // tuple nil + *> pass(once = false, strategy = pass.bottomUp) + .rules: + on( + parent(tokens.Expr) *> + TLAReader.Alpha + ).rewrite: name => + splice( + tokens.Expr( + tokens.Id().like(name) + )) + | on( + parent(tokens.Expr) *> + TLAReader.NumberLiteral + ).rewrite: lit => + splice(tokens.Expr(tokens.Expr.NumberLiteral().like(lit))) + | on( + parent(tokens.Expr) *> + TLAReader.StringLiteral + ).rewrite: lit => + splice(tokens.Expr(tokens.Expr.StringLiteral().like(lit))) + | on( + parent(tokens.Expr) *> + tok(TLAReader.BracesGroup) *> + children: + field(repeatedSepBy(`,`)(rawExpression)) + ~ eof + ).rewrite: exprs => + splice( + tokens.Expr( + tokens.Expr.SetLiteral(exprs.iterator.map(_.mkNode)) + )) + | on( + parent(tokens.Expr) *> + tok(TLAReader.TupleGroup).product( + children( + field(repeatedSepBy(`,`)(rawExpression)) + ~ eof + )) + ).rewrite: (lit, elems) => + splice( + tokens.Expr( + tokens.Expr.TupleLiteral(elems.iterator.map(_.mkNode)).like(lit) + )) + | on( + parent(tokens.Expr) *> + tok(TLAReader.SqBracketsGroup) *> + children: + field( + repeatedSepBy1(`,`)( + field(TLAReader.Alpha) + ~ skip(`|->`) + ~ field(rawExpression) + ~ trailing + )) + ~ eof + ).rewrite: fields => + splice( + tokens.Expr( + tokens.Expr.RecordLiteral( + fields.iterator.map( + (alpha, expr) => + tokens.Expr.RecordLiteral.Field( + tokens.Id().like(alpha.unparent()), + expr.mkNode + ))))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr) + ~ skip(defns.`.`) + ~ field(tokens.Expr.withChildren( + field(tokens.Id) + ~ eof + )) + ~ trailing + ).rewrite: (expr, id) => + splice( + tokens.Expr( + tokens.Expr.Project( + tokens.Expr( + expr.unparent() + ), + id.unparent() + ))) + | on( + parent(tokens.Expr) *> + tok(TLAReader.SqBracketsGroup) *> + children: + field( + repeatedSepBy1(`,`)( + field(TLAReader.Alpha) + ~ skip(`:`) + ~ field(rawExpression) + ~ trailing + )) + ~ eof + ).rewrite: fields => + splice( + tokens.Expr( + tokens.Expr.RecordSetLiteral( + fields.iterator.map( + (alpha, expr) => + tokens.Expr.RecordSetLiteral.Field( + tokens.Id().like(alpha.unparent()), + expr.mkNode + ))))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr) + ~ field(tok(defns.InfixOperator.instances.filter(_ != defns.`.`)*)) + ~ field(tokens.Expr) + ~ eof + ).rewrite: (left, op, right) => + splice( + tokens.Expr(tokens.Expr.TmpInfixGroup( + op.unparent(), + left.unparent(), + right.unparent() + ))) + | on( + parent(tokens.Expr) *> + field(tok(defns.PrefixOperator.instances*)) + ~ field(tokens.Expr) + ~ eof + ).rewrite: (op, expr) => + splice( + tokens.Expr(tokens.Expr.TmpUnaryGroup( + op.unparent(), + expr.unparent(), + ))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr) + ~ field(tok(defns.PostfixOperator.instances*)) + ~ trailing + ).rewrite: (expr, op) => + splice( + tokens.Expr(tokens.Expr.TmpUnaryGroup( + op.unparent(), + expr.unparent(), + ))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr) + ~ field(tok(TLAReader.SqBracketsGroup) *> + children: + field(repeatedSepBy(`,`)(rawExpression)) + ~ eof + ) + ~ eof + ).rewrite: (callee, args) => + splice( + tokens.Expr(tokens.Expr.FnCall( + callee.unparent(), + args match + case List(expr) => + expr.mkNode + case _ => + tokens.Expr(tokens.Expr.TupleLiteral( + args.iterator.map(_.mkNode) + ))))) + | on( + parent(tokens.Expr) *> + skip(defns.CASE) + ~ field( + repeatedSepBy(defns.`[]`)( + field(tokens.Expr) + ~ skip(defns.-) + ~ skip(defns.>) + ~ field(tokens.Expr) + ~ trailing + )) + ~ field( + optional( + skip(defns.OTHER) + ~ skip(defns.-) + ~ skip(defns.>) + ~ field(tokens.Expr) + ~ eof + ) + ) + ~ eof + ).rewrite: (cases, other) => + splice( + tokens.Expr( + tokens.Expr.Case( + tokens.Expr.Case.Branches( + cases.iterator.map((pred, branch) => + tokens.Expr.Case.Branch( + pred.unparent(), + branch.unparent(), + ))), + tokens.Expr.Case.Other( + other match + case None => tokens.Expr.Case.Other.None() + case Some(expr) => expr.unparent() + ) + ))) + | on( + parent(tokens.Expr) *> + field(tok(TLAReader.LetGroup).product( + children( + field( + repeated1: + tok(tokens.Operator) + | tok(tokens.ModuleDefinition) + | tok(tokens.Recursive) + ) + ~ eof + ) + )) + ~ field(tokens.Expr) + ~ trailing + ).rewrite: (let, expr) => + splice( + tokens.Expr( + tokens.Expr.Let( + tokens.Expr.Let.Defns( + let._2.iterator.map(_.unparent()) + ), + expr.unparent() + )).like(let._1) + ) + // TODO: Function + // TODO: SetComprehension {expr : x \in y} + // TODO: SetRefinement {x \in y : bool} + // {x \in y : p \in q} TODO: look in the book + // TODO: Choose + // TODO: Except + // TODO: Lambda + // TODO a \x b \ a + | on( + parent(tokens.Expr) *> + skip(defns.IF) + ~ field(tokens.Expr) + ~ skip(defns.THEN) + ~ field(tokens.Expr) + ~ skip(defns.ELSE) + ~ field(tokens.Expr) + ~ trailing + ).rewrite: (pred, t, f) => + splice( + tokens.Expr.If( + pred.unparent(), + t.unparent(), + f.unparent() + )) + | on( + parent(tokens.Expr) *> + tok(TLAReader.ParenthesesGroup) *> + children: + field(rawExpression) + ~ eof + ).rewrite: rawExpr => + splice( + tokens.Expr(rawExpr.mkNode) + ) + *> pass(once = false, strategy = pass.bottomUp) // resolve Alphas + .rules: + on( + parent(tokens.Expr) *> + tokens.Id + ).rewrite: name => + splice( + tokens.Expr( + tokens.Expr.OpCall( + name.unparent(), + tokens.Expr.OpCall.Params(), + ))) + end buildExpressions + + val reorderOperations = passDef: + wellformed := prevWellformed.makeDerived: + tokens.Expr.removeCases( + tokens.Expr.TmpInfixGroup, + tokens.Expr.TmpUnaryGroup) + pass(once = false, strategy = pass.bottomUp) + .rules: + on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpInfixGroup.withChildren: + defns.InfixOperator.instances + .iterator + .map: op => + field(op) + ~ highPredInfixInfix(op) + ~ field(tokens.Expr) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, left, right) => + left match + case (op, expr1, expr2) => + splice( + tokens.Expr.TmpInfixGroup( + op.unparent(), + expr1.unparent(), + tokens.Expr(tokens.Expr.TmpInfixGroup( + curOp.unparent(), + expr2.unparent(), + right.unparent() + )))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpInfixGroup.withChildren: + defns.InfixOperator.instances + .iterator + .map: op => + field(op) + ~ field(tokens.Expr) + ~ highPredInfixInfix(op) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, left, right) => + right match + case (op, expr1, expr2) => + splice( + tokens.Expr.TmpInfixGroup( + op.unparent(), + tokens.Expr(tokens.Expr.TmpInfixGroup( + curOp.unparent(), + left.unparent(), + expr1.unparent() + )), + expr2.unparent())) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpInfixGroup.withChildren: + defns.InfixOperator.instances + .iterator + .map: op => + field(op) + ~ highPredInfixUnary(op) + ~ field(tokens.Expr) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, left, right) => + left match + case (op, expr) => + splice( + tokens.Expr.TmpUnaryGroup( + op.unparent(), + tokens.Expr(tokens.Expr.TmpInfixGroup( + curOp.unparent(), + expr.unparent(), + right.unparent() + )))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpInfixGroup.withChildren: + defns.InfixOperator.instances + .iterator + .map: op => + field(op) + ~ field(tokens.Expr) + ~ highPredInfixUnary(op) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, left, right) => + right match + case (op, expr) => + splice( + tokens.Expr.TmpUnaryGroup( + op.unparent(), + tokens.Expr(tokens.Expr.TmpInfixGroup( + curOp.unparent(), + left.unparent(), + expr.unparent() + )))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpUnaryGroup.withChildren: + defns.PrefixOperator.instances + .iterator + .map: op => + field(op) + ~ highPredUnaryInfix(op) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, infixGroup) => + infixGroup match + case (op, left, right) => + splice( + tokens.Expr.TmpInfixGroup( + op.unparent(), + tokens.Expr(tokens.Expr.TmpUnaryGroup( + curOp.unparent(), + left.unparent() + )), + right.unparent())) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpUnaryGroup.withChildren: + defns.PostfixOperator.instances + .iterator + .map: op => + field(op) + ~ highPredUnaryInfix(op) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, infixGroup) => + infixGroup match + case (op, left, right) => + splice( + tokens.Expr.TmpInfixGroup( + op.unparent(), + left.unparent(), + tokens.Expr(tokens.Expr.TmpUnaryGroup( + curOp.unparent(), + right.unparent() + )))) + *> pass(once = false, strategy = pass.bottomUp) // assoc related errors + .rules: + on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpInfixGroup.withChildren: + defns.InfixOperator.instances + .iterator + .map: op => + field(op) + ~ badPredInfixInfix(op) + ~ field(tokens.Expr) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, left, right) => + left match + case (op, expr1, expr2) => + splice( + Node(Builtin.Error)( + Builtin.Error.Message( // todo: use src + s"$curOp and $op must have different precedence, or be duplicates of an associative operator." + ), + Builtin.Error.AST( + tokens.Expr.TmpInfixGroup( + curOp.unparent(), + tokens.Expr.TmpInfixGroup( + op.unparent(), + expr1.unparent(), + expr2.unparent() + ), + right.unparent())))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpInfixGroup.withChildren: + defns.InfixOperator.instances + .iterator + .map: op => + field(op) + ~ field(tokens.Expr) + ~ badPredInfixInfix((op)) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, left, right) => + right match + case (op, expr1, expr2) => + splice( + Node(Builtin.Error)( + Builtin.Error.Message( + s"$curOp and $op must have different precedence, or be duplicates of an associative operator." + ), + Builtin.Error.AST( + tokens.Expr.TmpInfixGroup( + curOp.unparent(), + left.unparent(), + tokens.Expr.TmpInfixGroup( + op.unparent(), + expr1.unparent(), + expr2.unparent() + ))))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpInfixGroup.withChildren: + defns.InfixOperator.instances + .iterator + .map: op => + field(op) + ~ badPredInfixUnary(op) + ~ field(tokens.Expr) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, left, right) => + left match + case (op, expr) => + splice( + Node(Builtin.Error)( + Builtin.Error.Message( + s"$curOp and $op must have different precedence." + ), + Builtin.Error.AST( + tokens.Expr.TmpInfixGroup( + curOp.unparent(), + tokens.Expr.TmpUnaryGroup( + op.unparent(), + expr.unparent(), + )), + right.unparent() + ))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpInfixGroup.withChildren: + defns.InfixOperator.instances + .iterator + .map: op => + field(op) + ~ field(tokens.Expr) + ~ badPredInfixUnary(op) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, left, right) => + right match + case (op, expr) => + splice( + Node(Builtin.Error)( + Builtin.Error.Message( + s"$curOp and $op must have different precedence." + ), + Builtin.Error.AST( + tokens.Expr.TmpInfixGroup( + curOp.unparent(), + left.unparent(), + tokens.Expr.TmpUnaryGroup( + op.unparent(), + expr.unparent() + ))))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpUnaryGroup.withChildren: + defns.PrefixOperator.instances + .iterator + .map: op => + field(op) + ~ badPredUnaryInfix(op) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, infixGroup) => + infixGroup match + case (op, left, right) => + splice( + Node(Builtin.Error)( + Builtin.Error.Message( + s"$curOp and $op must have different precedence." + ), + Builtin.Error.AST( + tokens.Expr.UnaryGroup( + curOp.unparent(), + tokens.Expr.TmpInfixGroup( + op.unparent(), + left.unparent(), + right.unparent() + ))))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpUnaryGroup.withChildren: + defns.PostfixOperator.instances + .iterator + .map: op => + field(op) + ~ badPredUnaryInfix(op) + ~ eof + .reduce(_ | _)) + ~ eof + ).rewrite: (curOp, infixGroup) => + infixGroup match + case (op, left, right) => + splice( + Node(Builtin.Error)( + Builtin.Error.Message( + s"$curOp and $op must have different precedence." + ), + Builtin.Error.AST( + tokens.Expr.UnaryGroup( + curOp.unparent(), + tokens.Expr.TmpInfixGroup( + op.unparent(), + left.unparent(), + right.unparent() + ))))) + *> pass(once = false, strategy = pass.bottomUp) + .rules: + on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpInfixGroup.withChildren( + field(tok(defns.InfixOperator.instances*)) + ~ field(tokens.Expr) + ~ field(tokens.Expr) + ~ eof + )) + ~ eof + ).rewrite: (op, right, left) => + splice( + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(op.unparent()), + tokens.Expr.OpCall.Params( + right.unparent(), + left.unparent() + )))) + | on( + parent(tokens.Expr) *> + field(tokens.Expr.TmpUnaryGroup.withChildren( + field(tok(defns.PrefixOperator.instances*) + | tok(defns.PostfixOperator.instances*)) + ~ field(tokens.Expr) + ~ eof + )) + ~ eof + ).rewrite: (op, expr) => + splice( + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(op.unparent()), + tokens.Expr.OpCall.Params( + expr.unparent(), + )))) + end reorderOperations + + val removeNestedExpr = passDef: + wellformed := prevWellformed.makeDerived: + tokens.Expr.removeCases(tokens.Expr) + + pass(once = true, strategy = bottomUp) + .rules: + on( + tok(tokens.Expr) *> + onlyChild(tokens.Expr) + ).rewrite: child => + splice( + child.unparent() + ) + end removeNestedExpr diff --git a/tla/ExprParser.test.scala b/tla/ExprParser.test.scala new file mode 100644 index 0000000..e28c838 --- /dev/null +++ b/tla/ExprParser.test.scala @@ -0,0 +1,900 @@ +package distcompiler.tla + +import distcompiler.* +import Builtin.{Error, SourceMarker} + +// scala-cli test . -- '*ExprParser*' + +class ExprParserTests extends munit.FunSuite: + extension (str: String) + def parseStr: Node.Top = + val wrapped_str = + Source.fromString( + s"""---- MODULE TestMod ---- + |EXTENDS Naturals + |VARIABLE temp + | + |Init == $str + |==== + """.stripMargin + ) + val top = TLAReader(SourceRange.entire(wrapped_str)) + TLAParser(top) + ExprParser(top) + Node.Top( + top(tokens.Module)(tokens.Module.Defns)(tokens.Operator)( + tokens.Expr + ).unparentedChildren + ) + extension (str: String) + def withoutParse: Node.Top = + val wrapped_str = + Source.fromString( + s"""---- MODULE TestMod ---- + |EXTENDS Naturals + |VARIABLE temp + | + |Init == $str + |==== + """.stripMargin + ) + val top = TLAReader(SourceRange.entire(wrapped_str)) + TLAParser(top) + Node.Top( + top(tokens.Module)(tokens.Module.Defns)(tokens.Operator)( + tokens.Expr + ).unparentedChildren + ) + + extension (top: Node.Top) + def parseNode: Node.Top = + val freshTop = Node.Top( + tokens.Module( + tokens.Id("TestMod"), + tokens.Module.Extends(), + tokens.Module.Defns( + tokens.Operator( + tokens.Id("test"), + tokens.Operator.Params(), + tokens.Expr( + top.unparentedChildren + ))))) + ExprParser(freshTop) + Node.Top( + freshTop(tokens.Module)(tokens.Module.Defns)(tokens.Operator)( + tokens.Expr + ).unparentedChildren + ) + + test("NumberLiteral"): + assertEquals("1".parseStr, Node.Top(tokens.Expr.NumberLiteral("1"))) + + test("StringLiteral"): + assertEquals( + "\"string\"".parseStr, + Node.Top(tokens.Expr.StringLiteral("string")) + ) + assertEquals( + "\"string\\nnewline\"".parseStr, + Node.Top(tokens.Expr.StringLiteral("string\nnewline")) + ) + + test("Set Literal"): + assertEquals( + "{1, 2, 3}".parseStr, + Node.Top( + tokens.Expr.SetLiteral( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + ))) + assertEquals( + "{}".parseStr, + Node.Top(tokens.Expr.SetLiteral()) + ) + + test("TupleLiteral"): + assertEquals( + "<<1, 2, 3>>".parseStr, + Node.Top( + tokens.Expr.TupleLiteral( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + ))) + assertEquals( + "<<>>".parseStr, + Node.Top( + tokens.Expr.TupleLiteral() + )) + + test("RecordLiteral"): + assertEquals( + "[x |-> 2, y |-> 3]".parseStr, + Node.Top( + tokens.Expr.RecordLiteral( + tokens.Expr.RecordLiteral.Field( + tokens.Id("x"), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + ), + tokens.Expr.RecordLiteral.Field( + tokens.Id("y"), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + )))) + + test("Projection (Record Field Acess)"): + assertEquals( + "x.y".parseStr, + Node.Top( + tokens.Expr.Project( + tokens.Expr( + tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )), + tokens.Id("y")))) + assertEquals( + "x.y.z".parseStr, + Node.Top( + tokens.Expr.Project( + tokens.Expr(tokens.Expr.Project( + tokens.Expr( + tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )), + tokens.Id("y"))), + tokens.Id("z")))) + assertEquals( + "[y |-> 2].y".parseStr, + Node.Top( + tokens.Expr.Project( + tokens.Expr( + tokens.Expr.RecordLiteral( + tokens.Expr.RecordLiteral.Field( + tokens.Id("y"), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + ) + )), + tokens.Id("y")))) + + test("RecordSetLiteral"): + assertEquals( + "[x : {1, 2}]".parseStr, + Node.Top( + tokens.Expr.RecordSetLiteral( + tokens.Expr.RecordSetLiteral.Field( + tokens.Id("x"), + tokens.Expr( + tokens.Expr.SetLiteral( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + )))))) + + test("Name"): + assertEquals( + "x".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + ))) + + test("Single Binary Operator"): + assertEquals( + "5 + 6".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.OpSym( + defns.+("+") + ), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("5")), + tokens.Expr(tokens.Expr.NumberLiteral("6")) + )))) + assertEquals( + "5 $ x".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.OpSym(defns.$("$")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("5")), + tokens.Expr( + tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )))))) + + test("Single Unary Operator"): + assertEquals( + "UNION A".parseStr, + Node.Top(tokens.Expr.OpCall( + tokens.OpSym(defns.UNION("UNION")), + tokens.Expr.OpCall.Params( + tokens.Expr( + tokens.Expr.OpCall( + tokens.Id("A"), + tokens.Expr.OpCall.Params() + )))))) + assertEquals( + "x'".parseStr, + Node.Top(tokens.Expr.OpCall( + tokens.OpSym(defns.`'`("'")), + tokens.Expr.OpCall.Params( + tokens.Expr( + tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )))))) + + test("Precedence: Infix - Infix"): + assertEquals( + "5 * 6 + 7".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.OpSym( + defns.+("+") + ), + tokens.Expr.OpCall.Params( + tokens.Expr( + tokens.Expr.OpCall( + tokens.OpSym( + defns.*("*") + ), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("5")), + tokens.Expr(tokens.Expr.NumberLiteral("6")) + ))), + tokens.Expr(tokens.Expr.NumberLiteral("7")))))) + assertEquals( + "1 + 8 * 6 - 9 * 3".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.OpSym(defns.+("+")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.-("-")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.*("*")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("8")), + tokens.Expr(tokens.Expr.NumberLiteral("6")) + ))), + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.*("*")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("9")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + )))))))))) + + test("Precedence: Infix - Prefix"): + assertEquals( + "~x /\\ y".parseStr, + Node.Top(tokens.Expr.OpCall( + tokens.OpSym(defns./\("/\\")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.~("~")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + ))))), + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("y"), + tokens.Expr.OpCall.Params() + )))))) + assertEquals( + "~x \\in y".parseStr, + Node.Top(tokens.Expr.OpCall( + tokens.OpSym(defns.~("~")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.`\\in`("\\in")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("y"), + tokens.Expr.OpCall.Params() + ))))))))) + + test("Precedence: Infix - Postfix"): + assertEquals( + "1 + x' * 2".parseStr, + Node.Top(tokens.Expr.OpCall( + tokens.OpSym(defns.+("+")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.*("*")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.`'`("'")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + ))))), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + ))))))) + + test("Precedence: Unary - Unary"): + assertEquals( + "[] ~x".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.OpSym(defns.`[]`("[]")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.~("~")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + ))))))))) + + // TODO???: postfix postfix: never allowed + + test("Precedence: Associative"): + assertEquals( + "1 + 2 + 3".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.OpSym(defns.+("+")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.+("+")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + ))))))) + assertEquals( + "1 * 2 / 3".parseStr, + Node.Top(tokens.Expr( + Node(Builtin.Error)( + Builtin.Error.Message( + s"distcompiler.tla.defns.* and distcompiler.tla.defns./ must have different precedence, or be duplicates of an associative operator." + ), + Builtin.Error.AST( + tokens.Expr.TmpInfixGroup( + defns.*("*"), + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr.TmpInfixGroup( + defns./("/"), + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + ))))))) + + // test("Precedence: Error"): + // assertEquals( + // "x \\in A \\in B".parseStr, + // Node.Top(tokens.Expr( + // Node(Builtin.Error)( + // Builtin.Error.Message( + // s"26:distcompiler.tla.defns.\\in and 26:distcompiler.tla.defns.\\in must have different precedence, or be duplicates of an associative operator." + // ), + // Builtin.Error.AST( + // tokens.Expr.TmpInfixGroup( + // defns.`\\in`("\\in"), + // tokens.Expr(tokens.Id("x")), + // tokens.Expr.TmpInfixGroup( + // defns.`\\in`("\\in"), + // tokens.Expr(tokens.Id("A")), + // tokens.Expr(tokens.Id("B")) + // ))))))) + // assertEquals( + // "x \\in A = B".parseStr, + // Node.Top(tokens.Expr( + // Node(Builtin.Error)( + // Builtin.Error.Message( + // s"26:distcompiler.tla.defns.\\in and distcompiler.tla.defns.= must have different precedence, or be duplicates of an associative operator." + // ), + // Builtin.Error.AST( + // tokens.Expr.TmpInfixGroup( + // defns.`\\in`("\\in"), + // tokens.Expr(tokens.Id("x")), + // tokens.Expr.TmpInfixGroup( + // defns.`=`("="), + // tokens.Expr(tokens.Id("A")), + // tokens.Expr(tokens.Id("B")) + // ))))))) + + test("OpCall"): + assertEquals( + "testFun(1, 2, 3)".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.Id("testFun"), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + )))) + + // TODO: conjunction alignment + + test("FnCall"): + assertEquals( + "x[\"y\"]".parseStr, + Node.Top( + tokens.Expr.FnCall( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.StringLiteral("y")) + ))) + assertEquals( + "x[1, 2, 3]".parseStr, + Node.Top( + tokens.Expr.FnCall( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.TupleLiteral( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")), + )) + ))) + + test("If"): + assertEquals( + """IF A + |THEN 1 + |ELSE 2""".stripMargin.parseStr, + Node.Top( + tokens.Expr.If( + tokens.Expr( + tokens.Expr.OpCall( + tokens.Id("A"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + ))) + + test("Case"): + assertEquals( + """CASE A -> 1 + | [] B -> 2""".stripMargin.parseStr, + Node.Top( + tokens.Expr.Case( + tokens.Expr.Case.Branches( + tokens.Expr.Case.Branch( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("A"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.NumberLiteral("1")) + ), + tokens.Expr.Case.Branch( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("B"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + )), + tokens.Expr.Case.Other(tokens.Expr.Case.Other.None()) + ))) + assertEquals( + """CASE A -> 1 + | [] B -> 2 + | OTHER -> 3""".stripMargin.parseStr, + Node.Top( + tokens.Expr.Case( + tokens.Expr.Case.Branches( + tokens.Expr.Case.Branch( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("A"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.NumberLiteral("1")) + ), + tokens.Expr.Case.Branch( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("B"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + )), + tokens.Expr.Case.Other( + tokens.Expr(tokens.Expr.NumberLiteral("3")) + )))) + + test("LET"): + assertEquals( + """LET x == 1 + |y == 2 + |IN {y, x}""".stripMargin.parseStr, + Node.Top( + tokens.Expr.Let( + tokens.Expr.Let.Defns( + tokens.Operator( + tokens.Id("x"), + tokens.Operator.Params(), + tokens.Expr(tokens.Expr.NumberLiteral("1")) + ), + tokens.Operator( + tokens.Id("y"), + tokens.Operator.Params(), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + ) + ), + tokens.Expr( + tokens.Expr.SetLiteral( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("y"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + ))))))) + + test("Exists"): + assertEquals( + "\\E x \\in {1, 2, 3} : x = 2".parseStr, + Node.Top( + tokens.Expr.Exists( + tokens.QuantifierBounds( + tokens.QuantifierBound( + tokens.Id("x"), + tokens.Expr(tokens.Expr.SetLiteral( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + )))), + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.`=`("=")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + )))))) + assertEquals( + "\\E <> \\in {1, 2, 3} : x = 2".parseStr, + Node.Top( + tokens.Expr.Exists( + tokens.QuantifierBounds( + tokens.QuantifierBound( + tokens.Ids( + tokens.Id("x"), + tokens.Id("y"), + tokens.Id("z") + ), + tokens.Expr(tokens.Expr.SetLiteral( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + )))), + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.`=`("=")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + )))))) + assertEquals( + "\\E x \\in {1, 2, 3}, y \\in {4, 5, 6} : x = 2".parseStr, + Node.Top( + tokens.Expr.Exists( + tokens.QuantifierBounds( + tokens.QuantifierBound( + tokens.Id("x"), + tokens.Expr(tokens.Expr.SetLiteral( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + ))), + tokens.QuantifierBound( + tokens.Id("y"), + tokens.Expr(tokens.Expr.SetLiteral( + tokens.Expr(tokens.Expr.NumberLiteral("4")), + tokens.Expr(tokens.Expr.NumberLiteral("5")), + tokens.Expr(tokens.Expr.NumberLiteral("6")) + )))), + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.`=`("=")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + )))))) + + test("Forall"): + assertEquals( + "\\A x \\in {1, 2, 3} : x = 2".parseStr, + Node.Top( + tokens.Expr.Forall( + tokens.QuantifierBounds( + tokens.QuantifierBound( + tokens.Id("x"), + tokens.Expr(tokens.Expr.SetLiteral( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + )))), + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.`=`("=")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + )))))) + + // TODO: AA EE ?? + + test("Temporal Logic Combined"): + assertEquals( + "s /\\ \\E x \\in y : z /\\ \\E p \\in q : r".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.OpSym(defns./\("/\\")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("s"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.Exists( + tokens.QuantifierBounds( + tokens.QuantifierBound( + tokens.Id("x"), + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("y"), + tokens.Expr.OpCall.Params() + )))), + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns./\("/\\")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("z"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.Exists( + tokens.QuantifierBounds( + tokens.QuantifierBound( + tokens.Id("p"), + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("q"), + tokens.Expr.OpCall.Params() + )))), + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("r"), + tokens.Expr.OpCall.Params() + ))))))))))))) + + // todo: function + // todo: set comprehension + // todo: set refinement + // todo: lambda + + test("Choose"): + assertEquals( + "CHOOSE x \\in {1, 2, 3} : x = 2".parseStr, + Node.Top( + tokens.Expr.Choose( + tokens.QuantifierBound( + tokens.Id("x"), + tokens.Expr(tokens.Expr.SetLiteral( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + ))), + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym(defns.`=`("=")), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + )), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + )))))) + // todo: id nil, tuple expr, tuple nil + + // todo: except + + test("Parentheses"): + assertEquals( + "(1)".parseStr, + Node.Top(tokens.Expr.NumberLiteral("1")) + ) + assertEquals( + "(((x)))".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.Id("x"), + tokens.Expr.OpCall.Params() + ))) + assertEquals( + "(5 + 6)".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.OpSym( + defns.+("+") + ), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("5")), + tokens.Expr(tokens.Expr.NumberLiteral("6")) + )))) + assertEquals( + "(1 + 2) * 3".parseStr, + Node.Top( + tokens.Expr.OpCall( + tokens.OpSym( + defns.*("*") + ), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.OpCall( + tokens.OpSym( + defns.+("+") + ), + tokens.Expr.OpCall.Params( + tokens.Expr(tokens.Expr.NumberLiteral("1")), + tokens.Expr(tokens.Expr.NumberLiteral("2")) + ) + )), + tokens.Expr(tokens.Expr.NumberLiteral("3")) + )))) + // assertEquals( + // "\\A x \\in {1, 2, 3} : x = (2)".parseStr, + // Node.Top( + // tokens.Expr.Forall( + // tokens.QuantifierBounds( + // tokens.QuantifierBound( + // tokens.Id("x"), + // tokens.Expr(tokens.Expr.SetLiteral( + // tokens.Expr(tokens.Expr.NumberLiteral("1")), + // tokens.Expr(tokens.Expr.NumberLiteral("2")), + // tokens.Expr(tokens.Expr.NumberLiteral("3")) + // )))), + // tokens.Expr(tokens.Expr.OpCall( + // tokens.OpSym(defns.`=`("=")), + // tokens.Expr.OpCall.Params( + // tokens.Expr(tokens.Expr.OpCall( + // tokens.Id("x"), + // tokens.Expr.OpCall.Params() + // )), + // tokens.Expr(tokens.Expr.NumberLiteral("2")) + // )))))) + + // test("Conjunctions"): + // assertEquals( + // "/\\ 1".parseStr, + // Node.Top( + // tokens.Expr.NumberLiteral("1") + // )) + // assertEquals( + // "/\\ 1 \\/ 2".parseStr, + // Node.Top( + // tokens.Expr.OpCall( + // tokens.OpSym(defns.\/("\\/")), + // tokens.Expr.OpCall.Params( + // tokens.Expr(tokens.Expr.NumberLiteral("1")), + // tokens.Expr(tokens.Expr.NumberLiteral("2")) + // )))) + // assertEquals( + // s""" + // |/\\ 1 + // |/\\ 2 + // |""".stripMargin.parseStr, + // Node.Top( + // tokens.Expr.OpCall( + // tokens.OpSym(defns./\("/\\")), + // tokens.Expr.OpCall.Params( + // tokens.Expr(tokens.Expr.NumberLiteral("1")), + // tokens.Expr(tokens.Expr.NumberLiteral("2")) + // )))) + // assertEquals( + // s""" + // |/\\ 1 \\/ 1 + // |/\\ 2 \\/ 2 + // |/\\ 3 \\/ 3 + // |""".stripMargin.parseStr, + // Node.Top(tokens.Expr.OpCall( + // tokens.OpSym(defns./\("/\\")), + // tokens.Expr.OpCall.Params( + // tokens.Expr(tokens.Expr.OpCall( + // tokens.OpSym(defns.\/("\\/")), + // tokens.Expr.OpCall.Params( + // tokens.Expr(tokens.Expr.NumberLiteral("1")), + // tokens.Expr(tokens.Expr.NumberLiteral("1")) + // ) + // )), + // tokens.Expr(tokens.Expr.OpCall( + // tokens.OpSym(defns./\("/\\")), + // tokens.Expr.OpCall.Params( + // tokens.Expr(tokens.Expr.OpCall( + // tokens.OpSym(defns.\/("\\/")), + // tokens.Expr.OpCall.Params( + // tokens.Expr(tokens.Expr.NumberLiteral("2")), + // tokens.Expr(tokens.Expr.NumberLiteral("2")) + // ))), + // tokens.Expr(tokens.Expr.OpCall( + // tokens.OpSym(defns.\/("\\/")), + // tokens.Expr.OpCall.Params( + // tokens.Expr(tokens.Expr.NumberLiteral("3")), + // tokens.Expr(tokens.Expr.NumberLiteral("3")) + // )))))))))) + + // assertEquals( + // s""" + // |/\\ 1 + // | /\\ 2 + // |/\\ 3 + // |""".stripMargin.parseStr, + // Node.Top( + // tokens.Expr.OpCall( + // tokens.OpSym(defns./\("/\\")), + // tokens.Expr.OpCall.Params( + // tokens.Expr(tokens.Expr.OpCall( + // tokens.OpSym(defns./\("/\\")), + // tokens.Expr.OpCall.Params( + // tokens.Expr(tokens.Expr.NumberLiteral("1")), + // tokens.Expr(tokens.Expr.NumberLiteral("2")) + // ) + // )), + // tokens.Expr(tokens.Expr.NumberLiteral("3")) + // ) + // ) + // )) + + // assertEquals( + // s""" + // |/\\ \\E x \in {1, 2, 3} : 1 + // | /\\ 2 + // |""".stripMargin.parseStr, + // Node.Top( + // tokens.Expr.NumberLiteral("1") + // )) + + // assertEquals( + // s""" + // |/\\ \\E x \in {1, 2, 3} : 1 + // |/\\ 2 + // |""".stripMargin.parseStr, + // Node.Top( + // tokens.Expr.NumberLiteral("1") + // )) + + // assertEquals( + // s""" + // | /\\ \\E x \in {1, 2, 3} : 1 + // |/\\ 2 + // |""".stripMargin.parseStr, + // Node.Top( + // tokens.Expr.NumberLiteral("1") + // )) + + // assertEquals( + // s""" + // |/\\ 1 + // | /\\ 2 + // | /\\ 3 + // | /\\ 4 + // |/\\ 5 + // | /\\ 6 + // |""".stripMargin.parseStr, + // Node.Top( + // tokens.Expr.NumberLiteral("1") + // )) + +// TODO +// paren +// conjucntion +// lamda function set diff --git a/tla/TLAParser.scala b/tla/TLAParser.scala index 88fbe71..1a85a99 100644 --- a/tla/TLAParser.scala +++ b/tla/TLAParser.scala @@ -173,6 +173,71 @@ object TLAParser extends PassSeq: ).void nodeSpanMatchedBy(impl).map(RawExpression.apply) + end rawExpression + + def rawConjunction(col : Integer): SeqPattern[RawExpression] = + val simpleCases: SeqPattern[Unit] = + anyChild.void <* not( + tok(expressionDelimiters*) + | tok(defns./\).filter(node => + val s = node.sourceRange + s.source.lines.lineColAtOffset(s.offset)._2 > col) + // stop at operator definitions: all valid patterns leading to == here + | operatorDefnBeginnings + ) + + lazy val quantifierBound: SeqPattern[EmptyTuple] = + skip( + tok(TupleGroup).as(EmptyTuple) + | repeatedSepBy1(`,`)(Alpha) + ) + ~ skip(defns.`\\in`) + ~ skip(defer(impl)) + ~ trailing + + lazy val quantifierBounds: SeqPattern[Unit] = + repeatedSepBy1(`,`)(quantifierBound).void + + lazy val forallExists: SeqPattern[EmptyTuple] = + skip( + tok(LaTexLike).src("\\A") | tok(LaTexLike).src("\\AA") | tok(LaTexLike) + .src("\\E") | tok(LaTexLike).src("\\EE") + ) + ~ skip(quantifierBounds | repeatedSepBy1(`,`)(Alpha)) + ~ skip(`:`) + ~ trailing + + lazy val choose: SeqPattern[EmptyTuple] = + skip(defns.CHOOSE) + ~ skip(quantifierBound | repeatedSepBy1(`,`)(Alpha)) + ~ skip(`:`) + ~ trailing + + lazy val lambda: SeqPattern[EmptyTuple] = + skip(defns.LAMBDA) + ~ skip(repeatedSepBy1(`,`)(Alpha)) + ~ skip(`:`) + ~ trailing + + // this is an over-approximation of what can show up + // in an identifier prefix + // TODO: why does the grammar make it look like it goes the other way round? + lazy val idFrag: SeqPattern[EmptyTuple] = + skip(`!`) + ~ skip(`:`) + ~ trailing + + lazy val impl: SeqPattern[Unit] = + repeated1( + forallExists + | choose + | lambda + | idFrag + | simpleCases // last, otherwise it eats parts of the above + ).void + + nodeSpanMatchedBy(impl).map(RawExpression.apply) + end rawConjunction val proofDelimiters: Seq[Token] = Seq( defns.ASSUME, @@ -711,64 +776,64 @@ object TLAParser extends PassSeq: ).rewrite: (local, op) => splice(tokens.Local(op.unparent()).like(local)) - // TODO: finish expression parsing - // val buildExpressions = passDef: - // wellformed := prevWellformed.makeDerived: - // val removedCases = Seq( - // TLAReader.StringLiteral, - // TLAReader.NumberLiteral, - // TLAReader.TupleGroup, - // ) - // tokens.Module.Defns.removeCases(removedCases*) - // tokens.Module.Defns.addCases(tokens.Expr) - // TLAReader.groupTokens.foreach: tok => - // tok.removeCases(removedCases*) - // tok.addCases(tokens.Expr) - - // tokens.Expr.importFrom(tla.wellformed) - // tokens.Expr.addCases(tokens.TmpGroupExpr) - - // tokens.TmpGroupExpr ::= tokens.Expr - - // pass(once = false, strategy = pass.bottomUp) - // .rules: - // on( - // TLAReader.StringLiteral - // ).rewrite: lit => - // splice(tokens.Expr(tokens.Expr.StringLiteral().like(lit))) - // | on( - // TLAReader.NumberLiteral - // ).rewrite: lit => - // splice(tokens.Expr(tokens.Expr.NumberLiteral().like(lit))) - // | on( - // field(TLAReader.Alpha) - // ~ field( - // tok(TLAReader.ParenthesesGroup) *> children( - // repeatedSepBy(`,`)(tokens.Expr) - // ) - // ) - // ~ trailing - // ).rewrite: (name, params) => - // splice(tokens.Expr(tokens.Expr.OpCall( - // tokens.Id().like(name), - // tokens.Expr.OpCall.Params(params.iterator.map(_.unparent())), - // ))) - // | on( - // TLAReader.Alpha - // ).rewrite: name => - // splice(tokens.Expr(tokens.Expr.OpCall( - // tokens.Id().like(name), - // tokens.Expr.OpCall.Params(), - // ))) - // | on( - // tok(TLAReader.ParenthesesGroup) *> onlyChild(tokens.Expr) - // ).rewrite: expr => - // // mark this group as an expression, but leave evidence that it is a group (for operator precedence handling) - // splice(tokens.Expr(tokens.TmpGroupExpr(expr.unparent()))) - // | on( - // tok(TLAReader.TupleGroup).product(children( - // field(repeatedSepBy(`,`)(tokens.Expr)) - // ~ eof - // )) - // ).rewrite: (lit, elems) => - // splice(tokens.Expr(tokens.Expr.TupleLiteral(elems.iterator.map(_.unparent())))) +// TODO: finish expression parsing +// val buildExpressions = passDef: +// wellformed := prevWellformed.makeDerived: +// val removedCases = Seq( +// TLAReader.StringLiteral, +// TLAReader.NumberLiteral, +// TLAReader.TupleGroup, +// ) +// tokens.Module.Defns.removeCases(removedCases*) +// tokens.Module.Defns.addCases(tokens.Expr) +// TLAReader.groupTokens.foreach: tok => +// tok.removeCases(removedCases*) +// tok.addCases(tokens.Expr) + +// tokens.Expr.importFrom(tla.wellformed) +// tokens.Expr.addCases(tokens.TmpGroupExpr) + +// tokens.TmpGroupExpr ::= tokens.Expr + +// pass(once = false, strategy = pass.bottomUp) +// .rules: +// on( +// TLAReader.StringLiteral +// ).rewrite: lit => +// splice(tokens.Expr(tokens.Expr.StringLiteral().like(lit))) +// | on( +// TLAReader.NumberLiteral +// ).rewrite: lit => +// splice(tokens.Expr(tokens.Expr.NumberLiteral().like(lit))) +// | on( +// field(TLAReader.Alpha) +// ~ field( +// tok(TLAReader.ParenthesesGroup) *> children( +// repeatedSepBy(`,`)(tokens.Expr) +// ) +// ) +// ~ trailing +// ).rewrite: (name, params) => +// splice(tokens.Expr(tokens.Expr.OpCall( +// tokens.Id().like(name), +// tokens.Expr.OpCall.Params(params.iterator.map(_.unparent())), +// ))) +// | on( +// TLAReader.Alpha +// ).rewrite: name => +// splice(tokens.Expr(tokens.Expr.OpCall( +// tokens.Id().like(name), +// tokens.Expr.OpCall.Params(), +// ))) +// | on( +// tok(TLAReader.ParenthesesGroup) *> onlyChild(tokens.Expr) +// ).rewrite: expr => +// // mark this group as an expression, but leave evidence that it is a group (for operator precedence handling) +// splice(tokens.Expr(tokens.TmpGroupExpr(expr.unparent()))) +// | on( +// tok(TLAReader.TupleGroup).product(children( +// field(repeatedSepBy(`,`)(tokens.Expr)) +// ~ eof +// )) +// ).rewrite: (lit, elems) => +// splice(tokens.Expr(tokens.Expr.TupleLiteral(elems.iterator.map(_.unparent())))) diff --git a/tla/TLAParser.test.scala b/tla/TLAParser.test.scala index 988f4a3..94ada11 100644 --- a/tla/TLAParser.test.scala +++ b/tla/TLAParser.test.scala @@ -14,7 +14,7 @@ package distcompiler.tla -import distcompiler.* +import distcompiler.*, distcompiler.tla.ExprParser class TLAParserTests extends munit.FunSuite, test.WithTLACorpus: self => @@ -36,6 +36,7 @@ class TLAParserTests extends munit.FunSuite, test.WithTLACorpus: // ) // , tracer = Manip.RewriteDebugTracer(os.pwd / "dbg_passes") ) + ExprParser(top) // re-enable if interesting: // val folder = os.SubPath(file.subRelativeTo(clonesDir).segments.init) diff --git a/tla/defns.scala b/tla/defns.scala index dcf1cb3..397f1cf 100644 --- a/tla/defns.scala +++ b/tla/defns.scala @@ -74,7 +74,9 @@ object defns: case object PROPOSITION extends ReservedWord case object ONLY extends ReservedWord - sealed trait Operator extends Token, HasSpelling + sealed trait Operator extends Token, HasSpelling: + def highPrecedence: Int + def lowPrecedence: Int object Operator: lazy val instances: IArray[Operator] = @@ -100,7 +102,7 @@ object defns: sealed trait InfixOperator( val lowPrecedence: Int, - val highPredecence: Int, + val highPrecedence: Int, val isAssociative: Boolean = false ) extends Operator object InfixOperator extends util.HasInstanceArray[InfixOperator] @@ -209,10 +211,12 @@ object defns: case object `\\supset` extends InfixOperator(5, 5) case object `%%` extends InfixOperator(10, 11, true) - sealed trait PostfixOperator(val predecence: Int) extends Operator + sealed trait PostfixOperator(val precedence: Int) extends Operator: + def highPrecedence: Int = precedence + def lowPrecedence: Int = precedence object PostfixOperator extends util.HasInstanceArray[PostfixOperator] case object `^+` extends PostfixOperator(15) case object `^*` extends PostfixOperator(15) case object `^#` extends PostfixOperator(15) - case object `'` extends PostfixOperator(15) + case object `'` extends PostfixOperator(15) diff --git a/tla/package.scala b/tla/package.scala index 16b5149..1c898ec 100644 --- a/tla/package.scala +++ b/tla/package.scala @@ -170,6 +170,7 @@ val wellformed: Wellformed = t.Expr.TupleLiteral, t.Expr.RecordLiteral, t.Expr.Project, + t.Expr.RecordSetLiteral, t.Expr.OpCall, t.Expr.FnCall, t.Expr.If, @@ -200,7 +201,7 @@ val wellformed: Wellformed = t.Id ) t.Expr.RecordSetLiteral ::= repeated( - t.Expr.RecordLiteral.Field, + t.Expr.RecordSetLiteral.Field, minCount = 1 ) t.Expr.RecordSetLiteral.Field ::= fields( @@ -222,11 +223,20 @@ val wellformed: Wellformed = t.Expr ) - t.Expr.Case ::= repeated(t.Expr.Case.Branch, minCount = 1) + t.Expr.Case ::= fields( + t.Expr.Case.Branches, + t.Expr.Case.Other + ) + t.Expr.Case.Branches ::= repeated(t.Expr.Case.Branch, minCount = 1) t.Expr.Case.Branch ::= fields( t.Expr, t.Expr ) + t.Expr.Case.Other ::= choice( + t.Expr, + t.Expr.Case.Other.None + ) + t.Expr.Case.Other.None ::= Atom t.Expr.Let ::= fields( t.Expr.Let.Defns, @@ -286,10 +296,7 @@ val wellformed: Wellformed = t.Expr ) t.Expr.Lambda.Params ::= repeated( - choice( - t.Id, - t.Order2 - ), + t.Id, minCount = 1 )