Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION.
# Copyright (c) 2022-2026, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -422,6 +422,73 @@ def test_regexp_replace():
'regexp_replace(a, "a|b|c", "A")'),
conf=_regexp_conf)


# https://github.com/NVIDIA/spark-rapids/issues/14742
# Replacement-string parser must match java.util.regex.Matcher#appendReplacement.
# We deliberately use the DataFrame API (not selectExpr) because Spark SQL's
# Hive-inherited variable substitution (spark.sql.variable.substitute=true by
# default) silently expands ${...} inside SQL string literals at parse time,
# which would mask sub-bugs 4 and 5.
# The pyspark Spark-3.3 signature is regexp_replace(col, pattern: str, replacement: str)
# where pattern and replacement are passed verbatim as Python strings.
def test_regexp_replace_subbug1_backslash_digit_is_literal_14742():
# \1 in the replacement is the literal character "1" per Java spec,
# not a backref. Both CPU and GPU should produce "1bc".
from pyspark.sql.functions import regexp_replace, col
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.createDataFrame([("abc",)], ["a"]).select(
regexp_replace(col("a"), "(a)", "\\1")),
conf=_regexp_conf)


@allow_non_gpu('ProjectExec', 'RegExpReplace')
def test_regexp_replace_subbug2_trailing_backslash_throws_14742():
# A trailing bare \ in the replacement must throw on both engines.
# The spark-rapids parser raises RegexUnsupportedException during meta
# tagging, which forces a CPU fallback; Spark's CPU Matcher then throws.
from pyspark.sql.functions import regexp_replace, col
assert_gpu_and_cpu_error(
lambda spark: spark.createDataFrame([("a",)], ["a"]).select(
regexp_replace(col("a"), "a", "\\")).collect(),
conf=_regexp_conf,
error_message="character to be escaped is missing")


@allow_non_gpu('ProjectExec', 'RegExpReplace')
def test_regexp_replace_subbug3_dollar_non_digit_throws_14742():
# $X for non-digit X must throw on both engines. GPU falls back to CPU.
from pyspark.sql.functions import regexp_replace, col
assert_gpu_and_cpu_error(
lambda spark: spark.createDataFrame([("a",)], ["a"]).select(
regexp_replace(col("a"), "a", "$x")).collect(),
conf=_regexp_conf,
error_message="Illegal group reference")


@allow_non_gpu('ProjectExec', 'RegExpReplace')
def test_regexp_replace_subbug4_digit_leading_named_group_throws_14742():
# ${1} has a digit-leading "name" and must throw on both engines.
# GPU falls back to CPU; CPU raises Java's exact error message.
from pyspark.sql.functions import regexp_replace, col
assert_gpu_and_cpu_error(
lambda spark: spark.createDataFrame([("a",)], ["a"]).select(
regexp_replace(col("a"), "(a)", "${1}")).collect(),
conf=_regexp_conf,
error_message="capturing group name {1} starts with digit character")


@allow_non_gpu('ProjectExec', 'RegExpReplace')
def test_regexp_replace_subbug5_unknown_named_group_throws_14742():
# ${name} with no such named group must throw on both engines.
# GPU falls back to CPU because spark-rapids does not support named-group
# references; CPU then raises "No group with name {name}".
from pyspark.sql.functions import regexp_replace, col
assert_gpu_and_cpu_error(
lambda spark: spark.createDataFrame([("a",)], ["a"]).select(
regexp_replace(col("a"), "(a)", "${name}")).collect(),
conf=_regexp_conf,
error_message="No group with name")

@pytest.mark.skipif(is_before_spark_320(), reason='regexp is synonym for RLike starting in Spark 3.2.0')
def test_regexp():
gen = mk_str_gen('[abcd]{1,3}')
Expand Down
116 changes: 82 additions & 34 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,54 +363,101 @@ class RegexParser(pattern: String) {
}
}

