Skip to content

Commit d3cdb0e

Browse files
authoredMar 7, 2024··
Implement BloomFilter query rewrite (without pushdown optimization) (#248)
* Add bloom filter might contain expression Signed-off-by: Chen Dai <daichen@amazon.com> * Fix IT to cover both codegen and eval execution test Signed-off-by: Chen Dai <daichen@amazon.com> * Add UT Signed-off-by: Chen Dai <daichen@amazon.com> * Address PR comment Signed-off-by: Chen Dai <daichen@amazon.com> --------- Signed-off-by: Chen Dai <daichen@amazon.com>
1 parent 05af470 commit d3cdb0e

File tree

6 files changed

+219
-13
lines changed

6 files changed

+219
-13
lines changed
 

‎docs/index.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Please see the following example in which Index Building Logic and Query Rewrite
2525
| Partition | CREATE SKIPPING INDEX<br>ON alb_logs<br> (<br>&nbsp;&nbsp;year PARTITION,<br>&nbsp;&nbsp;month PARTITION,<br>&nbsp;&nbsp;day PARTITION,<br>&nbsp;&nbsp;hour PARTITION<br>) | INSERT INTO flint_alb_logs_skipping_index<br>SELECT<br>&nbsp;&nbsp;FIRST(year) AS year,<br>&nbsp;&nbsp;FIRST(month) AS month,<br>&nbsp;&nbsp;FIRST(day) AS day,<br>&nbsp;&nbsp;FIRST(hour) AS hour,<br>&nbsp;&nbsp;input_file_name() AS file_path<br>FROM alb_logs<br>GROUP BY<br>&nbsp;&nbsp;input_file_name() | SELECT *<br>FROM alb_logs<br>WHERE year = 2023 AND month = 4<br>=><br>SELECT *<br>FROM alb_logs (input_files = <br>&nbsp;&nbsp;SELECT file_path<br>&nbsp;&nbsp;FROM flint_alb_logs_skipping_index<br>&nbsp;&nbsp;WHERE year = 2023 AND month = 4<br>)<br>WHERE year = 2023 AND month = 4 |
2626
| ValueSet | CREATE SKIPPING INDEX<br>ON alb_logs<br> (<br>&nbsp;&nbsp;elb_status_code VALUE_SET<br>) | INSERT INTO flint_alb_logs_skipping_index<br>SELECT<br>&nbsp;&nbsp;COLLECT_SET(elb_status_code) AS elb_status_code,<br>&nbsp;&nbsp;input_file_name() AS file_path<br>FROM alb_logs<br>GROUP BY<br>&nbsp;&nbsp;input_file_name() | SELECT *<br>FROM alb_logs<br>WHERE elb_status_code = 404<br>=><br>SELECT *<br>FROM alb_logs (input_files = <br>&nbsp;&nbsp;SELECT file_path<br>&nbsp;&nbsp;FROM flint_alb_logs_skipping_index<br>&nbsp;&nbsp;WHERE ARRAY_CONTAINS(elb_status_code, 404)<br>)<br>WHERE elb_status_code = 404 |
2727
| MinMax | CREATE SKIPPING INDEX<br>ON alb_logs<br> (<br>&nbsp;&nbsp;request_processing_time MIN_MAX<br>) | INSERT INTO flint_alb_logs_skipping_index<br>SELECT<br>&nbsp;&nbsp;MIN(request_processing_time) AS request_processing_time_min,<br>&nbsp;&nbsp;MAX(request_processing_time) AS request_processing_time_max,<br>&nbsp;&nbsp;input_file_name() AS file_path<br>FROM alb_logs<br>GROUP BY<br>&nbsp;&nbsp;input_file_name() | SELECT *<br>FROM alb_logs<br>WHERE request_processing_time = 100<br>=><br>SELECT *<br>FROM alb_logs (input_files = <br> SELECT file_path<br>&nbsp;&nbsp;FROM flint_alb_logs_skipping_index<br>&nbsp;&nbsp;WHERE request_processing_time_min <= 100<br>&nbsp;&nbsp;&nbsp;&nbsp;AND 100 <= request_processing_time_max<br>)<br>WHERE request_processing_time = 100 |
28-
| BloomFilter | CREATE SKIPPING INDEX<br>ON alb_logs<br> (<br>&nbsp;&nbsp;client_ip BLOOM_FILTER<br>) | INSERT INTO flint_alb_logs_skipping_index<br>SELECT<br>&nbsp;&nbsp;BLOOM_FILTER_AGG(client_ip) AS client_ip,<br>&nbsp;&nbsp;input_file_name() AS file_path<br>FROM alb_logs<br>GROUP BY<br>&nbsp;&nbsp;input_file_name() | SELECT *<br>FROM alb_logs<br>WHERE client_ip = '127.0.0.1'<br>=><br>SELECT *<br>FROM alb_logs (input_files = <br>&nbsp;&nbsp;SELECT file_path<br>&nbsp;&nbsp;FROM flint_alb_logs_skipping_index<br>&nbsp;&nbsp;WHERE BLOOM_FILTER_MIGHT_CONTAIN(client_ip, '127.0.0.1') = true<br>)<br>WHERE client_ip = '127.0.0.1' |
28+
| BloomFilter | CREATE SKIPPING INDEX<br>ON alb_logs<br> (<br>&nbsp;&nbsp;client_ip BLOOM_FILTER<br>) | INSERT INTO flint_alb_logs_skipping_index<br>SELECT<br>&nbsp;&nbsp;BLOOM_FILTER_AGG(client_ip) AS client_ip,<br>&nbsp;&nbsp;input_file_name() AS file_path<br>FROM alb_logs<br>GROUP BY<br>&nbsp;&nbsp;input_file_name() | SELECT *<br>FROM alb_logs<br>WHERE client_ip = '127.0.0.1'<br>=><br>SELECT *<br>FROM alb_logs (input_files = <br>&nbsp;&nbsp;SELECT file_path<br>&nbsp;&nbsp;FROM flint_alb_logs_skipping_index<br>&nbsp;&nbsp;WHERE BLOOM_FILTER_MIGHT_CONTAIN(client_ip, '127.0.0.1')<br>)<br>WHERE client_ip = '127.0.0.1' |
2929

3030
### Flint Index Refresh
3131

‎flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilter.java

+15-8
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,23 @@ public void writeTo(OutputStream out) throws IOException {
132132
* @param in input stream
133133
* @return bloom filter
134134
*/
135-
public static BloomFilter readFrom(InputStream in) throws IOException {
136-
DataInputStream dis = new DataInputStream(in);
135+
public static BloomFilter readFrom(InputStream in) {
136+
try {
137+
DataInputStream dis = new DataInputStream(in);
138+
139+
// Check version compatibility
140+
int version = dis.readInt();
141+
if (version != Version.V1.getVersionNumber()) {
142+
throw new IllegalStateException("Unexpected Bloom filter version number (" + version + ")");
143+
}
137144

138-
int version = dis.readInt();
139-
if (version != Version.V1.getVersionNumber()) {
140-
throw new IOException("Unexpected Bloom filter version number (" + version + ")");
145+
// Read bloom filter content
146+
int numHashFunctions = dis.readInt();
147+
BitArray bits = BitArray.readFrom(dis);
148+
return new ClassicBloomFilter(bits, numHashFunctions);
149+
} catch (IOException e) {
150+
throw new RuntimeException(e);
141151
}
142-
int numHashFunctions = dis.readInt();
143-
BitArray bits = BitArray.readFrom(dis);
144-
return new ClassicBloomFilter(bits, numHashFunctions);
145152
}
146153

147154
private static int optimalNumOfHashFunctions(long n, long m) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.flint.spark.skipping.bloomfilter
7+
8+
import java.io.ByteArrayInputStream
9+
10+
import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter
11+
12+
import org.apache.spark.sql.Column
13+
import org.apache.spark.sql.catalyst.InternalRow
14+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
15+
import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Expression}
16+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
17+
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
18+
import org.apache.spark.sql.functions.{col, lit}
19+
import org.apache.spark.sql.types._
20+
21+
/**
22+
* Bloom filter function that returns the membership check result for values of `valueExpression`
23+
* in the bloom filter represented by `bloomFilterExpression`.
24+
*
25+
* @param bloomFilterExpression
26+
* binary expression that represents bloom filter data
27+
* @param valueExpression
28+
* Long value expression to be tested
29+
*/
30+
case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpression: Expression)
31+
extends BinaryComparison {
32+
33+
override def nullable: Boolean = true
34+
35+
override def left: Expression = bloomFilterExpression
36+
37+
override def right: Expression = valueExpression
38+
39+
override def prettyName: String = "bloom_filter_might_contain"
40+
41+
override def dataType: DataType = BooleanType
42+
43+
override def symbol: String = "BLOOM_FILTER_MIGHT_CONTAIN"
44+
45+
override def checkInputDataTypes(): TypeCheckResult = {
46+
(left.dataType, right.dataType) match {
47+
case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) |
48+
(BinaryType, LongType) =>
49+
TypeCheckResult.TypeCheckSuccess
50+
case _ =>
51+
TypeCheckResult.TypeCheckFailure(s"""
52+
| Input to function $prettyName should be Binary expression followed by a Long value,
53+
| but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].
54+
| """.stripMargin)
55+
}
56+
}
57+
58+
override protected def withNewChildrenInternal(
59+
newBloomFilterExpression: Expression,
60+
newValueExpression: Expression): BloomFilterMightContain =
61+
copy(bloomFilterExpression = newBloomFilterExpression, valueExpression = newValueExpression)
62+
63+
override def eval(input: InternalRow): Any = {
64+
val value = valueExpression.eval(input)
65+
if (value == null) {
66+
null
67+
} else {
68+
val bytes = bloomFilterExpression.eval(input).asInstanceOf[Array[Byte]]
69+
val bloomFilter = ClassicBloomFilter.readFrom(new ByteArrayInputStream(bytes))
70+
bloomFilter.mightContain(value.asInstanceOf[Long])
71+
}
72+
}
73+
74+
/**
75+
* Generate expression code for Spark codegen execution. Sample result code:
76+
* ```
77+
* boolean filter_isNull_0 = true;
78+
* boolean filter_value_0 = false;
79+
* if (!right_isNull) {
80+
* filter_isNull_0 = false;
81+
* filter_value_0 =
82+
* org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter.readFrom(
83+
* new java.io.ByteArrayInputStream(left_value)
84+
* ).mightContain(right_value);
85+
* }
86+
* ```
87+
*/
88+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
89+
val leftGen = left.genCode(ctx)
90+
val rightGen = right.genCode(ctx)
91+
val bloomFilterEncoder = classOf[ClassicBloomFilter].getCanonicalName.stripSuffix("$")
92+
val bf = s"$bloomFilterEncoder.readFrom(new java.io.ByteArrayInputStream(${leftGen.value}))"
93+
val result = s"$bf.mightContain(${rightGen.value})"
94+
val resultCode =
95+
s"""
96+
|if (!(${rightGen.isNull})) {
97+
| ${leftGen.code}
98+
| ${ev.isNull} = false;
99+
| ${ev.value} = $result;
100+
|}
101+
""".stripMargin
102+
ev.copy(code = code"""
103+
${rightGen.code}
104+
boolean ${ev.isNull} = true;
105+
boolean ${ev.value} = false;
106+
$resultCode""")
107+
}
108+
}
109+
110+
object BloomFilterMightContain {
111+
112+
/**
113+
* Generate bloom filter might contain function given the bloom filter column and value.
114+
*
115+
* @param colName
116+
* column name
117+
* @param value
118+
* value
119+
* @return
120+
* bloom filter might contain expression
121+
*/
122+
def bloom_filter_might_contain(colName: String, value: Any): Column = {
123+
new Column(BloomFilterMightContain(col(colName).expr, lit(value).expr))
124+
}
125+
}

