Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion::execution::context::SessionContext;
use datafusion::execution::object_store::ObjectStoreUrl;
use datafusion::parquet::arrow::ArrowWriter;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr::expressions::{CastColumnExpr, CastExpr};
use datafusion::physical_expr::expressions::CastExpr;
use datafusion::prelude::SessionConfig;
use datafusion_physical_expr_adapter::{
DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
Expand All @@ -43,9 +43,10 @@ use object_store::path::Path;
use object_store::{ObjectStore, ObjectStoreExt, PutPayload};

// Example showing how to implement custom casting rules to adapt file schemas.
// This example enforces that casts must be strictly widening: if the file type is Int64 and the table type is Int32, it will error
// before even reading the data.
// Without this custom cast rule DataFusion would happily do the narrowing cast, potentially erroring only if it found a row with data it could not cast.
// This example enforces strictly widening casts: if the file type is Int64 and
// the table type is Int32, it errors before reading the data. Without this
// custom cast rule DataFusion would apply the narrowing cast and might only
// error after reading a row that it could not cast.
pub async fn custom_file_casts() -> Result<()> {
println!("=== Creating example data ===");

Expand Down Expand Up @@ -139,7 +140,7 @@ async fn write_data(
Ok(())
}

/// Factory for creating DefaultValuePhysicalExprAdapter instances
/// Factory for creating custom cast physical expression adapters
#[derive(Debug)]
struct CustomCastPhysicalExprAdapterFactory {
inner: Arc<dyn PhysicalExprAdapterFactory>,
Expand Down Expand Up @@ -167,8 +168,8 @@ impl PhysicalExprAdapterFactory for CustomCastPhysicalExprAdapterFactory {
}
}

/// Custom PhysicalExprAdapter that handles missing columns with default values from metadata
/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation
/// Custom `PhysicalExprAdapter` that wraps the default adapter and rejects
/// narrowing file-schema casts.
#[derive(Debug, Clone)]
struct CustomCastsPhysicalExprAdapter {
physical_file_schema: SchemaRef,
Expand All @@ -177,34 +178,23 @@ struct CustomCastsPhysicalExprAdapter {

impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter {
fn rewrite(&self, mut expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
// First delegate to the inner adapter to handle missing columns and discover any necessary casts
// First delegate to the inner adapter to handle standard schema adaptation
// and discover any necessary casts.
expr = self.inner.rewrite(expr)?;
// Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression
// For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138).
// Now apply custom casting rules or swap CastExprs for a custom cast
// kernel / expression. For example, DataFusion Comet has a custom cast
// kernel in its native Spark expression implementation.
expr.transform(|expr| {
if let Some(cast) = expr.as_any().downcast_ref::<CastExpr>() {
let input_data_type =
cast.expr().data_type(&self.physical_file_schema)?;
let output_data_type = cast.data_type(&self.physical_file_schema)?;
let output_data_type = cast.target_field().data_type();
if !cast.is_bigger_cast(&input_data_type) {
return not_impl_err!(
"Unsupported CAST from {input_data_type} to {output_data_type}"
);
}
}
if let Some(cast) = expr.as_any().downcast_ref::<CastColumnExpr>() {
let input_data_type =
cast.expr().data_type(&self.physical_file_schema)?;
let output_data_type = cast.data_type(&self.physical_file_schema)?;
if !CastExpr::check_bigger_cast(
cast.target_field().data_type(),
&input_data_type,
) {
return not_impl_err!(
"Unsupported CAST from {input_data_type} to {output_data_type}"
);
}
}
Ok(Transformed::no(expr))
})
.data()
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/custom_data_source/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
//!
//! ## Usage
//! ```bash
//! cargo run --example custom_data_source -- [all|csv_json_opener|csv_sql_streaming|custom_datasource|custom_file_casts|custom_file_format|default_column_values|file_stream_provider]
//! cargo run --example custom_data_source -- [all|adapter_serialization|csv_json_opener|csv_sql_streaming|custom_datasource|custom_file_casts|custom_file_format|default_column_values|file_stream_provider]
//! ```
//!
//! Each subcommand runs a corresponding example:
Expand Down
8 changes: 4 additions & 4 deletions datafusion/physical-expr-adapter/src/schema_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ mod tests {
.expect("Expected CastExpr")
}

fn assert_cast_column(cast_expr: &CastExpr, name: &str, index: usize) {
fn assert_cast_input_column(cast_expr: &CastExpr, name: &str, index: usize) {
let inner_col = cast_expr
.expr()
.as_any()
Expand Down Expand Up @@ -771,7 +771,7 @@ mod tests {

let left_cast = assert_cast_expr(left.left());
assert_eq!(left_cast.target_field().data_type(), &DataType::Int64);
assert_cast_column(left_cast, "a", 0);
assert_cast_input_column(left_cast, "a", 0);

let right = outer
.right()
Expand Down Expand Up @@ -1672,7 +1672,7 @@ mod tests {
let cast_expr = assert_cast_expr(&result);

// Verify the inner column points to the correct physical index (1)
assert_cast_column(cast_expr, "a", 1);
assert_cast_input_column(cast_expr, "a", 1);

// Verify cast types
assert_eq!(
Expand All @@ -1692,7 +1692,7 @@ mod tests {
// Regression: this must still resolve against physical field `a` by name.
let rewritten = adapter.rewrite(Arc::new(Column::new("a", 0))).unwrap();
let cast_expr = assert_cast_expr(&rewritten);
assert_cast_column(cast_expr, "a", 1);
assert_cast_input_column(cast_expr, "a", 1);
assert_eq!(cast_expr.target_field().data_type(), &DataType::Int64);
}
}
Loading
Loading