/**
* Parse the character that follows a `\` in a replacement string.
*
* Matches the semantics of `java.util.regex.Matcher#appendReplacement`:
* - `\X` for any character X (including digits) emits the literal character X.
* - A trailing bare `\` (end-of-string after `\`) raises an error matching
* Java's `IllegalArgumentException("character to be escaped is missing")`.
*
* The AST shape returned for `\X` is a two-node sequence containing
* `RegexChar('\\') + RegexChar(X)`. Keeping the leading backslash in the AST
* preserves the existing transpile pipeline contract: downstream
* `GpuRegExpUtils.backrefConversion` plus `unescapeReplaceString` strip one
* backslash per `\X` pair when producing the cuDF replacement string. The
* key correctness change is that we no longer emit `RegexBackref` for
* `\digit`, which is what diverged from Java's spec (Java treats `\digit`
* as the literal digit in the replacement, not a backref).
*
* Note: the leading `\` has already been consumed by `parseReplacementBase`.
*/
private def parseBackrefOrEscaped(): RegexAST = {
val start = pos

consumeInt match {
case Some(refNum) =>
RegexBackref(refNum)
peek() match {
case None =>
pos = start
RegexChar('\\')
throw new RegexUnsupportedException(
"character to be escaped is missing", Some(pos))
case Some(_) =>
RegexSequence(ListBuffer(RegexChar('\\'), RegexChar(consume())))
}
}

/**
* Parse the character(s) that follow a `$` in a replacement string.
*
* Matches the semantics of `java.util.regex.Matcher#appendReplacement`:
* - `$` followed by one or more digits is a numbered backref (longest run).
* - `${name}` with name matching `[A-Za-z][A-Za-z0-9]*` is a named-group
* backref. spark-rapids' regex layer does not support named groups in the
* pattern, so well-formed `${name}` references raise an unsupported
* exception rather than executing on the GPU.
* - Any other shape (`$` followed by EOF, by a non-digit non-`{` character,
* by `${` with a missing or malformed name, or by `${name}` with no
* matching `}`) raises an error matching Java's
* `IllegalArgumentException("Illegal group reference")`.
*
* Note: the leading `$` has already been consumed by `parseReplacementBase`.
*/
private def parseBackrefOrLiteralDollar(): RegexAST = {
val start = pos

def treatAsLiteralDollar() = {
pos = start
RegexChar('$')
}

peek() match {
case Some('{') =>
consumeExpected('{')
val num = consumeInt()
if (peek().contains('}')) {
consumeExpected('}')
num match {
case Some(n) =>
RegexBackref(n)
case _ =>
treatAsLiteralDollar()
}
} else {
treatAsLiteralDollar()
}
case Some(ch) if ch >= '1' && ch <= '9' =>
val num = consumeInt()
num match {
case Some(n) =>
RegexBackref(n)
peek() match {
case Some(ch) if isAsciiDigit(ch) =>
// Java: name starts with a digit -> throw.
throw new RegexUnsupportedException(
"Illegal group reference: group name starts with digit character",
Some(pos))
case Some(ch) if isLetter(ch) =>
// Named-group reference. Consume the name then require `}`.
val nameStart = pos
while (!eof() && peek().exists(c => isLetter(c) || isAsciiDigit(c))) {
skip()
}
val name = pattern.substring(nameStart, pos)
if (!peek().contains('}')) {
throw new RegexUnsupportedException(
"Illegal group reference: malformed " + "$" + "{name} reference",
Some(pos))
}
consumeExpected('}')
// cuDF has no support for named groups in the pattern, so any
// syntactically valid `${name}` reference is unsupported on the GPU.
// Throwing RegexUnsupportedException causes a CPU fallback (where
// Java will then raise "No group with name {name}" if the name is
// unknown), preserving the user-visible error.
throw new RegexUnsupportedException(
s"named-group reference $${$name} is not supported on the GPU",
Some(pos))
case _ =>
treatAsLiteralDollar()
throw new RegexUnsupportedException(
"Illegal group reference: empty or malformed " + "$" + "{name}",
Some(pos))
}
case Some(ch) if isAsciiDigit(ch) =>
RegexBackref(consumeInt().get)
case _ =>
treatAsLiteralDollar()
// Java: bare `$` not followed by a digit or `{` -> throw.
throw new RegexUnsupportedException(
"Illegal group reference", Some(pos))
}
}

private def isLetter(ch: Char): Boolean =
(ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z')

private def isAsciiDigit(ch: Char): Boolean = ch >= '0' && ch <= '9'

private def parseEscapedCharacter(): RegexAST = {
peek() match {
case None =>
Expand Down Expand Up @@ -617,7 +664,8 @@ class RegexParser(pattern: String) {

private def consumeInt(): Option[Int] = {
val start = pos
while (!eof() && peek().exists(_.isDigit)) {
// ASCII only: `substring.toInt` rejects Unicode digit codepoints.
while (!eof() && peek().exists(c => c >= '0' && c <= '9')) {
skip()
}
if (start == pos) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1073,22 +1073,28 @@ object GpuRegExpUtils {
}

/**
* Convert symbols of back-references if input string contains any.
* In spark's regex rule, there are two patterns of back-references:
* \group_index and \$group_index
* This method transforms above two patterns into cuDF pattern \${group_index}, except they are
* preceded by escape character.
* Convert numbered back-reference symbols into cuDF's `${group_index}` form.
*
* @param rep replacement string
* @return A pair consists of a boolean indicating whether containing any backref and the
* converted replacement.
* Java `Matcher#appendReplacement` recognizes only the `$digit` form as a
* numbered backref. The `\digit` form is the literal digit character
* (Java's `\X` escape just emits X). The escaped `\$digit` is the literal
* `$digit`.
*
* This method is called after `RegexParser.parseReplacement`, which emits
* `\X` pairs verbatim into the replacement-string AST. We therefore:
* - rewrite an unescaped `$digit...` to `${digits}` (a cuDF backref);
* - leave `\digit` alone (it is a literal digit, not a backref) — the
* following `unescapeReplaceString` will strip the leading `\`;
* - leave any other `\X` pair alone (same reason).
*
* @param rep replacement string (already normalized by RegexParser)
* @return A pair (containsBackref, converted-replacement).
*/
def backrefConversion(rep: String): (Boolean, String) = {
val b = new StringBuilder
var i = 0
while (i < rep.length) {
// match $group_index or \group_index
if (Seq('$', '\\').contains(rep.charAt(i))
if (rep.charAt(i) == '$'
&& i + 1 < rep.length && rep.charAt(i + 1).isDigit) {

b.append("${")
Expand All @@ -1100,7 +1106,7 @@ object GpuRegExpUtils {
b.append("}")
i = j
} else if (rep.charAt(i) == '\\' && i + 1 < rep.length) {
// skip potential \$group_index or \\group_index
// \X pair: keep verbatim; unescapeReplaceString will strip the \
b.append('\\').append(rep.charAt(i + 1))
i += 2
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
* Copyright (c) 2021-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -283,6 +283,104 @@ class RegularExpressionParserSuite extends AnyFunSuite {
RegexChar('$'))))
}

// ------------------------------------------------------------------
// Replacement-string parser: Java appendReplacement spec (issue #14742)
// ------------------------------------------------------------------

test("issue-14742-subbug1: \\N in replacement is the literal character N, not a backref") {
// Java Matcher#appendReplacement: \X for any X (digit or not) emits X literally.
// The AST keeps the backslash so the downstream cuDF replacement pipeline
// (backrefConversion + unescapeReplaceString) collapses \X -> X without
// converting \digit into a numbered backref.
val repl = new RegexParser("\\1").parseReplacement(numCaptureGroups = 1)
assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('1')))
}

test("issue-14742-subbug1: \\a in replacement is the literal character a") {
val repl = new RegexParser("\\a").parseReplacement(numCaptureGroups = 0)
assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('a')))
}

test("issue-14742-subbug2: trailing \\ in replacement throws") {
val ex = intercept[RegexUnsupportedException] {
new RegexParser("\\").parseReplacement(numCaptureGroups = 0)
}
assert(ex.getMessage.contains("character to be escaped is missing"))
}

test("issue-14742-subbug3: bare $X for non-digit X throws") {
val ex = intercept[RegexUnsupportedException] {
new RegexParser("$x").parseReplacement(numCaptureGroups = 0)
}
assert(ex.getMessage.contains("Illegal group reference"))
}

test("issue-14742-subbug3: trailing bare $ throws") {
val ex = intercept[RegexUnsupportedException] {
new RegexParser("$").parseReplacement(numCaptureGroups = 0)
}
assert(ex.getMessage.contains("Illegal group reference"))
}

test("issue-14742-subbug4: dollar-brace-digit-brace throws") {
val ex = intercept[RegexUnsupportedException] {
new RegexParser("$" + "{1}").parseReplacement(numCaptureGroups = 1)
}
assert(ex.getMessage.contains("Illegal group reference"))
assert(ex.getMessage.contains("digit"))
}

test("issue-14742-subbug5: dollar-brace-name-brace for named group is not supported on GPU") {
val ex = intercept[RegexUnsupportedException] {
new RegexParser("$" + "{name}").parseReplacement(numCaptureGroups = 1)
}
assert(ex.getMessage.contains("named-group reference"))
}

test("issue-14742-subbug5: dollar-brace-name with missing closing brace throws") {
val ex = intercept[RegexUnsupportedException] {
new RegexParser("$" + "{name").parseReplacement(numCaptureGroups = 0)
}
assert(ex.getMessage.contains("Illegal group reference"))
}

test("issue-14742: dollar-brace with empty body throws") {
val ex = intercept[RegexUnsupportedException] {
new RegexParser("$" + "{}").parseReplacement(numCaptureGroups = 0)
}
assert(ex.getMessage.contains("Illegal group reference"))
}

test("issue-14742: numbered backref $1 still works") {
val repl = new RegexParser("$1").parseReplacement(numCaptureGroups = 1)
assert(repl.parts.toList === List(RegexBackref(1)))
}

test("issue-14742: numbered backref $12 still consumes max digits") {
val repl = new RegexParser("$12").parseReplacement(numCaptureGroups = 12)
assert(repl.parts.toList === List(RegexBackref(12)))
}

test("issue-14742: escaped metachar \\$ in replacement keeps the \\ pair") {
val repl = new RegexParser("\\$").parseReplacement(numCaptureGroups = 0)
assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('$')))
}

test("issue-14742: escaped backslash \\\\ in replacement keeps the \\ pair") {
val repl = new RegexParser("\\\\").parseReplacement(numCaptureGroups = 0)
assert(repl.parts.toList === List(RegexChar('\\'), RegexChar('\\')))
}

test("issue-14742: non-ASCII Unicode digit after `$` triggers GPU fallback") {
for (rep <- Seq("$٢", "$१", "$۱")) {
val e = intercept[RegexUnsupportedException] {
new RegexParser(rep).parseReplacement(numCaptureGroups = 4)
}
assert(e.getMessage.startsWith("Illegal group reference"),
s"unexpected message for replacement '$rep': ${e.getMessage}")
}
}

private def parse(pattern: String): RegexAST = {
new RegexParser(pattern).parse()
}
Expand Down