‎flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala

+11-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
package org.opensearch.flint.spark.skipping.bloomfilter
77

88
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
9+
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.IndexColumnExtractor
910
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, SkippingKind}
1011
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, DEFAULT_CLASSIC_BLOOM_FILTER_FPP, DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS}
1112

12-
import org.apache.spark.sql.catalyst.expressions.Expression
13+
import org.apache.spark.sql.Column
14+
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, Literal}
1315
import org.apache.spark.sql.functions.{col, xxhash64}
1416

1517
/**
@@ -37,7 +39,14 @@ case class BloomFilterSkippingStrategy(
3739
) // TODO: use xxhash64() for now
3840
}
3941

40-
override def rewritePredicate(predicate: Expression): Option[Expression] = None
42+
override def rewritePredicate(predicate: Expression): Option[Expression] = {
43+
val IndexColumn = IndexColumnExtractor(columnName)
44+
predicate match {
45+
case EqualTo(IndexColumn(indexCol), value: Literal) =>
46+
Some(BloomFilterMightContain(indexCol.expr, xxhash64(new Column(value)).expr))
47+
case _ => None
48+
}
49+
}
4150

4251
private def expectedNumItems: Int = {
4352
params
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.flint.spark.skipping.bloomfilter
7+
8+
import org.apache.spark.FlintSuite
9+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
10+
import org.apache.spark.sql.catalyst.expressions.Literal
11+
import org.apache.spark.sql.types.{BinaryType, DoubleType, LongType, StringType}
12+
import org.apache.spark.unsafe.types.UTF8String
13+
14+
class BloomFilterMightContainSuite extends FlintSuite {
15+
16+
test("checkInputDataTypes should succeed for valid input types") {
17+
val binaryExpression = Literal(Array[Byte](1, 2, 3), BinaryType)
18+
val longExpression = Literal(42L, LongType)
19+
20+
val bloomFilterMightContain = BloomFilterMightContain(binaryExpression, longExpression)
21+
assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckSuccess)
22+
}
23+
24+
test("checkInputDataTypes should succeed for valid input types with nulls") {
25+
val binaryExpression = Literal.create(null, BinaryType)
26+
val longExpression = Literal.create(null, LongType)
27+
28+
val bloomFilterMightContain = BloomFilterMightContain(binaryExpression, longExpression)
29+
assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckSuccess)
30+
}
31+
32+
test("checkInputDataTypes should fail for invalid input types") {
33+
val stringExpression = Literal(UTF8String.fromString("invalid"), StringType)
34+
val doubleExpression = Literal(3.14, DoubleType)
35+
36+
val bloomFilterMightContain = BloomFilterMightContain(stringExpression, doubleExpression)
37+
val expectedErrorMsg =
38+
s"""
39+
| Input to function bloom_filter_might_contain should be Binary expression followed by a Long value,
40+
| but it's [${stringExpression.dataType.catalogString}, ${doubleExpression.dataType.catalogString}].
41+
| """.stripMargin
42+
43+
assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckFailure(expectedErrorMsg))
44+
}
45+
}

‎integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala

+22-2
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@ import org.opensearch.flint.core.FlintVersion.current
1212
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
1313
import org.opensearch.flint.spark.skipping.FlintSparkSkippingFileIndex
1414
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
15+
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain.bloom_filter_might_contain
1516
import org.opensearch.index.query.QueryBuilders
1617
import org.opensearch.index.reindex.DeleteByQueryRequest
1718
import org.scalatest.matchers.{Matcher, MatchResult}
1819
import org.scalatest.matchers.must.Matchers._
1920
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
2021

2122
import org.apache.spark.sql.{Column, Row}
23+
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
2224
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
2325
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
2426
import org.apache.spark.sql.flint.config.FlintSparkConf._
25-
import org.apache.spark.sql.functions.{col, isnull}
27+
import org.apache.spark.sql.functions.{col, isnull, lit, xxhash64}
28+
import org.apache.spark.sql.internal.SQLConf
2629

2730
class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
2831

@@ -390,7 +393,24 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
390393
// Assert index data
391394
flint.queryIndex(testIndex).collect() should have size 2
392395

393-
// TODO: Assert query rewrite result
396+
// Assert query result and rewrite
397+
def assertQueryRewrite(): Unit = {
398+
val query = sql(s"SELECT name FROM $testTable WHERE age = 50")
399+
checkAnswer(query, Row("Java"))
400+
query.queryExecution.executedPlan should
401+
useFlintSparkSkippingFileIndex(
402+
hasIndexFilter(bloom_filter_might_contain("age", xxhash64(lit(50)))))
403+
}
404+
405+
// Test expression with codegen enabled by default
406+
assertQueryRewrite()
407+
408+
// Test expression evaluation with codegen disabled
409+
withSQLConf(
410+
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
411+
SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString) {
412+
assertQueryRewrite()
413+
}
394414
}
395415

396416
test("should rewrite applicable query with table name without database specified") {

0 commit comments

Comments
 (0)
Please sign in to comment.