diff --git a/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java b/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java index 9502e1d5bd6f..af24ce40cac8 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java +++ b/api/src/main/java/org/apache/iceberg/expressions/ExpressionUtil.java @@ -24,6 +24,7 @@ import java.time.temporal.ChronoUnit; import java.util.List; import java.util.Locale; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.regex.Pattern; @@ -145,21 +146,28 @@ public static String toSanitizedString( } /** - * Extracts an expression that references only the given column IDs from the given expression. + * Returns an expression that retains only predicates which reference one of the given field IDs. * - *

The result is inclusive. If a row would match the original filter, it must match the result - * filter. - * - * @param expression a filter Expression - * @param schema a Schema + * @param expression a filter expression + * @param schema schema for binding references * @param caseSensitive whether binding is case sensitive - * @param ids field IDs used to match predicates to extract from the expression - * @return an Expression that selects at least the same rows as the original using only the IDs + * @param ids field IDs to retain predicates for + * @return expression containing only predicates that reference the given IDs */ public static Expression extractByIdInclusive( Expression expression, Schema schema, boolean caseSensitive, int... ids) { - PartitionSpec spec = identitySpec(schema, ids); - return Projections.inclusive(spec, caseSensitive).project(Expressions.rewriteNot(expression)); + if (ids == null || ids.length == 0) { + return Expressions.alwaysTrue(); + } + + ImmutableSet.Builder retainIds = ImmutableSet.builder(); + for (int id : ids) { + retainIds.add(id); + } + + return ExpressionVisitors.visit( + Expressions.rewriteNot(expression), + new RetainPredicatesByFieldIdVisitor(schema, caseSensitive, retainIds.build())); } /** @@ -262,6 +270,61 @@ public static UnboundTerm unbind(Term term) { throw new UnsupportedOperationException("Cannot unbind unsupported term: " + term); } + private static class RetainPredicatesByFieldIdVisitor + extends ExpressionVisitors.ExpressionVisitor { + private final Schema schema; + private final boolean caseSensitive; + private final Set retainFieldIds; + + RetainPredicatesByFieldIdVisitor( + Schema schema, boolean caseSensitive, Set retainFieldIds) { + this.schema = schema; + this.caseSensitive = caseSensitive; + this.retainFieldIds = retainFieldIds; + } + + @Override + public Expression alwaysTrue() { + return Expressions.alwaysTrue(); + } + + @Override + public Expression alwaysFalse() { + return Expressions.alwaysFalse(); + } + + @Override + public Expression not(Expression result) { + return Expressions.not(result); + } + + @Override + public Expression and(Expression leftResult, Expression rightResult) { + return Expressions.and(leftResult, rightResult); + } + + @Override + public Expression or(Expression leftResult, Expression rightResult) { + return Expressions.or(leftResult, rightResult); + } + + @Override + public Expression predicate(BoundPredicate pred) { + return retainFieldIds.contains(pred.ref().fieldId()) ? pred : Expressions.alwaysTrue(); + } + + @Override + public Expression predicate(UnboundPredicate pred) { + Expression bound = Binder.bind(schema.asStruct(), pred, caseSensitive); + if (bound instanceof BoundPredicate) { + return retainFieldIds.contains(((BoundPredicate) bound).ref().fieldId()) + ? pred + : Expressions.alwaysTrue(); + } + return Expressions.alwaysTrue(); + } + } + private static class ExpressionSanitizer extends ExpressionVisitors.ExpressionVisitor { private final long now; @@ -697,14 +760,4 @@ private static String sanitizeVariantValue( } return builder.toString(); } - - private static PartitionSpec identitySpec(Schema schema, int... ids) { - PartitionSpec.Builder specBuilder = PartitionSpec.builderFor(schema); - - for (int id : ids) { - specBuilder.identity(schema.findColumnName(id)); - } - - return specBuilder.build(); - } } diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionUtil.java b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionUtil.java index a5281421888f..fdf3d9dcd1b0 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionUtil.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionUtil.java @@ -69,6 +69,24 @@ public class TestExpressionUtil { private static final Types.StructType FLOAT_TEST = Types.StructType.of(Types.NestedField.optional(1, "test", Types.FloatType.get())); + /** Schema with struct, list, and map columns for {@link #testExtractByIdInclusiveNestedTypes}. */ + private static final Schema NESTED_EXTRACT_SCHEMA = + new Schema( + Types.NestedField.required(1, "top_id", Types.LongType.get()), + Types.NestedField.optional( + 2, + "st", + Types.StructType.of( + Types.NestedField.required(3, "inner_i", Types.IntegerType.get()))), + Types.NestedField.optional( + 4, "arr", Types.ListType.ofRequired(5, Types.IntegerType.get())), + Types.NestedField.optional( + 6, + "mp", + Types.MapType.ofRequired(7, 8, Types.StringType.get(), Types.IntegerType.get()))); + + private static final Types.StructType NESTED_EXTRACT_STRUCT = NESTED_EXTRACT_SCHEMA.asStruct(); + @Test public void testUnchangedUnaryPredicates() { for (Expression unary : @@ -825,6 +843,146 @@ public void testSanitizeStringFallback() { } } + @Test + public void testExtractByIdInclusive() { + Expression alwaysTrue = Expressions.alwaysTrue(); + Expression idEq = Expressions.equal("id", 5L); + Expression valEq = Expressions.equal("val", 5); + + assertThat( + ExpressionUtil.equivalent( + alwaysTrue, + ExpressionUtil.extractByIdInclusive( + Expressions.and(idEq, valEq), SCHEMA, true, new int[0]), + STRUCT, + true)) + .isTrue(); + + assertThat( + ExpressionUtil.equivalent( + alwaysTrue, + ExpressionUtil.extractByIdInclusive( + Expressions.and(idEq, valEq), SCHEMA, true, (int[]) null), + STRUCT, + true)) + .isTrue(); + + assertThat( + ExpressionUtil.equivalent( + idEq, ExpressionUtil.extractByIdInclusive(idEq, SCHEMA, true, 1), STRUCT, true)) + .isTrue(); + + assertThat( + ExpressionUtil.equivalent( + alwaysTrue, + ExpressionUtil.extractByIdInclusive(valEq, SCHEMA, true, 1), + STRUCT, + true)) + .isTrue(); + + assertThat( + ExpressionUtil.equivalent( + idEq, + ExpressionUtil.extractByIdInclusive(Expressions.and(idEq, valEq), SCHEMA, true, 1), + STRUCT, + true)) + .isTrue(); + + Expression orBothId = Expressions.or(Expressions.equal("id", 1L), Expressions.equal("id", 2L)); + assertThat( + ExpressionUtil.equivalent( + orBothId, + ExpressionUtil.extractByIdInclusive(orBothId, SCHEMA, true, 1), + STRUCT, + true)) + .isTrue(); + } + + @Test + public void testExtractByIdInclusiveNestedTypes() { + Expression alwaysTrue = Expressions.alwaysTrue(); + Expression structPred = Expressions.equal("st.inner_i", 1); + Expression listPred = Expressions.equal("arr.element", 42); + Expression mapKeyPred = Expressions.equal("mp.key", "k"); + Expression mapValuePred = Expressions.equal("mp.value", 7); + Expression topPred = Expressions.equal("top_id", 9L); + + assertThat( + ExpressionUtil.equivalent( + structPred, + ExpressionUtil.extractByIdInclusive(structPred, NESTED_EXTRACT_SCHEMA, true, 3), + NESTED_EXTRACT_STRUCT, + true)) + .isTrue(); + assertThat( + ExpressionUtil.equivalent( + alwaysTrue, + ExpressionUtil.extractByIdInclusive(structPred, NESTED_EXTRACT_SCHEMA, true, 1), + NESTED_EXTRACT_STRUCT, + true)) + .isTrue(); + + assertThat( + ExpressionUtil.equivalent( + listPred, + ExpressionUtil.extractByIdInclusive(listPred, NESTED_EXTRACT_SCHEMA, true, 5), + NESTED_EXTRACT_STRUCT, + true)) + .isTrue(); + assertThat( + ExpressionUtil.equivalent( + alwaysTrue, + ExpressionUtil.extractByIdInclusive(listPred, NESTED_EXTRACT_SCHEMA, true, 1), + NESTED_EXTRACT_STRUCT, + true)) + .isTrue(); + + assertThat( + ExpressionUtil.equivalent( + mapKeyPred, + ExpressionUtil.extractByIdInclusive(mapKeyPred, NESTED_EXTRACT_SCHEMA, true, 7), + NESTED_EXTRACT_STRUCT, + true)) + .isTrue(); + assertThat( + ExpressionUtil.equivalent( + mapValuePred, + ExpressionUtil.extractByIdInclusive(mapValuePred, NESTED_EXTRACT_SCHEMA, true, 8), + NESTED_EXTRACT_STRUCT, + true)) + .isTrue(); + assertThat( + ExpressionUtil.equivalent( + alwaysTrue, + ExpressionUtil.extractByIdInclusive(mapKeyPred, NESTED_EXTRACT_SCHEMA, true, 8), + NESTED_EXTRACT_STRUCT, + true)) + .isTrue(); + + Expression mixed = Expressions.and(structPred, Expressions.and(listPred, topPred)); + assertThat( + ExpressionUtil.equivalent( + structPred, + ExpressionUtil.extractByIdInclusive(mixed, NESTED_EXTRACT_SCHEMA, true, 3), + NESTED_EXTRACT_STRUCT, + true)) + .isTrue(); + assertThat( + ExpressionUtil.equivalent( + listPred, + ExpressionUtil.extractByIdInclusive(mixed, NESTED_EXTRACT_SCHEMA, true, 5), + NESTED_EXTRACT_STRUCT, + true)) + .isTrue(); + assertThat( + ExpressionUtil.equivalent( + topPred, + ExpressionUtil.extractByIdInclusive(mixed, NESTED_EXTRACT_SCHEMA, true, 1), + NESTED_EXTRACT_STRUCT, + true)) + .isTrue(); + } + @Test public void testIdenticalExpressionIsEquivalent() { Expression[] exprs = diff --git a/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFilesProcedure.java b/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFilesProcedure.java index 311cf763eeef..6c744c8df4fb 100644 --- a/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFilesProcedure.java +++ b/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewritePositionDeleteFilesProcedure.java @@ -245,6 +245,68 @@ public void testRewriteSummary() throws Exception { EnvironmentContext.ENGINE_VERSION, v -> assertThat(v).startsWith("4.1")); } + @TestTemplate + public void testRewritePositionDeletesWithArrayColumns() throws Exception { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, items ARRAY>) " + + "USING iceberg TBLPROPERTIES " + + "('format-version'='2', 'write.delete.mode'='merge-on-read', 'write.update.mode'='merge-on-read')", + tableName); + + sql( + "INSERT INTO %s VALUES " + + "(1, 'a', array(named_struct('value', cast(10 as bigint), 'count', 1))), " + + "(2, 'b', array(named_struct('value', cast(20 as bigint), 'count', 2))), " + + "(3, 'c', array(named_struct('value', cast(30 as bigint), 'count', 3))), " + + "(4, 'd', array(named_struct('value', cast(40 as bigint), 'count', 4))), " + + "(5, 'e', array(named_struct('value', cast(50 as bigint), 'count', 5))), " + + "(6, 'f', array(named_struct('value', cast(60 as bigint), 'count', 6)))", + tableName); + + sql("DELETE FROM %s WHERE id = 1", tableName); + sql("DELETE FROM %s WHERE id = 2", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(TestHelpers.deleteFiles(table)).hasSizeGreaterThanOrEqualTo(1); + + sql( + "CALL %s.system.rewrite_position_delete_files(" + + "table => '%s'," + + "options => map('rewrite-all','true'))", + catalogName, tableIdent); + } + + @TestTemplate + public void testRewritePositionDeletesWithMapColumns() throws Exception { + sql( + "CREATE TABLE %s (id BIGINT, data STRING, props MAP) " + + "USING iceberg TBLPROPERTIES " + + "('format-version'='2', 'write.delete.mode'='merge-on-read', 'write.update.mode'='merge-on-read')", + tableName); + + sql( + "INSERT INTO %s VALUES " + + "(1, 'a', map('x', cast(10 as bigint))), " + + "(2, 'b', map('y', cast(20 as bigint))), " + + "(3, 'c', map('z', cast(30 as bigint))), " + + "(4, 'd', map('w', cast(40 as bigint))), " + + "(5, 'e', map('v', cast(50 as bigint))), " + + "(6, 'f', map('u', cast(60 as bigint)))", + tableName); + + sql("DELETE FROM %s WHERE id = 1", tableName); + sql("DELETE FROM %s WHERE id = 2", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + assertThat(TestHelpers.deleteFiles(table)).hasSizeGreaterThanOrEqualTo(1); + + sql( + "CALL %s.system.rewrite_position_delete_files(" + + "table => '%s'," + + "options => map('rewrite-all','true'))", + catalogName, tableIdent); + } + private Map snapshotSummary() { return validationCatalog.loadTable(tableIdent).currentSnapshot().summary(); } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java index 1a45facba6c6..f310f1830c25 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java @@ -19,8 +19,6 @@ package org.apache.iceberg.spark.source; import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.iceberg.ContentFile; import org.apache.iceberg.PositionDeletesScanTask; @@ -32,7 +30,6 @@ import org.apache.iceberg.io.CloseableIterator; import org.apache.iceberg.io.InputFile; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.iceberg.relocated.com.google.common.primitives.Ints; import org.apache.iceberg.util.ContentFileUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.catalyst.InternalRow; @@ -84,12 +81,16 @@ protected CloseableIterator open(PositionDeletesScanTask task) { InputFile inputFile = getInputFile(task.file().location()); Preconditions.checkNotNull(inputFile, "Could not find InputFile associated with %s", task); - // select out constant fields when pushing down filter to row reader + // Retain predicates on non-constant fields for row reader filter Map idToConstant = constantsMap(task, expectedSchema()); - Set nonConstantFieldIds = nonConstantFieldIds(idToConstant); + int[] nonConstantFieldIds = + expectedSchema().idToName().keySet().stream() + .filter(id -> !idToConstant.containsKey(id)) + .mapToInt(Integer::intValue) + .toArray(); Expression residualWithoutConstants = ExpressionUtil.extractByIdInclusive( - task.residual(), expectedSchema(), caseSensitive(), Ints.toArray(nonConstantFieldIds)); + task.residual(), expectedSchema(), caseSensitive(), nonConstantFieldIds); if (ContentFileUtil.isDV(task.file())) { return new DVIterator(inputFile, task.file(), expectedSchema(), idToConstant); @@ -105,12 +106,4 @@ protected CloseableIterator open(PositionDeletesScanTask task) { idToConstant) .iterator(); } - - private Set nonConstantFieldIds(Map idToConstant) { - Set fields = expectedSchema().idToName().keySet(); - return fields.stream() - .filter(id -> expectedSchema().findField(id).type().isPrimitiveType()) - .filter(id -> !idToConstant.containsKey(id)) - .collect(Collectors.toSet()); - } } diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/source/TestPositionDeletesTable.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/source/TestPositionDeletesTable.java index 0e77e70e696d..14ad107e50e3 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/source/TestPositionDeletesTable.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/source/TestPositionDeletesTable.java @@ -209,6 +209,60 @@ public void testPartitionedTable() throws IOException { dropTable(tableName); } + @TestTemplate + public void testArrayColumnFilter() throws IOException { + assumeThat(formatVersion) + .as("Row content in position_deletes is required for array column filter test") + .isEqualTo(2); + String tableName = "array_column_filter"; + Schema schemaWithArray = + new Schema( + Types.NestedField.required(1, "id", Types.IntegerType.get()), + Types.NestedField.required(2, "data", Types.StringType.get()), + Types.NestedField.optional( + 3, "arr_col", Types.ListType.ofOptional(4, Types.IntegerType.get()))); + Table tab = createTable(tableName, schemaWithArray, PartitionSpec.unpartitioned()); + + GenericRecord record1 = GenericRecord.create(schemaWithArray); + record1.set(0, 1); + record1.set(1, "a"); + record1.set(2, ImmutableList.of(1, 2)); + GenericRecord record2 = GenericRecord.create(schemaWithArray); + record2.set(0, 2); + record2.set(1, "b"); + record2.set(2, ImmutableList.of(3, 4)); + List dataRecords = ImmutableList.of(record1, record2); + DataFile dFile = + FileHelpers.writeDataFile( + tab, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(), + dataRecords); + tab.newAppend().appendFile(dFile).commit(); + + List> deletes = + ImmutableList.of( + positionDelete(schemaWithArray, dFile.location(), 0L, 1, "a", ImmutableList.of(1, 2)), + positionDelete(schemaWithArray, dFile.location(), 1L, 2, "b", ImmutableList.of(3, 4))); + DeleteFile posDeletes = + FileHelpers.writePosDeleteFile( + tab, + Files.localOutput(File.createTempFile("junit", null, temp.toFile())), + TestHelpers.Row.of(), + deletes, + formatVersion); + tab.newRowDelta().addDeletes(posDeletes).commit(); + + // Filter directly on array column: row.arr_col = array(1, 2) + StructLikeSet actual = actual(tableName, tab, "row.arr_col = array(1, 2)"); + StructLikeSet expected = expected(tab, ImmutableList.of(deletes.get(0)), null, posDeletes); + + assertThat(actual) + .as("Filtering position_deletes by row.arr_col = array(1, 2) should return matching row") + .isEqualTo(expected); + dropTable(tableName); + } + @TestTemplate public void testSelect() throws IOException { assumeThat(formatVersion).as("DVs don't have row info in PositionDeletesTable").isEqualTo(2);