diff --git a/langs/tla/ExprParser.scala b/langs/tla/ExprParser.scala new file mode 100644 index 0000000..d2ccc34 --- /dev/null +++ b/langs/tla/ExprParser.scala @@ -0,0 +1,1150 @@ +// Copyright 2024-2025 Forja Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package forja.langs.tla + +import cats.syntax.all.given + +import forja.* +import forja.dsl.* +import forja.wf.Wellformed + +import TLAReader.* +import TLAParser.{rawExpression, rawConjunction} + +object ExprParser extends PassSeq: + object TmpInfixGroup extends Token + object TmpUnaryGroup extends Token + + def inputWellformed: Wellformed = TLAParser.outputWellformed + // TODO: make private + def highPredInfixInfix( + op: defns.InfixOperator, + ): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node, Node)]]] = + field( + lang.Expr.withChildren( + field( + TmpInfixGroup.withChildren( + field( + tok(defns.InfixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token: defns.InfixOperator => + op.highPrecedence > op2Token.highPrecedence, + ), + ) + ~ field(lang.Expr) + ~ field(lang.Expr) + ~ eof, + ), + ) + ~ eof, + ), + ) + def highPredInfixUnary( + op: defns.InfixOperator, + ): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node)]]] = + field( + lang.Expr.withChildren( + field( + TmpUnaryGroup.withChildren( + field( + tok(defns.PrefixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token: defns.PrefixOperator => + op.highPrecedence > op2Token.highPrecedence + case op2Token: defns.PostfixOperator => + op.highPrecedence > op2Token.precedence, + ), + ) + ~ field(lang.Expr) + ~ eof, + ), + ) + ~ eof, + ), + ) + def highPredUnaryInfix( + op: defns.Operator, + ): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node, Node)]]] = + field( + lang.Expr.withChildren( + field( + TmpInfixGroup.withChildren( + field( + tok(defns.InfixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token: defns.InfixOperator => + op.highPrecedence > op2Token.highPrecedence, + ), + ) + ~ field(lang.Expr) + ~ field(lang.Expr) + ~ eof, + ), + ) + ~ eof, + ), + ) + + def badPredInfixInfix( + op: defns.InfixOperator, + ): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node, Node)]]] = + field( + lang.Expr.withChildren( + field( + TmpInfixGroup.withChildren( + field( + tok(defns.InfixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token: defns.InfixOperator => + !((op.highPrecedence > op2Token.highPrecedence) + || (op.lowPrecedence < op2Token.lowPrecedence) + || ((op == op2Token) && op.isAssociative)), + ), + ) + ~ field(lang.Expr) + ~ field(lang.Expr) + ~ eof, + ), + ) + ~ eof, + ), + ) + def badPredInfixUnary( + op: defns.InfixOperator, + ): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node)]]] = + field( + lang.Expr.withChildren( + field( + TmpUnaryGroup.withChildren( + field( + tok(defns.PrefixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token: defns.PrefixOperator => + !((op.highPrecedence > op2Token.highPrecedence) + || (op.lowPrecedence < op2Token.lowPrecedence)) + case op2Token: defns.PostfixOperator => + !((op.highPrecedence > op2Token.precedence) + || (op.lowPrecedence < op2Token.precedence)), + ), + ) + ~ field(lang.Expr) + ~ eof, + ), + ) + ~ eof, + ), + ) + def badPredUnaryInfix( + op: defns.Operator, + ): SeqPattern[SeqPattern.Fields[Tuple1[(Node, Node, Node)]]] = + field( + lang.Expr.withChildren( + field( + TmpInfixGroup.withChildren( + field( + tok(defns.InfixOperator.instances*) + .filter(op2 => + op2.token match + case op2Token: defns.InfixOperator => + !((op.highPrecedence > op2Token.highPrecedence) + || (op.lowPrecedence < op2Token.lowPrecedence)), + ), + ) + ~ field(lang.Expr) + ~ field(lang.Expr) + ~ eof, + ), + ) + ~ eof, + ), + ) + def matchQuantifierId(): SeqPattern[(Node, TLAParser.RawExpression)] = + parent(lang.Expr) *> + field(TLAReader.Alpha) + ~ skip(defns.`\\in`) + ~ field(rawExpression) + ~ trailing + def matchQuantifierIds(): SeqPattern[(List[Node], TLAParser.RawExpression)] = + parent(lang.Expr) *> + field( + tok(TLAReader.TupleGroup).withChildren( + field(repeatedSepBy(`,`)(tok(TLAReader.Alpha))) + ~ eof, + ), + ) + ~ skip(defns.`\\in`) + ~ field(rawExpression) + ~ trailing + + // replace all lang.Expr(contents...) with + // lang.Expr(lang.ExprTry, contents...) + + val buildExpressions = passDef: + wellformed := prevWellformed.makeDerived: + val removedCases = Seq( + TLAReader.StringLiteral, + TLAReader.NumberLiteral, + TLAReader.TupleGroup, + // TODO: remove cases + ) + TLAReader.groupTokens.foreach: tok => + tok.removeCases(removedCases*) + tok.addCases(lang.Expr) + + TmpInfixGroup ::= fields( + choice(defns.InfixOperator.instances*), + lang.Expr, + lang.Expr, + ) + TmpUnaryGroup ::= fields( + choice( + (defns.PrefixOperator.instances ++ defns.PostfixOperator.instances)*, + ), + lang.Expr, + ) + + lang.Expr.deleteShape() + lang.Expr.importFrom(lang.wf) + lang.Expr.addCases(lang.Expr, TmpInfixGroup, TmpUnaryGroup) + + // TODO: assign the correct source with .like() + + pass(once = false, strategy = pass.bottomUp) // conjunction alignment + .rules: + on( + parent(lang.Expr) *> + (field(tok(defns./\)) + ~ field(rawConjunction(1)) + ~ field(tok(defns./\)) + ~ field(rawConjunction(1)) + ~ eof).filter(things => + things match + case (and1: Node, r1, and2: Node, r2) => + val s1 = and1.sourceRange + val s2 = and2.sourceRange + /* println("\nCol 1: " + + * s1.source.lines.lineColAtOffset(s1.offset)) */ + /* println("\nCol 2: " + + * s2.source.lines.lineColAtOffset(s2.offset)) */ + s1.source.lines + .lineColAtOffset(s1.offset) + ._2 == s2.source.lines.lineColAtOffset(s2.offset)._2, + ), + ).rewrite: (and1, r1, and2, r2) => + splice( + and1.unparent(), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(and2.unparent()), + lang.Expr.OpCall.Params( + r1.mkNode, + r2.mkNode, + ), + ), + ), + ) + *> pass( + once = false, + strategy = pass.bottomUp, + ) // remove leading /\, remove paren + .rules: + // on( + // field(lang.Expr.withChildren( + // skip(defns./\) + // ~ field(lang.Expr) + // ~ eof + // )) + // ~ eof + // ).rewrite: expr => + // splice(expr.unparent()) + on( + field( + lang.Expr.withChildren( + skip(defns./\) + ~ field(rawExpression) + ~ eof, + ), + ) + ~ eof, + ).rewrite: expr => + splice(expr.mkNode) + *> pass( + once = false, + strategy = pass.bottomUp, + ) // resolve quantifiers/opCall + .rules: + on( + parent(lang.Expr) *> + field(TLAReader.Alpha) + ~ field( + TLAReader.ParenthesesGroup.withChildren( + field(repeatedSepBy(`,`)(rawExpression)) + ~ eof, + ), + ) + ~ trailing, + ).rewrite: (fun, args) => + splice( + lang.Expr( + lang.Expr.OpCall( + lang.Id().like(fun), + lang.Expr.OpCall.Params( + args.iterator.map(_.mkNode), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + skip(tok(TLAReader.LaTexLike).src("\\E")) + ~ field(repeatedSepBy1(`,`)(matchQuantifierId())) + ~ skip(TLAReader.`:`) + ~ field(rawExpression) + ~ trailing, + ).rewrite: (qBounds, expr) => + splice( + lang.Expr( + lang.Expr.Exists( + lang.QuantifierBounds( + qBounds.iterator.map((id, qExpr) => + lang.QuantifierBound( + lang.Id().like(id), + qExpr.mkNode, + ), + ), + ), + expr.mkNode, + ), + ), + ) + | on( + parent(lang.Expr) *> + skip(tok(TLAReader.LaTexLike).src("\\E")) + ~ field(repeatedSepBy1(`,`)(matchQuantifierIds())) + ~ skip(TLAReader.`:`) + ~ field(rawExpression) + ~ trailing, + ).rewrite: (qBounds, expr) => + splice( + lang.Expr( + lang.Expr.Exists( + lang.QuantifierBounds( + qBounds.iterator.map((ids, qExpr) => + lang.QuantifierBound( + lang.Ids( + ids.iterator.map(id => lang.Id().like(id)), + ), + qExpr.mkNode, + ), + ), + ), + expr.mkNode, + ), + ), + ) + | on( + parent(lang.Expr) *> + skip(tok(TLAReader.LaTexLike).src("\\A")) + ~ field(repeatedSepBy1(`,`)(matchQuantifierId())) + ~ skip(TLAReader.`:`) + ~ field(rawExpression) + ~ trailing, + ).rewrite: (qBounds, expr) => + splice( + lang.Expr( + lang.Expr.Forall( + lang.QuantifierBounds( + qBounds.iterator.map((id, qExpr) => + lang.QuantifierBound( + lang.Id().like(id), + qExpr.mkNode, + ), + ), + ), + expr.mkNode, + ), + ), + ) + | on( + parent(lang.Expr) *> + skip(tok(TLAReader.LaTexLike).src("\\A")) + ~ field(repeatedSepBy1(`,`)(matchQuantifierIds())) + ~ skip(TLAReader.`:`) + ~ field(rawExpression) + ~ trailing, + ).rewrite: (qBounds, expr) => + splice( + lang.Expr( + lang.Expr.Forall( + lang.QuantifierBounds( + qBounds.iterator.map((ids, qExpr) => + lang.QuantifierBound( + lang.Ids( + ids.iterator.map(id => lang.Id().like(id)), + ), + qExpr.mkNode, + ), + ), + ), + expr.mkNode, + ), + ), + ) + | on( + parent(lang.Expr) *> + skip(tok(defns.CHOOSE)) + ~ field(matchQuantifierId()) + ~ skip(TLAReader.`:`) + ~ field(rawExpression) + ~ trailing, + ).rewrite: (qBound, expr) => + qBound match + case (id, qExpr) => + splice( + lang.Expr( + lang.Expr.Choose( + lang.QuantifierBound( + lang.Id().like(id), + qExpr.mkNode, + ), + expr.mkNode, + ), + ), + ) + // TODO: tuple qbound + // id nil + // tuple nil + *> pass(once = false, strategy = pass.bottomUp) + .rules: + on( + parent(lang.Expr) *> + TLAReader.Alpha, + ).rewrite: name => + splice( + lang.Expr( + lang.Id().like(name), + ), + ) + | on( + parent(lang.Expr) *> + TLAReader.NumberLiteral, + ).rewrite: lit => + splice(lang.Expr(lang.Expr.NumberLiteral().like(lit))) + | on( + parent(lang.Expr) *> + TLAReader.StringLiteral, + ).rewrite: lit => + splice(lang.Expr(lang.Expr.StringLiteral().like(lit))) + | on( + parent(lang.Expr) *> + tok(TLAReader.BracesGroup) *> + children: + field(repeatedSepBy(`,`)(rawExpression)) + ~ eof, + ).rewrite: exprs => + splice( + lang.Expr( + lang.Expr.SetLiteral(exprs.iterator.map(_.mkNode)), + ), + ) + | on( + parent(lang.Expr) *> + tok(TLAReader.TupleGroup).product( + children( + field(repeatedSepBy(`,`)(rawExpression)) + ~ eof, + ), + ), + ).rewrite: (lit, elems) => + splice( + lang.Expr( + lang.Expr.TupleLiteral(elems.iterator.map(_.mkNode)).like(lit), + ), + ) + | on( + parent(lang.Expr) *> + tok(TLAReader.SqBracketsGroup) *> + children: + field( + repeatedSepBy1(`,`)( + field(TLAReader.Alpha) + ~ skip(`|->`) + ~ field(rawExpression) + ~ trailing, + ), + ) + ~ eof, + ).rewrite: fields => + splice( + lang.Expr( + lang.Expr.RecordLiteral( + fields.iterator.map((alpha, expr) => + lang.Expr.RecordLiteral.Field( + lang.Id().like(alpha.unparent()), + expr.mkNode, + ), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(lang.Expr) + ~ skip(defns.`.`) + ~ field( + lang.Expr.withChildren( + field(lang.Id) + ~ eof, + ), + ) + ~ trailing, + ).rewrite: (expr, id) => + splice( + lang.Expr( + lang.Expr.Project( + lang.Expr( + expr.unparent(), + ), + id.unparent(), + ), + ), + ) + | on( + parent(lang.Expr) *> + tok(TLAReader.SqBracketsGroup) *> + children: + field( + repeatedSepBy1(`,`)( + field(TLAReader.Alpha) + ~ skip(`:`) + ~ field(rawExpression) + ~ trailing, + ), + ) + ~ eof, + ).rewrite: fields => + splice( + lang.Expr( + lang.Expr.RecordSetLiteral( + fields.iterator.map((alpha, expr) => + lang.Expr.RecordSetLiteral.Field( + lang.Id().like(alpha.unparent()), + expr.mkNode, + ), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(lang.Expr) + ~ field(tok(defns.InfixOperator.instances.filter(_ != defns.`.`)*)) + ~ field(lang.Expr) + ~ eof, + ).rewrite: (left, op, right) => + splice( + lang.Expr( + TmpInfixGroup( + op.unparent(), + left.unparent(), + right.unparent(), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(tok(defns.PrefixOperator.instances*)) + ~ field(lang.Expr) + ~ eof, + ).rewrite: (op, expr) => + splice( + lang.Expr( + TmpUnaryGroup( + op.unparent(), + expr.unparent(), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(lang.Expr) + ~ field(tok(defns.PostfixOperator.instances*)) + ~ trailing, + ).rewrite: (expr, op) => + splice( + lang.Expr( + TmpUnaryGroup( + op.unparent(), + expr.unparent(), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(lang.Expr) + ~ field( + tok(TLAReader.SqBracketsGroup) *> + children: + field(repeatedSepBy(`,`)(rawExpression)) + ~ eof, + ) + ~ eof, + ).rewrite: (callee, args) => + splice( + lang.Expr( + lang.Expr.FnCall( + callee.unparent(), + args match + case List(expr) => + expr.mkNode + case _ => + lang.Expr( + lang.Expr.TupleLiteral( + args.iterator.map(_.mkNode), + ), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + skip(defns.CASE) + ~ field( + repeatedSepBy(defns.`[]`)( + field(lang.Expr) + ~ skip(defns.-) + ~ skip(defns.>) + ~ field(lang.Expr) + ~ trailing, + ), + ) + ~ field( + optional( + skip(defns.OTHER) + ~ skip(defns.-) + ~ skip(defns.>) + ~ field(lang.Expr) + ~ eof, + ), + ) + ~ eof, + ).rewrite: (cases, other) => + splice( + lang.Expr( + lang.Expr.Case( + lang.Expr.Case.Branches( + cases.iterator.map((pred, branch) => + lang.Expr.Case.Branch( + pred.unparent(), + branch.unparent(), + ), + ), + ), + lang.Expr.Case.Other( + other match + case None => lang.Expr.Case.Other.None() + case Some(expr) => expr.unparent(), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + field( + tok(TLAReader.LetGroup).product( + children( + field( + repeated1: + tok(lang.Operator) + | tok(lang.ModuleDefinition) + | tok(lang.Recursive), + ) + ~ eof, + ), + ), + ) + ~ field(lang.Expr) + ~ trailing, + ).rewrite: (let, expr) => + splice( + lang + .Expr( + lang.Expr.Let( + lang.Expr.Let.Defns( + let._2.iterator.map(_.unparent()), + ), + expr.unparent(), + ), + ) + .like(let._1), + ) + // TODO: Function + // TODO: SetComprehension {expr : x \in y} + // TODO: SetRefinement {x \in y : bool} + // {x \in y : p \in q} TODO: look in the book + // TODO: Choose + // TODO: Except + // TODO: Lambda + // TODO a \x b \ a + | on( + parent(lang.Expr) *> + skip(defns.IF) + ~ field(lang.Expr) + ~ skip(defns.THEN) + ~ field(lang.Expr) + ~ skip(defns.ELSE) + ~ field(lang.Expr) + ~ trailing, + ).rewrite: (pred, t, f) => + splice( + lang.Expr.If( + pred.unparent(), + t.unparent(), + f.unparent(), + ), + ) + | on( + parent(lang.Expr) *> + tok(TLAReader.ParenthesesGroup) *> + children: + field(rawExpression) + ~ eof, + ).rewrite: rawExpr => + splice( + lang.Expr(rawExpr.mkNode), + ) + *> pass(once = false, strategy = pass.bottomUp) // resolve Alphas + .rules: + on( + parent(lang.Expr) *> + lang.Id, + ).rewrite: name => + splice( + lang.Expr( + lang.Expr.OpCall( + name.unparent(), + lang.Expr.OpCall.Params(), + ), + ), + ) + end buildExpressions + + val reorderOperations = passDef: + wellformed := prevWellformed.makeDerived: + lang.Expr.removeCases(TmpInfixGroup, TmpUnaryGroup) + pass(once = false, strategy = pass.bottomUp) + .rules: + on( + parent(lang.Expr) *> + field(TmpInfixGroup.withChildren: + defns.InfixOperator.instances.iterator + .map: op => + field(op) + ~ highPredInfixInfix(op) + ~ field(lang.Expr) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, left, right) => + left match + case (op, expr1, expr2) => + splice( + TmpInfixGroup( + op.unparent(), + expr1.unparent(), + lang.Expr( + TmpInfixGroup( + curOp.unparent(), + expr2.unparent(), + right.unparent(), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(TmpInfixGroup.withChildren: + defns.InfixOperator.instances.iterator + .map: op => + field(op) + ~ field(lang.Expr) + ~ highPredInfixInfix(op) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, left, right) => + right match + case (op, expr1, expr2) => + splice( + TmpInfixGroup( + op.unparent(), + lang.Expr( + TmpInfixGroup( + curOp.unparent(), + left.unparent(), + expr1.unparent(), + ), + ), + expr2.unparent(), + ), + ) + | on( + parent(lang.Expr) *> + field(TmpInfixGroup.withChildren: + defns.InfixOperator.instances.iterator + .map: op => + field(op) + ~ highPredInfixUnary(op) + ~ field(lang.Expr) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, left, right) => + left match + case (op, expr) => + splice( + TmpUnaryGroup( + op.unparent(), + lang.Expr( + TmpInfixGroup( + curOp.unparent(), + expr.unparent(), + right.unparent(), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(TmpInfixGroup.withChildren: + defns.InfixOperator.instances.iterator + .map: op => + field(op) + ~ field(lang.Expr) + ~ highPredInfixUnary(op) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, left, right) => + right match + case (op, expr) => + splice( + TmpUnaryGroup( + op.unparent(), + lang.Expr( + TmpInfixGroup( + curOp.unparent(), + left.unparent(), + expr.unparent(), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(TmpUnaryGroup.withChildren: + defns.PrefixOperator.instances.iterator + .map: op => + field(op) + ~ highPredUnaryInfix(op) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, infixGroup) => + infixGroup match + case (op, left, right) => + splice( + TmpInfixGroup( + op.unparent(), + lang.Expr( + TmpUnaryGroup( + curOp.unparent(), + left.unparent(), + ), + ), + right.unparent(), + ), + ) + | on( + parent(lang.Expr) *> + field(TmpUnaryGroup.withChildren: + defns.PostfixOperator.instances.iterator + .map: op => + field(op) + ~ highPredUnaryInfix(op) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, infixGroup) => + infixGroup match + case (op, left, right) => + splice( + TmpInfixGroup( + op.unparent(), + left.unparent(), + lang.Expr( + TmpUnaryGroup( + curOp.unparent(), + right.unparent(), + ), + ), + ), + ) + *> pass(once = false, strategy = pass.bottomUp) // assoc related errors + .rules: + on( + parent(lang.Expr) *> + field(TmpInfixGroup.withChildren: + defns.InfixOperator.instances.iterator + .map: op => + field(op) + ~ badPredInfixInfix(op) + ~ field(lang.Expr) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, left, right) => + left match + case (op, expr1, expr2) => + splice( + Node(Builtin.Error)( + Builtin.Error.Message( // todo: use src + s"$curOp and $op must have different precedence, or be duplicates of an associative operator.", + ), + Builtin.Error.AST( + TmpInfixGroup( + curOp.unparent(), + TmpInfixGroup( + op.unparent(), + expr1.unparent(), + expr2.unparent(), + ), + right.unparent(), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(TmpInfixGroup.withChildren: + defns.InfixOperator.instances.iterator + .map: op => + field(op) + ~ field(lang.Expr) + ~ badPredInfixInfix((op)) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, left, right) => + right match + case (op, expr1, expr2) => + splice( + Node(Builtin.Error)( + Builtin.Error.Message( + s"$curOp and $op must have different precedence, or be duplicates of an associative operator.", + ), + Builtin.Error.AST( + TmpInfixGroup( + curOp.unparent(), + left.unparent(), + TmpInfixGroup( + op.unparent(), + expr1.unparent(), + expr2.unparent(), + ), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(TmpInfixGroup.withChildren: + defns.InfixOperator.instances.iterator + .map: op => + field(op) + ~ badPredInfixUnary(op) + ~ field(lang.Expr) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, left, right) => + left match + case (op, expr) => + splice( + Node(Builtin.Error)( + Builtin.Error.Message( + s"$curOp and $op must have different precedence.", + ), + Builtin.Error.AST( + TmpInfixGroup( + curOp.unparent(), + TmpUnaryGroup( + op.unparent(), + expr.unparent(), + ), + ), + right.unparent(), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(TmpInfixGroup.withChildren: + defns.InfixOperator.instances.iterator + .map: op => + field(op) + ~ field(lang.Expr) + ~ badPredInfixUnary(op) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, left, right) => + right match + case (op, expr) => + splice( + Node(Builtin.Error)( + Builtin.Error.Message( + s"$curOp and $op must have different precedence.", + ), + Builtin.Error.AST( + TmpInfixGroup( + curOp.unparent(), + left.unparent(), + TmpUnaryGroup( + op.unparent(), + expr.unparent(), + ), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(TmpUnaryGroup.withChildren: + defns.PrefixOperator.instances.iterator + .map: op => + field(op) + ~ badPredUnaryInfix(op) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, infixGroup) => + infixGroup match + case (op, left, right) => + splice( + Builtin.Error( + s"$curOp and $op must have different precedence.", + curOp.unparent(), + TmpInfixGroup( + op.unparent(), + left.unparent(), + right.unparent(), + ), + ), + ) + | on( + parent(lang.Expr) *> + field(TmpUnaryGroup.withChildren: + defns.PostfixOperator.instances.iterator + .map: op => + field(op) + ~ badPredUnaryInfix(op) + ~ eof + .reduce(_ | _)) + ~ eof, + ).rewrite: (curOp, infixGroup) => + infixGroup match + case (op, left, right) => + splice( + Builtin.Error( + s"$curOp and $op must have different precedence.", + curOp.unparent(), + TmpInfixGroup( + op.unparent(), + left.unparent(), + right.unparent(), + ), + ), + ) + *> pass(once = false, strategy = pass.bottomUp) + .rules: + on( + parent(lang.Expr) *> + field( + TmpInfixGroup.withChildren( + field(tok(defns.InfixOperator.instances*)) + ~ field(lang.Expr) + ~ field(lang.Expr) + ~ eof, + ), + ) + ~ eof, + ).rewrite: (op, right, left) => + splice( + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(op.unparent()), + lang.Expr.OpCall.Params( + right.unparent(), + left.unparent(), + ), + ), + ), + ) + | on( + parent(lang.Expr) *> + field( + TmpUnaryGroup.withChildren( + field( + tok(defns.PrefixOperator.instances*) + | tok(defns.PostfixOperator.instances*), + ) + ~ field(lang.Expr) + ~ eof, + ), + ) + ~ eof, + ).rewrite: (op, expr) => + splice( + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(op.unparent()), + lang.Expr.OpCall.Params( + expr.unparent(), + ), + ), + ), + ) + end reorderOperations + + val removeNestedExpr = passDef: + wellformed := prevWellformed.makeDerived: + lang.Expr.removeCases(lang.Expr) + + pass(once = true, strategy = pass.bottomUp) + .rules: + on( + tok(lang.Expr) *> + onlyChild(lang.Expr), + ).rewrite: child => + splice( + child.unparent(), + ) + end removeNestedExpr diff --git a/langs/tla/ExprParser.test.scala b/langs/tla/ExprParser.test.scala new file mode 100644 index 0000000..941f5a2 --- /dev/null +++ b/langs/tla/ExprParser.test.scala @@ -0,0 +1,1193 @@ +// Copyright 2024-2025 Forja Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package forja.langs.tla + +import forja.* +import forja.dsl.* +import forja.source.{Source, SourceRange} + +import ExprParser.TmpInfixGroup + +// scala-cli test . -- '*ExprParser*' + +class ExprParserTests extends munit.FunSuite: + extension (str: String) + def parseStr: Node.Top = + val wrapped_str = + Source.fromString( + s"""---- MODULE TestMod ---- + |EXTENDS Naturals + |VARIABLE temp + | + |Init == $str + |==== + """.stripMargin, + ) + val top = TLAReader(SourceRange.entire(wrapped_str)) + TLAParser(top) + ExprParser(top) + Node.Top( + top(lang.Module)(lang.Module.Defns)(lang.Operator)( + lang.Expr, + ).unparentedChildren, + ) + extension (str: String) + def withoutParse: Node.Top = + val wrapped_str = + Source.fromString( + s"""---- MODULE TestMod ---- + |EXTENDS Naturals + |VARIABLE temp + | + |Init == $str + |==== + """.stripMargin, + ) + val top = TLAReader(SourceRange.entire(wrapped_str)) + TLAParser(top) + Node.Top( + top(lang.Module)(lang.Module.Defns)(lang.Operator)( + lang.Expr, + ).unparentedChildren, + ) + + extension (top: Node.Top) + def parseNode: Node.Top = + val freshTop = Node.Top( + lang.Module( + lang.Id("TestMod"), + lang.Module.Extends(), + lang.Module.Defns( + lang.Operator( + lang.Id("test"), + lang.Operator.Params(), + lang.Expr( + top.unparentedChildren, + ), + ), + ), + ), + ) + ExprParser(freshTop) + Node.Top( + freshTop(lang.Module)(lang.Module.Defns)(lang.Operator)( + lang.Expr, + ).unparentedChildren, + ) + + test("NumberLiteral"): + assertEquals("1".parseStr, Node.Top(lang.Expr.NumberLiteral("1"))) + + test("StringLiteral"): + assertEquals( + "\"string\"".parseStr, + Node.Top(lang.Expr.StringLiteral("string")), + ) + assertEquals( + "\"string\\nnewline\"".parseStr, + Node.Top(lang.Expr.StringLiteral("string\nnewline")), + ) + + test("Set Literal"): + assertEquals( + "{1, 2, 3}".parseStr, + Node.Top( + lang.Expr.SetLiteral( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ) + assertEquals( + "{}".parseStr, + Node.Top(lang.Expr.SetLiteral()), + ) + + test("TupleLiteral"): + assertEquals( + "<<1, 2, 3>>".parseStr, + Node.Top( + lang.Expr.TupleLiteral( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ) + assertEquals( + "<<>>".parseStr, + Node.Top( + lang.Expr.TupleLiteral(), + ), + ) + + test("RecordLiteral"): + assertEquals( + "[x |-> 2, y |-> 3]".parseStr, + Node.Top( + lang.Expr.RecordLiteral( + lang.Expr.RecordLiteral.Field( + lang.Id("x"), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + lang.Expr.RecordLiteral.Field( + lang.Id("y"), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ) + + test("Projection (Record Field Acess)"): + assertEquals( + "x.y".parseStr, + Node.Top( + lang.Expr.Project( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Id("y"), + ), + ), + ) + assertEquals( + "x.y.z".parseStr, + Node.Top( + lang.Expr.Project( + lang.Expr( + lang.Expr.Project( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Id("y"), + ), + ), + lang.Id("z"), + ), + ), + ) + assertEquals( + "[y |-> 2].y".parseStr, + Node.Top( + lang.Expr.Project( + lang.Expr( + lang.Expr.RecordLiteral( + lang.Expr.RecordLiteral.Field( + lang.Id("y"), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + ), + lang.Id("y"), + ), + ), + ) + + test("RecordSetLiteral"): + assertEquals( + "[x : {1, 2}]".parseStr, + Node.Top( + lang.Expr.RecordSetLiteral( + lang.Expr.RecordSetLiteral.Field( + lang.Id("x"), + lang.Expr( + lang.Expr.SetLiteral( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + ), + ), + ), + ) + + test("Name"): + assertEquals( + "x".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + ) + + test("Single Binary Operator"): + assertEquals( + "5 + 6".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym( + defns.+("+"), + ), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("5")), + lang.Expr(lang.Expr.NumberLiteral("6")), + ), + ), + ), + ) + assertEquals( + "5 $ x".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym(defns.$("$")), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("5")), + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + ), + ) + + test("Single Unary Operator"): + assertEquals( + "UNION A".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym(defns.UNION("UNION")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("A"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + ), + ) + assertEquals( + "x'".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym(defns.`'`("'")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + ), + ) + + test("Precedence: Infix - Infix"): + assertEquals( + "5 * 6 + 7".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym( + defns.+("+"), + ), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.OpSym( + defns.*("*"), + ), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("5")), + lang.Expr(lang.Expr.NumberLiteral("6")), + ), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("7")), + ), + ), + ), + ) + assertEquals( + "1 + 8 * 6 - 9 * 3".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym(defns.+("+")), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.-("-")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.*("*")), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("8")), + lang.Expr(lang.Expr.NumberLiteral("6")), + ), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.*("*")), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("9")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ) + + test("Precedence: Infix - Prefix"): + assertEquals( + "~x /\\ y".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym(defns./\("/\\")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.~("~")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.Id("y"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + ), + ) + assertEquals( + "~x \\in y".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym(defns.~("~")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.`\\in`("\\in")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.Id("y"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + ), + ), + ), + ), + ) + + test("Precedence: Infix - Postfix"): + assertEquals( + "1 + x' * 2".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym(defns.+("+")), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.*("*")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.`'`("'")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + ), + ), + ), + ), + ) + + test("Precedence: Unary - Unary"): + assertEquals( + "[] ~x".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym(defns.`[]`("[]")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.~("~")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + ), + ), + ), + ), + ) + + // TODO???: postfix postfix: never allowed + + test("Precedence: Associative"): + assertEquals( + "1 + 2 + 3".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym(defns.+("+")), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.+("+")), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ), + ), + ), + ) + assertEquals( + "1 * 2 / 3".parseStr, + Node.Top( + lang.Expr( + Builtin.Error( + s"forja.langs.tla.defns.* and forja.langs.tla.defns./ must have different precedence, or be duplicates of an associative operator.", + TmpInfixGroup( + defns.*("*"), + lang.Expr(lang.Expr.NumberLiteral("1")), + TmpInfixGroup( + defns./("/"), + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ), + ), + ) + + // test("Precedence: Error"): + // assertEquals( + // "x \\in A \\in B".parseStr, + // Node.Top(lang.Expr( + // Node(Builtin.Error)( + // Builtin.Error.Message( + /* s"26:distcompiler.tla.defns.\\in and 26:distcompiler.tla.defns.\\in must + * have different precedence, or be duplicates of an associative operator." */ + // ), + // Builtin.Error.AST( + // lang.Expr.TmpInfixGroup( + // defns.`\\in`("\\in"), + // lang.Expr(lang.Id("x")), + // lang.Expr.TmpInfixGroup( + // defns.`\\in`("\\in"), + // lang.Expr(lang.Id("A")), + // lang.Expr(lang.Id("B")) + // ))))))) + // assertEquals( + // "x \\in A = B".parseStr, + // Node.Top(lang.Expr( + // Node(Builtin.Error)( + // Builtin.Error.Message( + /* s"26:distcompiler.tla.defns.\\in and distcompiler.tla.defns.= must have + * different precedence, or be duplicates of an associative operator." */ + // ), + // Builtin.Error.AST( + // lang.Expr.TmpInfixGroup( + // defns.`\\in`("\\in"), + // lang.Expr(lang.Id("x")), + // lang.Expr.TmpInfixGroup( + // defns.`=`("="), + // lang.Expr(lang.Id("A")), + // lang.Expr(lang.Id("B")) + // ))))))) + + test("OpCall"): + assertEquals( + "testFun(1, 2, 3)".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.Id("testFun"), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ) + + // TODO: conjunction alignment + + test("FnCall"): + assertEquals( + "x[\"y\"]".parseStr, + Node.Top( + lang.Expr.FnCall( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.StringLiteral("y")), + ), + ), + ) + assertEquals( + "x[1, 2, 3]".parseStr, + Node.Top( + lang.Expr.FnCall( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr( + lang.Expr.TupleLiteral( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ), + ) + + test("If"): + assertEquals( + """IF A + |THEN 1 + |ELSE 2""".stripMargin.parseStr, + Node.Top( + lang.Expr.If( + lang.Expr( + lang.Expr.OpCall( + lang.Id("A"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + ) + + test("Case"): + assertEquals( + """CASE A -> 1 + | [] B -> 2""".stripMargin.parseStr, + Node.Top( + lang.Expr.Case( + lang.Expr.Case.Branches( + lang.Expr.Case.Branch( + lang.Expr( + lang.Expr.OpCall( + lang.Id("A"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("1")), + ), + lang.Expr.Case.Branch( + lang.Expr( + lang.Expr.OpCall( + lang.Id("B"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + lang.Expr.Case.Other(lang.Expr.Case.Other.None()), + ), + ), + ) + assertEquals( + """CASE A -> 1 + | [] B -> 2 + | OTHER -> 3""".stripMargin.parseStr, + Node.Top( + lang.Expr.Case( + lang.Expr.Case.Branches( + lang.Expr.Case.Branch( + lang.Expr( + lang.Expr.OpCall( + lang.Id("A"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("1")), + ), + lang.Expr.Case.Branch( + lang.Expr( + lang.Expr.OpCall( + lang.Id("B"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + lang.Expr.Case.Other( + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ) + + test("LET"): + assertEquals( + """LET x == 1 + |y == 2 + |IN {y, x}""".stripMargin.parseStr, + Node.Top( + lang.Expr.Let( + lang.Expr.Let.Defns( + lang.Operator( + lang.Id("x"), + lang.Operator.Params(), + lang.Expr(lang.Expr.NumberLiteral("1")), + ), + lang.Operator( + lang.Id("y"), + lang.Operator.Params(), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + lang.Expr( + lang.Expr.SetLiteral( + lang.Expr( + lang.Expr.OpCall( + lang.Id("y"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + ), + ), + ) + + test("Exists"): + assertEquals( + "\\E x \\in {1, 2, 3} : x = 2".parseStr, + Node.Top( + lang.Expr.Exists( + lang.QuantifierBounds( + lang.QuantifierBound( + lang.Id("x"), + lang.Expr( + lang.Expr.SetLiteral( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.`=`("=")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + ), + ), + ), + ) + assertEquals( + "\\E <> \\in {1, 2, 3} : x = 2".parseStr, + Node.Top( + lang.Expr.Exists( + lang.QuantifierBounds( + lang.QuantifierBound( + lang.Ids( + lang.Id("x"), + lang.Id("y"), + lang.Id("z"), + ), + lang.Expr( + lang.Expr.SetLiteral( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.`=`("=")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + ), + ), + ), + ) + assertEquals( + "\\E x \\in {1, 2, 3}, y \\in {4, 5, 6} : x = 2".parseStr, + Node.Top( + lang.Expr.Exists( + lang.QuantifierBounds( + lang.QuantifierBound( + lang.Id("x"), + lang.Expr( + lang.Expr.SetLiteral( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + lang.QuantifierBound( + lang.Id("y"), + lang.Expr( + lang.Expr.SetLiteral( + lang.Expr(lang.Expr.NumberLiteral("4")), + lang.Expr(lang.Expr.NumberLiteral("5")), + lang.Expr(lang.Expr.NumberLiteral("6")), + ), + ), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.`=`("=")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + ), + ), + ), + ) + + test("Forall"): + assertEquals( + "\\A x \\in {1, 2, 3} : x = 2".parseStr, + Node.Top( + lang.Expr.Forall( + lang.QuantifierBounds( + lang.QuantifierBound( + lang.Id("x"), + lang.Expr( + lang.Expr.SetLiteral( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.`=`("=")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + ), + ), + ), + ) + + // TODO: AA EE ?? + + test("Temporal Logic Combined"): + assertEquals( + "s /\\ \\E x \\in y : z /\\ \\E p \\in q : r".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym(defns./\("/\\")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("s"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr( + lang.Expr.Exists( + lang.QuantifierBounds( + lang.QuantifierBound( + lang.Id("x"), + lang.Expr( + lang.Expr.OpCall( + lang.Id("y"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns./\("/\\")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("z"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr( + lang.Expr.Exists( + lang.QuantifierBounds( + lang.QuantifierBound( + lang.Id("p"), + lang.Expr( + lang.Expr.OpCall( + lang.Id("q"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.Id("r"), + lang.Expr.OpCall.Params(), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ) + + // todo: function + // todo: set comprehension + // todo: set refinement + // todo: lambda + + test("Choose"): + assertEquals( + "CHOOSE x \\in {1, 2, 3} : x = 2".parseStr, + Node.Top( + lang.Expr.Choose( + lang.QuantifierBound( + lang.Id("x"), + lang.Expr( + lang.Expr.SetLiteral( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + lang.Expr( + lang.Expr.OpCall( + lang.OpSym(defns.`=`("=")), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + ), + ), + ), + ) + // todo: id nil, tuple expr, tuple nil + + // todo: except + + test("Parentheses"): + assertEquals( + "(1)".parseStr, + Node.Top(lang.Expr.NumberLiteral("1")), + ) + assertEquals( + "(((x)))".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.Id("x"), + lang.Expr.OpCall.Params(), + ), + ), + ) + assertEquals( + "(5 + 6)".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym( + defns.+("+"), + ), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("5")), + lang.Expr(lang.Expr.NumberLiteral("6")), + ), + ), + ), + ) + assertEquals( + "(1 + 2) * 3".parseStr, + Node.Top( + lang.Expr.OpCall( + lang.OpSym( + defns.*("*"), + ), + lang.Expr.OpCall.Params( + lang.Expr( + lang.Expr.OpCall( + lang.OpSym( + defns.+("+"), + ), + lang.Expr.OpCall.Params( + lang.Expr(lang.Expr.NumberLiteral("1")), + lang.Expr(lang.Expr.NumberLiteral("2")), + ), + ), + ), + lang.Expr(lang.Expr.NumberLiteral("3")), + ), + ), + ), + ) + // assertEquals( + // "\\A x \\in {1, 2, 3} : x = (2)".parseStr, + // Node.Top( + // lang.Expr.Forall( + // lang.QuantifierBounds( + // lang.QuantifierBound( + // lang.Id("x"), + // lang.Expr(lang.Expr.SetLiteral( + // lang.Expr(lang.Expr.NumberLiteral("1")), + // lang.Expr(lang.Expr.NumberLiteral("2")), + // lang.Expr(lang.Expr.NumberLiteral("3")) + // )))), + // lang.Expr(lang.Expr.OpCall( + // lang.OpSym(defns.`=`("=")), + // lang.Expr.OpCall.Params( + // lang.Expr(lang.Expr.OpCall( + // lang.Id("x"), + // lang.Expr.OpCall.Params() + // )), + // lang.Expr(lang.Expr.NumberLiteral("2")) + // )))))) + + // test("Conjunctions"): + // assertEquals( + // "/\\ 1".parseStr, + // Node.Top( + // lang.Expr.NumberLiteral("1") + // )) + // assertEquals( + // "/\\ 1 \\/ 2".parseStr, + // Node.Top( + // lang.Expr.OpCall( + // lang.OpSym(defns.\/("\\/")), + // lang.Expr.OpCall.Params( + // lang.Expr(lang.Expr.NumberLiteral("1")), + // lang.Expr(lang.Expr.NumberLiteral("2")) + // )))) + // assertEquals( + // s""" + // |/\\ 1 + // |/\\ 2 + // |""".stripMargin.parseStr, + // Node.Top( + // lang.Expr.OpCall( + // lang.OpSym(defns./\("/\\")), + // lang.Expr.OpCall.Params( + // lang.Expr(lang.Expr.NumberLiteral("1")), + // lang.Expr(lang.Expr.NumberLiteral("2")) + // )))) + // assertEquals( + // s""" + // |/\\ 1 \\/ 1 + // |/\\ 2 \\/ 2 + // |/\\ 3 \\/ 3 + // |""".stripMargin.parseStr, + // Node.Top(lang.Expr.OpCall( + // lang.OpSym(defns./\("/\\")), + // lang.Expr.OpCall.Params( + // lang.Expr(lang.Expr.OpCall( + // lang.OpSym(defns.\/("\\/")), + // lang.Expr.OpCall.Params( + // lang.Expr(lang.Expr.NumberLiteral("1")), + // lang.Expr(lang.Expr.NumberLiteral("1")) + // ) + // )), + // lang.Expr(lang.Expr.OpCall( + // lang.OpSym(defns./\("/\\")), + // lang.Expr.OpCall.Params( + // lang.Expr(lang.Expr.OpCall( + // lang.OpSym(defns.\/("\\/")), + // lang.Expr.OpCall.Params( + // lang.Expr(lang.Expr.NumberLiteral("2")), + // lang.Expr(lang.Expr.NumberLiteral("2")) + // ))), + // lang.Expr(lang.Expr.OpCall( + // lang.OpSym(defns.\/("\\/")), + // lang.Expr.OpCall.Params( + // lang.Expr(lang.Expr.NumberLiteral("3")), + // lang.Expr(lang.Expr.NumberLiteral("3")) + // )))))))))) + + // assertEquals( + // s""" + // |/\\ 1 + // | /\\ 2 + // |/\\ 3 + // |""".stripMargin.parseStr, + // Node.Top( + // lang.Expr.OpCall( + // lang.OpSym(defns./\("/\\")), + // lang.Expr.OpCall.Params( + // lang.Expr(lang.Expr.OpCall( + // lang.OpSym(defns./\("/\\")), + // lang.Expr.OpCall.Params( + // lang.Expr(lang.Expr.NumberLiteral("1")), + // lang.Expr(lang.Expr.NumberLiteral("2")) + // ) + // )), + // lang.Expr(lang.Expr.NumberLiteral("3")) + // ) + // ) + // )) + + // assertEquals( + // s""" + // |/\\ \\E x \in {1, 2, 3} : 1 + // | /\\ 2 + // |""".stripMargin.parseStr, + // Node.Top( + // lang.Expr.NumberLiteral("1") + // )) + + // assertEquals( + // s""" + // |/\\ \\E x \in {1, 2, 3} : 1 + // |/\\ 2 + // |""".stripMargin.parseStr, + // Node.Top( + // lang.Expr.NumberLiteral("1") + // )) + + // assertEquals( + // s""" + // | /\\ \\E x \in {1, 2, 3} : 1 + // |/\\ 2 + // |""".stripMargin.parseStr, + // Node.Top( + // lang.Expr.NumberLiteral("1") + // )) + + // assertEquals( + // s""" + // |/\\ 1 + // | /\\ 2 + // | /\\ 3 + // | /\\ 4 + // |/\\ 5 + // | /\\ 6 + // |""".stripMargin.parseStr, + // Node.Top( + // lang.Expr.NumberLiteral("1") + // )) + +// TODO +// paren +// conjucntion +// lamda function set diff --git a/langs/tla/TLAParser.scala b/langs/tla/TLAParser.scala index aafb8ab..3f0f577 100644 --- a/langs/tla/TLAParser.scala +++ b/langs/tla/TLAParser.scala @@ -177,6 +177,72 @@ object TLAParser extends PassSeq: ).void nodeSpanMatchedBy(impl).map(RawExpression.apply) + end rawExpression + + def rawConjunction(col: Integer): SeqPattern[RawExpression] = + val simpleCases: SeqPattern[Unit] = + anyChild.void <* not( + tok(expressionDelimiters*) + | tok(defns./\).filter(node => + val s = node.sourceRange + s.source.lines.lineColAtOffset(s.offset)._2 > col, + ) + // stop at operator definitions: all valid patterns leading to == here + | operatorDefnBeginnings, + ) + + lazy val quantifierBound: SeqPattern[EmptyTuple] = + skip( + tok(TupleGroup).as(EmptyTuple) + | repeatedSepBy1(`,`)(Alpha), + ) + ~ skip(defns.`\\in`) + ~ skip(defer(impl)) + ~ trailing + + lazy val quantifierBounds: SeqPattern[Unit] = + repeatedSepBy1(`,`)(quantifierBound).void + + lazy val forallExists: SeqPattern[EmptyTuple] = + skip( + tok(LaTexLike).src("\\A") | tok(LaTexLike).src("\\AA") | tok(LaTexLike) + .src("\\E") | tok(LaTexLike).src("\\EE"), + ) + ~ skip(quantifierBounds | repeatedSepBy1(`,`)(Alpha)) + ~ skip(`:`) + ~ trailing + + lazy val choose: SeqPattern[EmptyTuple] = + skip(defns.CHOOSE) + ~ skip(quantifierBound | repeatedSepBy1(`,`)(Alpha)) + ~ skip(`:`) + ~ trailing + + lazy val lambda: SeqPattern[EmptyTuple] = + skip(defns.LAMBDA) + ~ skip(repeatedSepBy1(`,`)(Alpha)) + ~ skip(`:`) + ~ trailing + + // this is an over-approximation of what can show up + // in an identifier prefix + // TODO: why does the grammar make it look like it goes the other way round? + lazy val idFrag: SeqPattern[EmptyTuple] = + skip(`!`) + ~ skip(`:`) + ~ trailing + + lazy val impl: SeqPattern[Unit] = + repeated1( + forallExists + | choose + | lambda + | idFrag + | simpleCases, // last, otherwise it eats parts of the above + ).void + + nodeSpanMatchedBy(impl).map(RawExpression.apply) + end rawConjunction val proofDelimiters: Seq[Token] = Seq( defns.ASSUME, @@ -715,66 +781,3 @@ object TLAParser extends PassSeq: ~ trailing, ).rewrite: (local, op) => splice(lang.Local(op.unparent()).like(local)) - - // TODO: finish expression parsing - // val buildExpressions = passDef: - // wellformed := prevWellformed.makeDerived: - // val removedCases = Seq( - // TLAReader.StringLiteral, - // TLAReader.NumberLiteral, - // TLAReader.TupleGroup, - // ) - // lang.Module.Defns.removeCases(removedCases*) - // lang.Module.Defns.addCases(lang.Expr) - // TLAReader.groupTokens.foreach: tok => - // tok.removeCases(removedCases*) - // tok.addCases(lang.Expr) - - // lang.Expr.importFrom(tla.wellformed) - // lang.Expr.addCases(lang.TmpGroupExpr) - - // lang.TmpGroupExpr ::= lang.Expr - - // pass(once = false, strategy = pass.bottomUp) - // .rules: - // on( - // TLAReader.StringLiteral - // ).rewrite: lit => - // splice(lang.Expr(lang.Expr.StringLiteral().like(lit))) - // | on( - // TLAReader.NumberLiteral - // ).rewrite: lit => - // splice(lang.Expr(lang.Expr.NumberLiteral().like(lit))) - // | on( - // field(TLAReader.Alpha) - // ~ field( - // tok(TLAReader.ParenthesesGroup) *> children( - // repeatedSepBy(`,`)(lang.Expr) - // ) - // ) - // ~ trailing - // ).rewrite: (name, params) => - // splice(lang.Expr(lang.Expr.OpCall( - // lang.Id().like(name), - // lang.Expr.OpCall.Params(params.iterator.map(_.unparent())), - // ))) - // | on( - // TLAReader.Alpha - // ).rewrite: name => - // splice(lang.Expr(lang.Expr.OpCall( - // lang.Id().like(name), - // lang.Expr.OpCall.Params(), - // ))) - // | on( - // tok(TLAReader.ParenthesesGroup) *> onlyChild(lang.Expr) - // ).rewrite: expr => - /* // mark this group as an expression, but leave evidence that it is a group - * (for operator precedence handling) */ - // splice(lang.Expr(lang.TmpGroupExpr(expr.unparent()))) - // | on( - // tok(TLAReader.TupleGroup).product(children( - // field(repeatedSepBy(`,`)(lang.Expr)) - // ~ eof - // )) - // ).rewrite: (lit, elems) => - // splice(lang.Expr(lang.Expr.TupleLiteral(elems.iterator.map(_.unparent())))) diff --git a/langs/tla/TLAParser.test.scala b/langs/tla/TLAParser.test.scala index f6dbffa..1b31489 100644 --- a/langs/tla/TLAParser.test.scala +++ b/langs/tla/TLAParser.test.scala @@ -16,8 +16,9 @@ package forja.langs.tla import forja.* import forja.source.{Source, SourceRange} +import forja.test.WithTLACorpus -class TLAParserTests extends munit.FunSuite, test.WithTLACorpus: +class TLAParserTests extends munit.FunSuite, WithTLACorpus: self => /* TODO: skip the TLAPS files; parsing that seems like a waste of time for @@ -37,6 +38,7 @@ class TLAParserTests extends munit.FunSuite, test.WithTLACorpus: // ) // , tracer = Manip.RewriteDebugTracer(os.pwd / "dbg_passes") ) + ExprParser(top) // re-enable if interesting: // val folder = os.SubPath(file.subRelativeTo(clonesDir).segments.init) diff --git a/langs/tla/defns.scala b/langs/tla/defns.scala index 0372867..fb8dbd1 100644 --- a/langs/tla/defns.scala +++ b/langs/tla/defns.scala @@ -74,7 +74,9 @@ object defns: case object PROPOSITION extends ReservedWord case object ONLY extends ReservedWord - sealed trait Operator extends Token, HasSpelling + sealed trait Operator extends Token, HasSpelling: + def highPrecedence: Int + def lowPrecedence: Int object Operator: lazy val instances: IArray[Operator] = @@ -100,7 +102,7 @@ object defns: sealed trait InfixOperator( val lowPrecedence: Int, - val highPredecence: Int, + val highPrecedence: Int, val isAssociative: Boolean = false, ) extends Operator object InfixOperator extends util.HasInstanceArray[InfixOperator] @@ -209,7 +211,9 @@ object defns: case object `\\supset` extends InfixOperator(5, 5) case object `%%` extends InfixOperator(10, 11, true) - sealed trait PostfixOperator(val predecence: Int) extends Operator + sealed trait PostfixOperator(val precedence: Int) extends Operator: + def highPrecedence: Int = precedence + def lowPrecedence: Int = precedence object PostfixOperator extends util.HasInstanceArray[PostfixOperator] case object `^+` extends PostfixOperator(15) diff --git a/langs/tla/package.scala b/langs/tla/package.scala index be312f0..492d348 100644 --- a/langs/tla/package.scala +++ b/langs/tla/package.scala @@ -100,6 +100,7 @@ object lang extends WellformedDef: Expr.SetLiteral, Expr.TupleLiteral, Expr.RecordLiteral, + Expr.RecordSetLiteral, Expr.Project, Expr.OpCall, Expr.FnCall, @@ -180,7 +181,8 @@ object lang extends WellformedDef: ), ) - object Case extends t(repeated(Case.Branch, minCount = 1)): + object Case extends t(fields(Case.Branches, Case.Other)): + object Branches extends t(repeated(Branch, minCount = 1)) object Branch extends t( fields( @@ -188,6 +190,11 @@ object lang extends WellformedDef: Expr, ), ) + object Other + extends t( + choice(Expr, Other.None), + ): + object None extends t(Atom) end Case object Let diff --git a/src/Builtin.scala b/src/Builtin.scala index 7772d1f..a1e60f3 100644 --- a/src/Builtin.scala +++ b/src/Builtin.scala @@ -18,7 +18,7 @@ import cats.syntax.all.given object Builtin: object Error extends Token: - def apply(msg: String, ast: Node.Child): Node = + def apply(msg: String, ast: Node.Child*): Node = Error( Error.Message().at(msg), Error.AST(ast), diff --git a/src/Node.scala b/src/Node.scala index 953b660..015fb44 100644 --- a/src/Node.scala +++ b/src/Node.scala @@ -391,7 +391,7 @@ object Node: require( results.size == 1, - s"token(s) not found ${(tok +: toks).map(_.name).mkString(", ")}", + s"token(s) not found ${(tok +: toks).map(_.name).mkString(", ")}, in ${this.toShortString()}", ) results.head diff --git a/src/PassSeq.scala b/src/PassSeq.scala index 9625802..d2b38d5 100644 --- a/src/PassSeq.scala +++ b/src/PassSeq.scala @@ -24,15 +24,16 @@ import scala.collection.mutable transparent trait PassSeq: def inputWellformed: Wellformed final def outputWellformed: Wellformed = - assert(entriesSealed) + allPasses // compute allPasses, needed to enforce complete initialization entries.last.wellformed private val entries = mutable.ListBuffer.empty[PassSeq.Entry] private var entriesSealed = false protected def prevWellformed(using BuildCtx): Wellformed = - require(entries.nonEmpty, "there is no previous Wellformed") - entries.last.wellformed + if entries.isEmpty + then inputWellformed + else entries.last.wellformed protected object wellformed: def :=(using ctx: BuildCtx)(wellformed: Wellformed): Unit = diff --git a/src/wf/Wellformed.scala b/src/wf/Wellformed.scala index cf633d4..47c97b9 100644 --- a/src/wf/Wellformed.scala +++ b/src/wf/Wellformed.scala @@ -166,11 +166,9 @@ final class Wellformed private ( done: val wrongSize = parent.children.size parent.children = List( - Node(Builtin.Error)( - Builtin.Error.Message( - s"$desc should have exactly ${fields.size} children, but it had $wrongSize instead", - ), - Builtin.Error.AST(parent.unparentedChildren), + Builtin.Error( + s"$desc should have exactly ${fields.size} children, but it had $wrongSize instead", + parent.unparentedChildren.toSeq*, ), ) else @@ -206,11 +204,9 @@ final class Wellformed private ( done: val wrongSize = parent.children.size parent.children = List( - Node(Builtin.Error)( - Builtin.Error.Message( - s"$desc should have at least $minCount children, but it had $wrongSize instead", - ), - Builtin.Error.AST(parent.unparentedChildren), + Builtin.Error( + s"$desc should have at least $minCount children, but it had $wrongSize instead", + parent.unparentedChildren.toSeq*, ), ) else @@ -531,6 +527,15 @@ object Wellformed: s"$token's shape is not appropriate for adding cases ($shape)", ) + def deleteShape(): Unit = + token match + case Node.Top => + require(topShapeOpt.nonEmpty) + topShapeOpt = None + case token: Token => + require(assigns.contains(token)) + assigns.remove(token) + def importFrom(wf2: Wellformed): Unit = def fillFromShape(shape: Shape): Unit = shape match @@ -557,6 +562,7 @@ object Wellformed: topShapeOpt = Some(wf2.topShape) fillFromShape(wf2.topShape) case token: Token => fillFromTokenOrEmbed(token) + end importFrom private[forja] def build(): Wellformed = require(topShapeOpt.nonEmpty)