diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index 07864aca56a..bb58bcd1415 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION. +# Copyright (c) 2022-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -422,6 +422,73 @@ def test_regexp_replace(): 'regexp_replace(a, "a|b|c", "A")'), conf=_regexp_conf) + +# https://github.com/NVIDIA/spark-rapids/issues/14742 +# Replacement-string parser must match java.util.regex.Matcher#appendReplacement. +# We deliberately use the DataFrame API (not selectExpr) because Spark SQL's +# Hive-inherited variable substitution (spark.sql.variable.substitute=true by +# default) silently expands ${...} inside SQL string literals at parse time, +# which would mask sub-bugs 4 and 5. +# The pyspark Spark-3.3 signature is regexp_replace(col, pattern: str, replacement: str) +# where pattern and replacement are passed verbatim as Python strings. +def test_regexp_replace_subbug1_backslash_digit_is_literal_14742(): + # \1 in the replacement is the literal character "1" per Java spec, + # not a backref. Both CPU and GPU should produce "1bc". + from pyspark.sql.functions import regexp_replace, col + assert_gpu_and_cpu_are_equal_collect( + lambda spark: spark.createDataFrame([("abc",)], ["a"]).select( + regexp_replace(col("a"), "(a)", "\\1")), + conf=_regexp_conf) + + +@allow_non_gpu('ProjectExec', 'RegExpReplace') +def test_regexp_replace_subbug2_trailing_backslash_throws_14742(): + # A trailing bare \ in the replacement must throw on both engines. + # The spark-rapids parser raises RegexUnsupportedException during meta + # tagging, which forces a CPU fallback; Spark's CPU Matcher then throws. + from pyspark.sql.functions import regexp_replace, col + assert_gpu_and_cpu_error( + lambda spark: spark.createDataFrame([("a",)], ["a"]).select( + regexp_replace(col("a"), "a", "\\")).collect(), + conf=_regexp_conf, + error_message="character to be escaped is missing") + + +@allow_non_gpu('ProjectExec', 'RegExpReplace') +def test_regexp_replace_subbug3_dollar_non_digit_throws_14742(): + # $X for non-digit X must throw on both engines. GPU falls back to CPU. + from pyspark.sql.functions import regexp_replace, col + assert_gpu_and_cpu_error( + lambda spark: spark.createDataFrame([("a",)], ["a"]).select( + regexp_replace(col("a"), "a", "$x")).collect(), + conf=_regexp_conf, + error_message="Illegal group reference") + + +@allow_non_gpu('ProjectExec', 'RegExpReplace') +def test_regexp_replace_subbug4_digit_leading_named_group_throws_14742(): + # ${1} has a digit-leading "name" and must throw on both engines. + # GPU falls back to CPU; CPU raises Java's exact error message. + from pyspark.sql.functions import regexp_replace, col + assert_gpu_and_cpu_error( + lambda spark: spark.createDataFrame([("a",)], ["a"]).select( + regexp_replace(col("a"), "(a)", "${1}")).collect(), + conf=_regexp_conf, + error_message="capturing group name {1} starts with digit character") + + +@allow_non_gpu('ProjectExec', 'RegExpReplace') +def test_regexp_replace_subbug5_unknown_named_group_throws_14742(): + # ${name} with no such named group must throw on both engines. + # GPU falls back to CPU because spark-rapids does not support named-group + # references; CPU then raises "No group with name {name}". + from pyspark.sql.functions import regexp_replace, col + assert_gpu_and_cpu_error( + lambda spark: spark.createDataFrame([("a",)], ["a"]).select( + regexp_replace(col("a"), "(a)", "${name}")).collect(), + conf=_regexp_conf, + error_message="No group with name") + @pytest.mark.skipif(is_before_spark_320(), reason='regexp is synonym for RLike starting in Spark 3.2.0') def test_regexp(): gen = mk_str_gen('[abcd]{1,3}') diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index ff1d858002b..727a987ff1a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -363,54 +363,101 @@ class RegexParser(pattern: String) { } } + /** + * Parse the character that follows a `\` in a replacement string. + * + * Matches the semantics of `java.util.regex.Matcher#appendReplacement`: + * - `\X` for any character X (including digits) emits the literal character X. + * - A trailing bare `\` (end-of-string after `\`) raises an error matching + * Java's `IllegalArgumentException("character to be escaped is missing")`. + * + * The AST shape returned for `\X` is a two-node sequence containing + * `RegexChar('\\') + RegexChar(X)`. Keeping the leading backslash in the AST + * preserves the existing transpile pipeline contract: downstream + * `GpuRegExpUtils.backrefConversion` plus `unescapeReplaceString` strip one + * backslash per `\X` pair when producing the cuDF replacement string. The + * key correctness change is that we no longer emit `RegexBackref` for + * `\digit`, which is what diverged from Java's spec (Java treats `\digit` + * as the literal digit in the replacement, not a backref). + * + * Note: the leading `\` has already been consumed by `parseReplacementBase`. + */ private def parseBackrefOrEscaped(): RegexAST = { - val start = pos - - consumeInt match { - case Some(refNum) => - RegexBackref(refNum) + peek() match { case None => - pos = start - RegexChar('\\') + throw new RegexUnsupportedException( + "character to be escaped is missing", Some(pos)) + case Some(_) => + RegexSequence(ListBuffer(RegexChar('\\'), RegexChar(consume()))) } } + /** + * Parse the character(s) that follow a `$` in a replacement string. + * + * Matches the semantics of `java.util.regex.Matcher#appendReplacement`: + * - `$` followed by one or more digits is a numbered backref (longest run). + * - `${name}` with name matching `[A-Za-z][A-Za-z0-9]*` is a named-group + * backref. spark-rapids' regex layer does not support named groups in the + * pattern, so well-formed `${name}` references raise an unsupported + * exception rather than executing on the GPU. + * - Any other shape (`$` followed by EOF, by a non-digit non-`{` character, + * by `${` with a missing or malformed name, or by `${name}` with no + * matching `}`) raises an error matching Java's + * `IllegalArgumentException("Illegal group reference")`. + * + * Note: the leading `$` has already been consumed by `parseReplacementBase`. + */ private def parseBackrefOrLiteralDollar(): RegexAST = { - val start = pos - - def treatAsLiteralDollar() = { - pos = start - RegexChar('$') - } - peek() match { case Some('{') => consumeExpected('{') - val num = consumeInt() - if (peek().contains('}')) { - consumeExpected('}') - num match { - case Some(n) => - RegexBackref(n) - case _ => - treatAsLiteralDollar() - } - } else { - treatAsLiteralDollar() - } - case Some(ch) if ch >= '1' && ch <= '9' => - val num = consumeInt() - num match { - case Some(n) => - RegexBackref(n) + peek() match { + case Some(ch) if isAsciiDigit(ch) => + // Java: name starts with a digit -> throw. + throw new RegexUnsupportedException( + "Illegal group reference: group name starts with digit character", + Some(pos)) + case Some(ch) if isLetter(ch) => + // Named-group reference. Consume the name then require `}`. + val nameStart = pos + while (!eof() && peek().exists(c => isLetter(c) || isAsciiDigit(c))) { + skip() + } + val name = pattern.substring(nameStart, pos) + if (!peek().contains('}')) { + throw new RegexUnsupportedException( + "Illegal group reference: malformed " + "$" + "{name} reference", + Some(pos)) + } + consumeExpected('}') + // cuDF has no support for named groups in the pattern, so any + // syntactically valid `${name}` reference is unsupported on the GPU. + // Throwing RegexUnsupportedException causes a CPU fallback (where + // Java will then raise "No group with name {name}" if the name is + // unknown), preserving the user-visible error. + throw new RegexUnsupportedException( + s"named-group reference $${$name} is not supported on the GPU", + Some(pos)) case _ => - treatAsLiteralDollar() + throw new RegexUnsupportedException( + "Illegal group reference: empty or malformed " + "$" + "{name}", + Some(pos)) } + case Some(ch) if isAsciiDigit(ch) => + RegexBackref(consumeInt().get) case _ => - treatAsLiteralDollar() + // Java: bare `$` not followed by a digit or `{` -> throw. + throw new RegexUnsupportedException( + "Illegal group reference", Some(pos)) } } + private def isLetter(ch: Char): Boolean = + (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') + + private def isAsciiDigit(ch: Char): Boolean = ch >= '0' && ch <= '9' + private def parseEscapedCharacter(): RegexAST = { peek() match { case None => @@ -617,7 +664,8 @@ class RegexParser(pattern: String) { private def consumeInt(): Option[Int] = { val start = pos - while (!eof() && peek().exists(_.isDigit)) { + // ASCII only: `substring.toInt` rejects Unicode digit codepoints. + while (!eof() && peek().exists(c => c >= '0' && c <= '9')) { skip() } if (start == pos) { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 33644982cf1..0906ad6d401 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1073,22 +1073,28 @@ object GpuRegExpUtils { } /** - * Convert symbols of back-references if input string contains any. - * In spark's regex rule, there are two patterns of back-references: - * \group_index and \$group_index - * This method transforms above two patterns into cuDF pattern \${group_index}, except they are - * preceded by escape character. + * Convert numbered back-reference symbols into cuDF's `${group_index}` form. * - * @param rep replacement string - * @return A pair consists of a boolean indicating whether containing any backref and the - * converted replacement. + * Java `Matcher#appendReplacement` recognizes only the `$digit` form as a + * numbered backref. The `\digit` form is the literal digit character + * (Java's `\X` escape just emits X). The escaped `\$digit` is the literal + * `$digit`. + * + * This method is called after `RegexParser.parseReplacement`, which emits + * `\X` pairs verbatim into the replacement-string AST. We therefore: + * - rewrite an unescaped `$digit...` to `${digits}` (a cuDF backref); + * - leave `\digit` alone (it is a literal digit, not a backref) — the + * following `unescapeReplaceString` will strip the leading `\`; + * - leave any other `\X` pair alone (same reason). + * + * @param rep replacement string (already normalized by RegexParser) + * @return A pair (containsBackref, converted-replacement). */ def backrefConversion(rep: String): (Boolean, String) = { val b = new StringBuilder var i = 0 while (i < rep.length) { - // match $group_index or \group_index - if (Seq('$', '\\').contains(rep.charAt(i)) + if (rep.charAt(i) == '$' && i + 1 < rep.length && rep.charAt(i + 1).isDigit) { b.append("${") @@ -1100,7 +1106,7 @@ object GpuRegExpUtils { b.append("}") i = j } else if (rep.charAt(i) == '\\' && i + 1 < rep.length) { - // skip potential \$group_index or \\group_index + // \X pair: keep verbatim; unescapeReplaceString will strip the \ b.append('\\').append(rep.charAt(i + 1)) i += 2 } else { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index b76ad5679e9..0791f7402c2 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2025, NVIDIA CORPORATION. + * Copyright (c) 2021-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -283,6 +283,104 @@ class RegularExpressionParserSuite extends AnyFunSuite { RegexChar('$')))) } + // ------------------------------------------------------------------ + // Replacement-string parser: Java appendReplacement spec (issue #14742) + // ------------------------------------------------------------------ + + test("issue-14742-subbug1: \\N in replacement is the literal character N, not a backref") { + // Java Matcher#appendReplacement: \X for any X (digit or not) emits X literally. + // The AST keeps the backslash so the downstream cuDF replacement pipeline + // (backrefConversion + unescapeReplaceString) collapses \X -> X without + // converting \digit into a numbered backref. + val repl = new RegexParser("\\1").parseReplacement(numCaptureGroups = 1) + assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('1'))) + } + + test("issue-14742-subbug1: \\a in replacement is the literal character a") { + val repl = new RegexParser("\\a").parseReplacement(numCaptureGroups = 0) + assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('a'))) + } + + test("issue-14742-subbug2: trailing \\ in replacement throws") { + val ex = intercept[RegexUnsupportedException] { + new RegexParser("\\").parseReplacement(numCaptureGroups = 0) + } + assert(ex.getMessage.contains("character to be escaped is missing")) + } + + test("issue-14742-subbug3: bare $X for non-digit X throws") { + val ex = intercept[RegexUnsupportedException] { + new RegexParser("$x").parseReplacement(numCaptureGroups = 0) + } + assert(ex.getMessage.contains("Illegal group reference")) + } + + test("issue-14742-subbug3: trailing bare $ throws") { + val ex = intercept[RegexUnsupportedException] { + new RegexParser("$").parseReplacement(numCaptureGroups = 0) + } + assert(ex.getMessage.contains("Illegal group reference")) + } + + test("issue-14742-subbug4: dollar-brace-digit-brace throws") { + val ex = intercept[RegexUnsupportedException] { + new RegexParser("$" + "{1}").parseReplacement(numCaptureGroups = 1) + } + assert(ex.getMessage.contains("Illegal group reference")) + assert(ex.getMessage.contains("digit")) + } + + test("issue-14742-subbug5: dollar-brace-name-brace for named group is not supported on GPU") { + val ex = intercept[RegexUnsupportedException] { + new RegexParser("$" + "{name}").parseReplacement(numCaptureGroups = 1) + } + assert(ex.getMessage.contains("named-group reference")) + } + + test("issue-14742-subbug5: dollar-brace-name with missing closing brace throws") { + val ex = intercept[RegexUnsupportedException] { + new RegexParser("$" + "{name").parseReplacement(numCaptureGroups = 0) + } + assert(ex.getMessage.contains("Illegal group reference")) + } + + test("issue-14742: dollar-brace with empty body throws") { + val ex = intercept[RegexUnsupportedException] { + new RegexParser("$" + "{}").parseReplacement(numCaptureGroups = 0) + } + assert(ex.getMessage.contains("Illegal group reference")) + } + + test("issue-14742: numbered backref $1 still works") { + val repl = new RegexParser("$1").parseReplacement(numCaptureGroups = 1) + assert(repl.parts.toList === List(RegexBackref(1))) + } + + test("issue-14742: numbered backref $12 still consumes max digits") { + val repl = new RegexParser("$12").parseReplacement(numCaptureGroups = 12) + assert(repl.parts.toList === List(RegexBackref(12))) + } + + test("issue-14742: escaped metachar \\$ in replacement keeps the \\ pair") { + val repl = new RegexParser("\\$").parseReplacement(numCaptureGroups = 0) + assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('$'))) + } + + test("issue-14742: escaped backslash \\\\ in replacement keeps the \\ pair") { + val repl = new RegexParser("\\\\").parseReplacement(numCaptureGroups = 0) + assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('\\'))) + } + + test("issue-14742: non-ASCII Unicode digit after `$` triggers GPU fallback") { + for (rep <- Seq("$٢", "$१", "$۱")) { + val e = intercept[RegexUnsupportedException] { + new RegexParser(rep).parseReplacement(numCaptureGroups = 4) + } + assert(e.getMessage.startsWith("Illegal group reference"), + s"unexpected message for replacement '$rep': ${e.getMessage}") + } + } + private def parse(pattern: String): RegexAST = { new RegexParser(pattern).parse() }