-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[SPARK-56438][SQL][CORE] Optimize VectorizedPlainValuesReader.readBinary for direct ByteBuffer by eliminating intermediate byte[] copy
#55296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
901fd0f
75012bb
d2e8bb3
76f2f1f
5b9e29f
f176e50
c22329b
074c50d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -203,6 +203,22 @@ public void putBytes(int rowId, int count, byte[] src, int srcIndex) { | |
| Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId, count); | ||
| } | ||
|
|
||
| @Override | ||
| public void putBytes(int rowId, int count, ByteBuffer src, int srcIndex) { | ||
| if (src.hasArray()) { | ||
| Platform.copyMemory(src.array(), Platform.BYTE_ARRAY_OFFSET + src.arrayOffset() + srcIndex, | ||
| null, data + rowId, count); | ||
| } else if (src.isDirect()) { | ||
| long srcAddr = Platform.getDirectBufferAddress(src) + srcIndex; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. d2e8bb3 address this |
||
| Platform.copyMemory(null, srcAddr, null, data + rowId, count); | ||
| } else { | ||
| // Fallback for non-heap, non-direct buffers (e.g., read-only wrappers). | ||
| byte[] tmp = new byte[count]; | ||
| src.get(srcIndex, tmp, 0, count); | ||
| Platform.copyMemory(tmp, Platform.BYTE_ARRAY_OFFSET, null, data + rowId, count); | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public byte getByte(int rowId) { | ||
| if (dictionary == null) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -286,6 +286,12 @@ public void putBooleans(int rowId, int count, byte src, int srcIndex) { | |
| */ | ||
| public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); | ||
|
|
||
| /** | ||
| * Copies {@code count} bytes from a {@link ByteBuffer} starting at absolute position | ||
| * {@code srcIndex} into this column at {@code rowId}. Does not modify the buffer's position. | ||
| */ | ||
| public abstract void putBytes(int rowId, int count, ByteBuffer src, int srcIndex); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Anyway we can add some test cases for this new API? |
||
|
|
||
| /** | ||
| * Sets `value` to the value at rowId. | ||
| */ | ||
|
|
@@ -435,6 +441,25 @@ public final int putByteArray(int rowId, byte[] value) { | |
| return putByteArray(rowId, value, 0, value.length); | ||
| } | ||
|
|
||
| /** | ||
| * Stores bytes from a {@link ByteBuffer} as a variable-length byte array at {@code rowId}. | ||
| * Copies {@code length} bytes starting at absolute position {@code srcPosition} in the buffer. | ||
| * Does not modify the buffer's position. | ||
| */ | ||
| public final int putByteArray(int rowId, ByteBuffer src, int srcPosition, int length) { | ||
| int result = arrayData().appendBytes(length, src, srcPosition); | ||
| putArray(rowId, result, length); | ||
| return result; | ||
| } | ||
|
|
||
| final int appendBytes(int length, ByteBuffer src, int srcPosition) { | ||
| reserve(elementsAppended + length); | ||
| int result = elementsAppended; | ||
| putBytes(elementsAppended, length, src, srcPosition); | ||
| elementsAppended += length; | ||
| return result; | ||
| } | ||
|
|
||
| @Override | ||
| public Decimal getDecimal(int rowId, int precision, int scale) { | ||
| if (isNullAt(rowId)) return null; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.execution.vectorized | ||
|
|
||
| import java.nio.ByteBuffer | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.YearUDT | ||
| import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow | ||
|
|
@@ -262,6 +264,71 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { | |
| } | ||
| } | ||
|
|
||
| testVectors("putByteArray from ByteBuffer", 10, BinaryType) { testVector => | ||
| def verifyPutByteArray(testVector: WritableColumnVector): Unit = { | ||
| (0 until 10).foreach { i => | ||
| assert(testVector.getBinary(i) === s"str$i".getBytes("utf8")) | ||
| } | ||
| } | ||
|
|
||
| // Heap ByteBuffer | ||
| (0 until 10).foreach { i => | ||
| val bytes = s"str$i".getBytes("utf8") | ||
| testVector.putByteArray(i, ByteBuffer.wrap(bytes), 0, bytes.length) | ||
| } | ||
| verifyPutByteArray(testVector) | ||
|
|
||
| // Direct ByteBuffer | ||
| testVector.reset() | ||
| (0 until 10).foreach { i => | ||
| val bytes = s"str$i".getBytes("utf8") | ||
| val buf = ByteBuffer.allocateDirect(bytes.length) | ||
| buf.put(bytes) | ||
| testVector.putByteArray(i, buf, 0, bytes.length) | ||
| } | ||
| verifyPutByteArray(testVector) | ||
|
|
||
| // Read-only ByteBuffer (hasArray=false, isDirect=false) | ||
| testVector.reset() | ||
| (0 until 10).foreach { i => | ||
| val bytes = s"str$i".getBytes("utf8") | ||
| val buf = ByteBuffer.wrap(bytes).asReadOnlyBuffer() | ||
| testVector.putByteArray(i, buf, 0, bytes.length) | ||
| } | ||
| verifyPutByteArray(testVector) | ||
| } | ||
|
|
||
| testVectors("putBytes from ByteBuffer", 16, ByteType) { testVector => | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sunchao Added new tests for both new APIs. |
||
| val data = Array[Byte](10, 20, 30, 40, 50, 60, 70, 80) | ||
|
|
||
| // Heap ByteBuffer | ||
| testVector.putBytes(0, data.length, ByteBuffer.wrap(data), 0) | ||
| (0 until data.length).foreach { i => | ||
| assert(testVector.getByte(i) === data(i)) | ||
| } | ||
|
|
||
| // Direct ByteBuffer | ||
| val directBuf = ByteBuffer.allocateDirect(data.length) | ||
| directBuf.put(data) | ||
| testVector.putBytes(0, data.length, directBuf, 0) | ||
| (0 until data.length).foreach { i => | ||
| assert(testVector.getByte(i) === data(i)) | ||
| } | ||
|
|
||
| // Read-only ByteBuffer (hasArray=false, isDirect=false) | ||
| val readOnlyBuf = ByteBuffer.wrap(data).asReadOnlyBuffer() | ||
| testVector.putBytes(0, data.length, readOnlyBuf, 0) | ||
| (0 until data.length).foreach { i => | ||
| assert(testVector.getByte(i) === data(i)) | ||
| } | ||
|
|
||
| // With srcIndex offset | ||
| testVector.putBytes(0, 4, ByteBuffer.wrap(data), 4) | ||
| (0 until 4).foreach { i => | ||
| assert(testVector.getByte(i) === data(i + 4)) | ||
| } | ||
| } | ||
|
|
||
| DataTypeTestUtils.yearMonthIntervalTypes.foreach { | ||
| dt => | ||
| testVectors(dt.typeName, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this pr is accepted, a corresponding benchmark can be added to PlatformBenchmark at a later time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. It'd be great to have some benchmark results!