diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml index ea075d7a2dc..ec15acdd66d 100644 --- a/integration_tests/pom.xml +++ b/integration_tests/pom.xml @@ -132,6 +132,8 @@ parquet-hadoop*.jar spark-avro*.jar + spark-protobuf*.jar + protobuf-java-*.jar @@ -166,6 +168,31 @@ + + + copy-spark-protobuf + package + + copy + + + ${spark.protobuf.skipCopy} + true + + + org.apache.spark + spark-protobuf_${scala.binary.version} + ${spark.version} + + + com.google.protobuf + protobuf-java + 3.25.5 + + + + diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh index 2f1b96d901d..db7b3126f4d 100755 --- a/integration_tests/run_pyspark_from_build.sh +++ b/integration_tests/run_pyspark_from_build.sh @@ -46,6 +46,9 @@ # To run all tests, including Avro tests: # INCLUDE_SPARK_AVRO_JAR=true ./run_pyspark_from_build.sh # +# To run tests WITHOUT Protobuf tests (protobuf is included by default): +# INCLUDE_SPARK_PROTOBUF_JAR=false ./run_pyspark_from_build.sh +# # To run a specific test: # TEST=my_test ./run_pyspark_from_build.sh # @@ -100,6 +103,7 @@ else # support alternate local jars NOT building from the source code if [ -d "$LOCAL_JAR_PATH" ]; then AVRO_JARS=$(echo "$LOCAL_JAR_PATH"/spark-avro*.jar) + PROTOBUF_JARS=$(echo "$LOCAL_JAR_PATH"/spark-protobuf*.jar "$LOCAL_JAR_PATH"/protobuf-java-*.jar) PLUGIN_JAR=$(echo "$LOCAL_JAR_PATH"/rapids-4-spark_*.jar) if [ -f $(echo $LOCAL_JAR_PATH/parquet-hadoop*.jar) ]; then export INCLUDE_PARQUET_HADOOP_TEST_JAR=true @@ -116,6 +120,7 @@ else else [[ "$SCALA_VERSION" != "2.12" ]] && TARGET_DIR=${TARGET_DIR/integration_tests/scala$SCALA_VERSION\/integration_tests} AVRO_JARS=$(echo "$TARGET_DIR"/dependency/spark-avro*.jar) + PROTOBUF_JARS=$(echo "$TARGET_DIR"/dependency/spark-protobuf*.jar "$TARGET_DIR"/dependency/protobuf-java-*.jar) PARQUET_HADOOP_TESTS=$(echo "$TARGET_DIR"/dependency/parquet-hadoop*.jar) # remove the log4j.properties file so it doesn't conflict with ours, ignore errors # if it isn't present or already removed @@ -141,9 +146,24 @@ else AVRO_JARS="" fi - # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar if they exist + # spark-protobuf shades `com.google.protobuf.*` internally and Spark does not bundle the + # unshaded jar, so we must ship both jars to the test classpath. + INCLUDE_SPARK_PROTOBUF_JAR_REQUESTED=$(echo "${INCLUDE_SPARK_PROTOBUF_JAR}" | tr '[:upper:]' '[:lower:]') + if [[ "$INCLUDE_SPARK_PROTOBUF_JAR_REQUESTED" != "false" \ + && $(readlink -e $PROTOBUF_JARS 2>/dev/null | wc -l) -eq 2 ]]; + then + export INCLUDE_SPARK_PROTOBUF_JAR=true + else + if [[ "$INCLUDE_SPARK_PROTOBUF_JAR_REQUESTED" == "true" ]]; then + >&2 echo "WARNING: INCLUDE_SPARK_PROTOBUF_JAR=true was requested but spark-protobuf/protobuf-java jars were not found under $TARGET_DIR/dependency; disabling protobuf tests." + fi + export INCLUDE_SPARK_PROTOBUF_JAR=false + PROTOBUF_JARS="" + fi + + # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar protobuf.jar if they exist # Remove non-existing paths and canonicalize the paths including get rid of links and `..` - ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS || true) + ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS $PROTOBUF_JARS || true) # `:` separated jars ALL_JARS="${ALL_JARS//$'\n'/:}" diff --git a/integration_tests/src/main/python/protobuf_test.py b/integration_tests/src/main/python/protobuf_test.py new file mode 100644 index 00000000000..6b2606f06b6 --- /dev/null +++ b/integration_tests/src/main/python/protobuf_test.py @@ -0,0 +1,130 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os + +import pytest + +from asserts import assert_gpu_fallback_collect +from marks import allow_non_gpu +from spark_session import is_before_spark_340, with_cpu_session +import pyspark.sql.functions as f + +if os.environ.get('INCLUDE_SPARK_PROTOBUF_JAR', 'true').lower() == 'false': + pytestmark = pytest.mark.skip(reason="INCLUDE_SPARK_PROTOBUF_JAR is disabled") +else: + pytestmark = pytest.mark.skipif( + is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") + + +def _try_import_from_protobuf(): + try: + from pyspark.sql.protobuf.functions import from_protobuf + return from_protobuf + except Exception: + return None + + +@pytest.fixture(scope="module") +def from_protobuf_fn(): + fn = _try_import_from_protobuf() + if fn is None: + pytest.skip("from_protobuf not available") + return fn + + +def _encode_varint(value): + out = bytearray() + value &= 0xFFFFFFFFFFFFFFFF + while True: + bits = value & 0x7F + value >>= 7 + if value: + out.append(bits | 0x80) + else: + out.append(bits) + return bytes(out) + + +def _encode_simple_message(i32_value, s_value): + buf = bytearray() + buf += _encode_varint((1 << 3) | 0) # field 1, VARINT + buf += _encode_varint(i32_value) + s_bytes = s_value.encode("utf-8") + buf += _encode_varint((2 << 3) | 2) # field 2, LENGTH-DELIMITED + buf += _encode_varint(len(s_bytes)) + buf += s_bytes + return bytes(buf) + + +def _build_simple_descriptor_bytes(spark): + D = spark.sparkContext._jvm.com.google.protobuf.DescriptorProtos + i32_field = D.FieldDescriptorProto.newBuilder() \ + .setName("i32").setNumber(1) \ + .setLabel(D.FieldDescriptorProto.Label.LABEL_OPTIONAL) \ + .setType(D.FieldDescriptorProto.Type.TYPE_INT32).build() + s_field = D.FieldDescriptorProto.newBuilder() \ + .setName("s").setNumber(2) \ + .setLabel(D.FieldDescriptorProto.Label.LABEL_OPTIONAL) \ + .setType(D.FieldDescriptorProto.Type.TYPE_STRING).build() + msg = D.DescriptorProto.newBuilder() \ + .setName("Simple").addField(i32_field).addField(s_field).build() + file_builder = D.FileDescriptorProto.newBuilder() \ + .setName("simple.proto").setPackage("test").addMessageType(msg) \ + .setSyntax("proto2") + fds = D.FileDescriptorSet.newBuilder().addFile(file_builder.build()).build() + return bytes(fds.toByteArray()) + + +@pytest.fixture +def simple_desc(spark_tmp_path): + desc_path = spark_tmp_path + "/simple.desc" + desc_bytes = with_cpu_session(_build_simple_descriptor_bytes) + with open(desc_path, "wb") as fp: + fp.write(desc_bytes) + return desc_path, desc_bytes + + +_smoke_rows = [(1, "a"), (-2, "bb"), (0, ""), (12345, "hello")] + + +def _make_smoke_df(spark): + encoded = [(_encode_simple_message(i, s),) for (i, s) in _smoke_rows] + return spark.createDataFrame(encoded, ["bin"]) + + +@allow_non_gpu("ProjectExec", "ProtobufDataToCatalyst") +def test_from_protobuf_smoke_path_api(simple_desc, from_protobuf_fn): + desc_path, _ = simple_desc + + def run(spark): + return _make_smoke_df(spark).select( + from_protobuf_fn(f.col("bin"), "test.Simple", desc_path).alias("d")) + + assert_gpu_fallback_collect(run, "ProtobufDataToCatalyst") + + +@allow_non_gpu("ProjectExec", "ProtobufDataToCatalyst") +def test_from_protobuf_smoke_binary_descriptor_api(simple_desc, from_protobuf_fn): + if "binaryDescriptorSet" not in inspect.signature(from_protobuf_fn).parameters: + pytest.skip("binaryDescriptorSet kwarg is Spark 3.5+ only") + _, desc_bytes = simple_desc + + def run(spark): + return _make_smoke_df(spark).select( + from_protobuf_fn(f.col("bin"), "test.Simple", + binaryDescriptorSet=bytearray(desc_bytes)).alias("d")) + + assert_gpu_fallback_collect(run, "ProtobufDataToCatalyst") diff --git a/pom.xml b/pom.xml index 450211bcc4a..0dc708fba65 100644 --- a/pom.xml +++ b/pom.xml @@ -94,6 +94,7 @@ rapids-4-spark-delta-21x rapids-4-spark-delta-22x rapids-4-spark-delta-23x + true delta-lake/delta-21x @@ -118,6 +119,7 @@ rapids-4-spark-delta-21x rapids-4-spark-delta-22x rapids-4-spark-delta-23x + true delta-lake/delta-21x @@ -142,6 +144,7 @@ rapids-4-spark-delta-21x rapids-4-spark-delta-22x rapids-4-spark-delta-23x + true delta-lake/delta-21x @@ -166,6 +169,7 @@ rapids-4-spark-delta-21x rapids-4-spark-delta-22x rapids-4-spark-delta-23x + true delta-lake/delta-21x @@ -190,6 +194,7 @@ rapids-4-spark-delta-21x rapids-4-spark-delta-22x rapids-4-spark-delta-23x + true delta-lake/delta-21x @@ -806,6 +811,8 @@ ${spark.rapids.project.basedir}/target/${spark.version.classifier}/.sbt/1.0/zinc/org.scala-sbt false 330 + + false 1.8 8 ${java.major.version} diff --git a/scala2.13/integration_tests/pom.xml b/scala2.13/integration_tests/pom.xml index e1d2e133210..77081a9ea47 100644 --- a/scala2.13/integration_tests/pom.xml +++ b/scala2.13/integration_tests/pom.xml @@ -132,6 +132,8 @@ parquet-hadoop*.jar spark-avro*.jar + spark-protobuf*.jar + protobuf-java-*.jar @@ -166,6 +168,31 @@ + + + copy-spark-protobuf + package + + copy + + + ${spark.protobuf.skipCopy} + true + + + org.apache.spark + spark-protobuf_${scala.binary.version} + ${spark.version} + + + com.google.protobuf + protobuf-java + 3.25.5 + + + + diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml index 6b9a9aa8d68..09d9e74cf68 100644 --- a/scala2.13/pom.xml +++ b/scala2.13/pom.xml @@ -94,6 +94,7 @@ rapids-4-spark-delta-21x rapids-4-spark-delta-22x rapids-4-spark-delta-23x + true delta-lake/delta-21x @@ -118,6 +119,7 @@ rapids-4-spark-delta-21x rapids-4-spark-delta-22x rapids-4-spark-delta-23x + true delta-lake/delta-21x @@ -142,6 +144,7 @@ rapids-4-spark-delta-21x rapids-4-spark-delta-22x rapids-4-spark-delta-23x + true delta-lake/delta-21x @@ -166,6 +169,7 @@ rapids-4-spark-delta-21x rapids-4-spark-delta-22x rapids-4-spark-delta-23x + true delta-lake/delta-21x @@ -190,6 +194,7 @@ rapids-4-spark-delta-21x rapids-4-spark-delta-22x rapids-4-spark-delta-23x + true delta-lake/delta-21x @@ -806,6 +811,8 @@ ${spark.rapids.project.basedir}/target/${spark.version.classifier}/.sbt/1.0/zinc/org.scala-sbt false 330 + + false 1.8 8 ${java.major.version}