Skip to content

Commit 848e3b3

Browse files
committed
Merge remote-tracking branch 'upstream/main' into update_fuzzer
2 parents d14ea87 + 5f39006 commit 848e3b3

File tree

16 files changed

+297
-346
lines changed

16 files changed

+297
-346
lines changed

dev/benchmarks/comet-tpch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,5 @@ $SPARK_HOME/bin/spark-submit \
5050
--data $TPCH_DATA \
5151
--queries $TPCH_QUERIES \
5252
--output . \
53+
--write /tmp \
5354
--iterations 1

dev/benchmarks/tpcbench.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pyspark.sql import SparkSession
2222
import time
2323

24-
def main(benchmark: str, data_path: str, query_path: str, iterations: int, output: str, name: str, query_num: int = None):
24+
def main(benchmark: str, data_path: str, query_path: str, iterations: int, output: str, name: str, query_num: int = None, write_path: str = None):
2525

2626
# Initialize a SparkSession
2727
spark = SparkSession.builder \
@@ -89,10 +89,16 @@ def main(benchmark: str, data_path: str, query_path: str, iterations: int, outpu
8989
print(f"Executing: {sql}")
9090
df = spark.sql(sql)
9191
df.explain()
92-
rows = df.collect()
92+
93+
if write_path is not None:
94+
output_path = f"{write_path}/q{query}"
95+
df.coalesce(1).write.mode("overwrite").parquet(output_path)
96+
print(f"Query {query} results written to {output_path}")
97+
else:
98+
rows = df.collect()
99+
print(f"Query {query} returned {len(rows)} rows")
93100
df.explain()
94101

95-
print(f"Query {query} returned {len(rows)} rows")
96102

97103
end_time = time.time()
98104
print(f"Query {query} took {end_time - start_time} seconds")
@@ -123,6 +129,7 @@ def main(benchmark: str, data_path: str, query_path: str, iterations: int, outpu
123129
parser.add_argument("--output", required=True, help="Path to write output")
124130
parser.add_argument("--name", required=True, help="Prefix for result file e.g. spark/comet/gluten")
125131
parser.add_argument("--query", required=False, type=int, help="Specific query number to run (1-based). If not specified, all queries will be run.")
132+
parser.add_argument("--write", required=False, help="Path to save query results to, in Parquet format.")
126133
args = parser.parse_args()
127134

128-
main(args.benchmark, args.data, args.queries, int(args.iterations), args.output, args.name, args.query)
135+
main(args.benchmark, args.data, args.queries, int(args.iterations), args.output, args.name, args.query, args.write)

docs/source/user-guide/latest/configs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ These settings can be used to determine which parts of the plan are accelerated
268268
| `spark.comet.expression.Reverse.enabled` | Enable Comet acceleration for `Reverse` | true |
269269
| `spark.comet.expression.Round.enabled` | Enable Comet acceleration for `Round` | true |
270270
| `spark.comet.expression.Second.enabled` | Enable Comet acceleration for `Second` | true |
271+
| `spark.comet.expression.Sha1.enabled` | Enable Comet acceleration for `Sha1` | true |
271272
| `spark.comet.expression.Sha2.enabled` | Enable Comet acceleration for `Sha2` | true |
272273
| `spark.comet.expression.ShiftLeft.enabled` | Enable Comet acceleration for `ShiftLeft` | true |
273274
| `spark.comet.expression.ShiftRight.enabled` | Enable Comet acceleration for `ShiftRight` | true |

fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.rogach.scallop.ScallopOption
2626

2727
import org.apache.spark.sql.SparkSession
2828

29-
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
29+
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions}
3030

3131
class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
3232
object generateData extends Subcommand("data") {
@@ -78,19 +78,19 @@ object Main {
7878
case Some(seed) => new Random(seed)
7979
case None => new Random()
8080
}
81-
val options = DataGenOptions(
82-
allowNull = true,
83-
generateArray = conf.generateData.generateArrays(),
84-
generateStruct = conf.generateData.generateStructs(),
85-
generateMap = conf.generateData.generateMaps(),
86-
generateNegativeZero = !conf.generateData.excludeNegativeZero())
8781
for (i <- 0 until conf.generateData.numFiles()) {
8882
ParquetGenerator.makeParquetFile(
8983
r,
9084
spark,
9185
s"test$i.parquet",
9286
numRows = conf.generateData.numRows(),
93-
options)
87+
SchemaGenOptions(
88+
generateArray = conf.generateData.generateArrays(),
89+
generateStruct = conf.generateData.generateStructs(),
90+
generateMap = conf.generateData.generateMaps()),
91+
DataGenOptions(
92+
allowNull = true,
93+
generateNegativeZero = !conf.generateData.excludeNegativeZero()))
9494
}
9595
case Some(conf.generateQueries) =>
9696
val r = conf.generateQueries.randomSeed.toOption match {

native/core/src/execution/jni_api.rs

Lines changed: 56 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,16 @@ use datafusion_spark::function::string::char::CharFunc;
5050
use futures::poll;
5151
use futures::stream::StreamExt;
5252
use jni::objects::JByteBuffer;
53-
use jni::sys::JNI_FALSE;
53+
use jni::sys::{jlongArray, JNI_FALSE};
5454
use jni::{
5555
errors::Result as JNIResult,
5656
objects::{
57-
JByteArray, JClass, JIntArray, JLongArray, JObject, JObjectArray, JPrimitiveArray, JString,
57+
GlobalRef, JByteArray, JClass, JIntArray, JLongArray, JObject, JObjectArray, JString,
5858
ReleaseMode,
5959
},
60-
sys::{jbyteArray, jint, jlong, jlongArray},
60+
sys::{jboolean, jdouble, jint, jlong},
6161
JNIEnv,
6262
};
63-
use jni::{
64-
objects::GlobalRef,
65-
sys::{jboolean, jdouble, jintArray, jobjectArray, jstring},
66-
};
6763
use std::collections::HashMap;
6864
use std::path::PathBuf;
6965
use std::time::{Duration, Instant};
@@ -159,26 +155,25 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
159155
e: JNIEnv,
160156
_class: JClass,
161157
id: jlong,
162-
iterators: jobjectArray,
163-
serialized_query: jbyteArray,
164-
serialized_spark_configs: jbyteArray,
158+
iterators: JObjectArray,
159+
serialized_query: JByteArray,
160+
serialized_spark_configs: JByteArray,
165161
partition_count: jint,
166162
metrics_node: JObject,
167163
metrics_update_interval: jlong,
168164
comet_task_memory_manager_obj: JObject,
169-
local_dirs: jobjectArray,
165+
local_dirs: JObjectArray,
170166
batch_size: jint,
171167
off_heap_mode: jboolean,
172-
memory_pool_type: jstring,
168+
memory_pool_type: JString,
173169
memory_limit: jlong,
174170
memory_limit_per_task: jlong,
175171
task_attempt_id: jlong,
176172
key_unwrapper_obj: JObject,
177173
) -> jlong {
178174
try_unwrap_or_throw(&e, |mut env| {
179175
// Deserialize Spark configs
180-
let array = unsafe { JPrimitiveArray::from_raw(serialized_spark_configs) };
181-
let bytes = env.convert_byte_array(array)?;
176+
let bytes = env.convert_byte_array(serialized_spark_configs)?;
182177
let spark_configs = serde::deserialize_config(bytes.as_slice())?;
183178
let spark_config: HashMap<String, String> = spark_configs.entries.into_iter().collect();
184179

@@ -196,18 +191,16 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
196191
let start = Instant::now();
197192

198193
// Deserialize query plan
199-
let array = unsafe { JPrimitiveArray::from_raw(serialized_query) };
200-
let bytes = env.convert_byte_array(array)?;
194+
let bytes = env.convert_byte_array(serialized_query)?;
201195
let spark_plan = serde::deserialize_op(bytes.as_slice())?;
202196

203197
let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?);
204198

205199
// Get the global references of input sources
206200
let mut input_sources = vec![];
207-
let iter_array = JObjectArray::from_raw(iterators);
208-
let num_inputs = env.get_array_length(&iter_array)?;
201+
let num_inputs = env.get_array_length(&iterators)?;
209202
for i in 0..num_inputs {
210-
let input_source = env.get_object_array_element(&iter_array, i)?;
203+
let input_source = env.get_object_array_element(&iterators, i)?;
211204
let input_source = Arc::new(jni_new_global_ref!(env, input_source)?);
212205
input_sources.push(input_source);
213206
}
@@ -216,7 +209,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
216209
let task_memory_manager =
217210
Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?);
218211

219-
let memory_pool_type = env.get_string(&JString::from_raw(memory_pool_type))?.into();
212+
let memory_pool_type = env.get_string(&memory_pool_type)?.into();
220213
let memory_pool_config = parse_memory_pool_config(
221214
off_heap_mode != JNI_FALSE,
222215
memory_pool_type,
@@ -227,13 +220,12 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
227220
create_memory_pool(&memory_pool_config, task_memory_manager, task_attempt_id);
228221

229222
// Get local directories for storing spill files
230-
let local_dirs_array = JObjectArray::from_raw(local_dirs);
231-
let num_local_dirs = env.get_array_length(&local_dirs_array)?;
232-
let mut local_dirs = vec![];
223+
let num_local_dirs = env.get_array_length(&local_dirs)?;
224+
let mut local_dirs_vec = vec![];
233225
for i in 0..num_local_dirs {
234-
let local_dir: JString = env.get_object_array_element(&local_dirs_array, i)?.into();
226+
let local_dir: JString = env.get_object_array_element(&local_dirs, i)?.into();
235227
let local_dir = env.get_string(&local_dir)?;
236-
local_dirs.push(local_dir.into());
228+
local_dirs_vec.push(local_dir.into());
237229
}
238230

239231
// We need to keep the session context alive. Some session state like temporary
@@ -242,7 +234,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
242234
let session = prepare_datafusion_session_context(
243235
batch_size as usize,
244236
memory_pool,
245-
local_dirs,
237+
local_dirs_vec,
246238
max_temp_directory_size,
247239
)?;
248240

@@ -344,21 +336,17 @@ fn prepare_datafusion_session_context(
344336
/// Prepares arrow arrays for output.
345337
fn prepare_output(
346338
env: &mut JNIEnv,
347-
array_addrs: jlongArray,
348-
schema_addrs: jlongArray,
339+
array_addrs: JLongArray,
340+
schema_addrs: JLongArray,
349341
output_batch: RecordBatch,
350342
validate: bool,
351343
) -> CometResult<jlong> {
352-
let array_address_array = unsafe { JLongArray::from_raw(array_addrs) };
353-
let num_cols = env.get_array_length(&array_address_array)? as usize;
344+
let num_cols = env.get_array_length(&array_addrs)? as usize;
354345

355-
let array_addrs =
356-
unsafe { env.get_array_elements(&array_address_array, ReleaseMode::NoCopyBack)? };
346+
let array_addrs = unsafe { env.get_array_elements(&array_addrs, ReleaseMode::NoCopyBack)? };
357347
let array_addrs = &*array_addrs;
358348

359-
let schema_address_array = unsafe { JLongArray::from_raw(schema_addrs) };
360-
let schema_addrs =
361-
unsafe { env.get_array_elements(&schema_address_array, ReleaseMode::NoCopyBack)? };
349+
let schema_addrs = unsafe { env.get_array_elements(&schema_addrs, ReleaseMode::NoCopyBack)? };
362350
let schema_addrs = &*schema_addrs;
363351

364352
let results = output_batch.columns();
@@ -441,8 +429,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
441429
stage_id: jint,
442430
partition: jint,
443431
exec_context: jlong,
444-
array_addrs: jlongArray,
445-
schema_addrs: jlongArray,
432+
array_addrs: JLongArray,
433+
schema_addrs: JLongArray,
446434
) -> jlong {
447435
try_unwrap_or_throw(&e, |mut env| {
448436
// Retrieve the query
@@ -599,24 +587,21 @@ fn update_metrics(env: &mut JNIEnv, exec_context: &mut ExecutionContext) -> Come
599587

600588
fn convert_datatype_arrays(
601589
env: &'_ mut JNIEnv<'_>,
602-
serialized_datatypes: jobjectArray,
590+
serialized_datatypes: JObjectArray,
603591
) -> JNIResult<Vec<ArrowDataType>> {
604-
unsafe {
605-
let obj_array = JObjectArray::from_raw(serialized_datatypes);
606-
let array_len = env.get_array_length(&obj_array)?;
607-
let mut res: Vec<ArrowDataType> = Vec::new();
608-
609-
for i in 0..array_len {
610-
let inner_array = env.get_object_array_element(&obj_array, i)?;
611-
let inner_array: JByteArray = inner_array.into();
612-
let bytes = env.convert_byte_array(inner_array)?;
613-
let data_type = serde::deserialize_data_type(bytes.as_slice()).unwrap();
614-
let arrow_dt = to_arrow_datatype(&data_type);
615-
res.push(arrow_dt);
616-
}
617-
618-
Ok(res)
592+
let array_len = env.get_array_length(&serialized_datatypes)?;
593+
let mut res: Vec<ArrowDataType> = Vec::new();
594+
595+
for i in 0..array_len {
596+
let inner_array = env.get_object_array_element(&serialized_datatypes, i)?;
597+
let inner_array: JByteArray = inner_array.into();
598+
let bytes = env.convert_byte_array(inner_array)?;
599+
let data_type = serde::deserialize_data_type(bytes.as_slice()).unwrap();
600+
let arrow_dt = to_arrow_datatype(&data_type);
601+
res.push(arrow_dt);
619602
}
603+
604+
Ok(res)
620605
}
621606

622607
fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext {
@@ -634,16 +619,16 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext {
634619
pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative(
635620
e: JNIEnv,
636621
_class: JClass,
637-
row_addresses: jlongArray,
638-
row_sizes: jintArray,
639-
serialized_datatypes: jobjectArray,
640-
file_path: jstring,
622+
row_addresses: JLongArray,
623+
row_sizes: JIntArray,
624+
serialized_datatypes: JObjectArray,
625+
file_path: JString,
641626
prefer_dictionary_ratio: jdouble,
642627
batch_size: jlong,
643628
checksum_enabled: jboolean,
644629
checksum_algo: jint,
645630
current_checksum: jlong,
646-
compression_codec: jstring,
631+
compression_codec: JString,
647632
compression_level: jint,
648633
tracing_enabled: jboolean,
649634
) -> jlongArray {
@@ -654,21 +639,16 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative
654639
|| {
655640
let data_types = convert_datatype_arrays(&mut env, serialized_datatypes)?;
656641

657-
let row_address_array = JLongArray::from_raw(row_addresses);
658-
let row_num = env.get_array_length(&row_address_array)? as usize;
642+
let row_num = env.get_array_length(&row_addresses)? as usize;
659643
let row_addresses =
660-
env.get_array_elements(&row_address_array, ReleaseMode::NoCopyBack)?;
644+
env.get_array_elements(&row_addresses, ReleaseMode::NoCopyBack)?;
661645

662-
let row_size_array = JIntArray::from_raw(row_sizes);
663-
let row_sizes = env.get_array_elements(&row_size_array, ReleaseMode::NoCopyBack)?;
646+
let row_sizes = env.get_array_elements(&row_sizes, ReleaseMode::NoCopyBack)?;
664647

665648
let row_addresses_ptr = row_addresses.as_ptr();
666649
let row_sizes_ptr = row_sizes.as_ptr();
667650

668-
let output_path: String = env
669-
.get_string(&JString::from_raw(file_path))
670-
.unwrap()
671-
.into();
651+
let output_path: String = env.get_string(&file_path).unwrap().into();
672652

673653
let checksum_enabled = checksum_enabled == 1;
674654
let current_checksum = if current_checksum == i64::MIN {
@@ -678,10 +658,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative
678658
Some(current_checksum as u32)
679659
};
680660

681-
let compression_codec: String = env
682-
.get_string(&JString::from_raw(compression_codec))
683-
.unwrap()
684-
.into();
661+
let compression_codec: String = env.get_string(&compression_codec).unwrap().into();
685662

686663
let compression_codec = match compression_codec.as_str() {
687664
"zstd" => CompressionCodec::Zstd(compression_level),
@@ -754,8 +731,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
754731
_class: JClass,
755732
byte_buffer: JByteBuffer,
756733
length: jint,
757-
array_addrs: jlongArray,
758-
schema_addrs: jlongArray,
734+
array_addrs: JLongArray,
735+
schema_addrs: JLongArray,
759736
tracing_enabled: jboolean,
760737
) -> jlong {
761738
try_unwrap_or_throw(&e, |mut env| {
@@ -775,10 +752,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
775752
pub unsafe extern "system" fn Java_org_apache_comet_Native_traceBegin(
776753
e: JNIEnv,
777754
_class: JClass,
778-
event: jstring,
755+
event: JString,
779756
) {
780757
try_unwrap_or_throw(&e, |mut env| {
781-
let name: String = env.get_string(&JString::from_raw(event)).unwrap().into();
758+
let name: String = env.get_string(&event).unwrap().into();
782759
trace_begin(&name);
783760
Ok(())
784761
})
@@ -790,10 +767,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_traceBegin(
790767
pub unsafe extern "system" fn Java_org_apache_comet_Native_traceEnd(
791768
e: JNIEnv,
792769
_class: JClass,
793-
event: jstring,
770+
event: JString,
794771
) {
795772
try_unwrap_or_throw(&e, |mut env| {
796-
let name: String = env.get_string(&JString::from_raw(event)).unwrap().into();
773+
let name: String = env.get_string(&event).unwrap().into();
797774
trace_end(&name);
798775
Ok(())
799776
})
@@ -805,11 +782,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_traceEnd(
805782
pub unsafe extern "system" fn Java_org_apache_comet_Native_logMemoryUsage(
806783
e: JNIEnv,
807784
_class: JClass,
808-
name: jstring,
785+
name: JString,
809786
value: jlong,
810787
) {
811788
try_unwrap_or_throw(&e, |mut env| {
812-
let name: String = env.get_string(&JString::from_raw(name)).unwrap().into();
789+
let name: String = env.get_string(&name).unwrap().into();
813790
log_memory_usage(&name, value as u64);
814791
Ok(())
815792
})

0 commit comments

Comments
 (0)