diff --git a/.config/typos.toml b/.config/typos.toml index 62ee56c73..9b16edb89 100644 --- a/.config/typos.toml +++ b/.config/typos.toml @@ -28,4 +28,5 @@ updat = "updat" # Used for stem matching extend-ignore-re = [ "baNAna", "eXIst", + "Hel", ] diff --git a/integration/test_fulltext.py b/integration/test_fulltext.py index dcd8ebde9..5c3188423 100644 --- a/integration/test_fulltext.py +++ b/integration/test_fulltext.py @@ -28,10 +28,10 @@ ] text_query_term = ["FT.SEARCH", "products", '@desc:"wonder"'] text_query_term_nomatch = ["FT.SEARCH", "products", '@desc:"nomatch"'] -text_query_prefix = ["FT.SEARCH", "products", '@desc:"wond*"'] -text_query_prefix2 = ["FT.SEARCH", "products", '@desc:"wond*"'] -text_query_prefix_nomatch = ["FT.SEARCH", "products", '@desc:"nomatch*"'] -text_query_prefix_multimatch = ["FT.SEARCH", "products", '@desc:"grea*"'] +text_query_prefix = ["FT.SEARCH", "products", '@desc:wond*'] +text_query_prefix2 = ["FT.SEARCH", "products", '@desc:wond*'] +text_query_prefix_nomatch = ["FT.SEARCH", "products", '@desc:nomatch*'] +text_query_prefix_multimatch = ["FT.SEARCH", "products", '@desc:grea*'] text_query_exact_phrase1 = ["FT.SEARCH", "products", '@desc:"word wonder"'] text_query_exact_phrase2 = ["FT.SEARCH", "products", '@desc:"random word wonder"'] @@ -55,9 +55,9 @@ # Search queries for specific fields text_query_desc_field = ["FT.SEARCH", "products2", '@desc:"wonder"'] -text_query_desc_prefix = ["FT.SEARCH", "products2", '@desc:"wonde*"'] +text_query_desc_prefix = ["FT.SEARCH", "products2", '@desc:wonde*'] text_query_desc2_field = ["FT.SEARCH", "products2", '@desc2:"wonder"'] -text_query_desc2_prefix = ["FT.SEARCH", "products2", '@desc2:"wonde*"'] +text_query_desc2_prefix = ["FT.SEARCH", "products2", '@desc2:wonde*'] # Expected results for desc field search expected_desc_hash_key = b'product:4' @@ -124,18 +124,18 @@ def test_text_search(self): result3 = client.execute_command("FT.SEARCH", "products", '@desc:xpe*') assert result1[0] == 1 and result2[0] == 1 and result3[0] == 0 assert result1[1] == b"product:3" and result2[1] == b"product:3" - # TODO: Update these queries to non stemmed versions after queries are stemmed. + # TODO: Update these queries to non stemmed versions once the stem tree is supported and ingestion is updated. # Perform an exact phrase search operation on a unique phrase (exists in one doc). result1 = client.execute_command("FT.SEARCH", "products", '@desc:"great oak from littl"') result2 = client.execute_command("FT.SEARCH", "products", '@desc:"great oak from littl grey acorn grow"') assert result1[0] == 1 and result2[0] == 1 assert result1[1] == b"product:1" and result2[1] == b"product:1" - result3 = client.execute_command("FT.SEARCH", "products", '@desc:great @desc:oa* @desc:from @desc:lit* @desc:gr* @desc:acorn @desc:gr*') + result3 = client.execute_command("FT.SEARCH", "products", 'great oa* from lit* gr* acorn gr*') assert result3[0] == 1 assert result3[1] == b"product:1" - result3 = client.execute_command("FT.SEARCH", "products", '@desc:great @desc:oa* @desc:from @desc:lit* @desc:gr* @desc:acorn @desc:grea*') + result3 = client.execute_command("FT.SEARCH", "products", 'great oa* from lit* gr* acorn grea*') assert result3[0] == 0 - result3 = client.execute_command("FT.SEARCH", "products", '@desc:great @desc:oa* @desc:from @desc:lit* @desc:gr* @desc:acorn @desc:great') + result3 = client.execute_command("FT.SEARCH", "products", 'great oa* from lit* gr* acorn great') assert result3[0] == 0 # Perform an exact phrase search operation on a phrase existing in 2 documents. result = client.execute_command("FT.SEARCH", "products", '@desc:"interest desc"') @@ -173,7 +173,6 @@ def test_text_search(self): result = client.execute_command("FT.SEARCH", "products", '@desc:"1 2 3 4 5 6 7 8 9 0"') assert result[0] == 1 assert result[1] == b"product:1" - # TODO: We can test this once the queries are tokenized with punctuation applied. # result = client.execute_command("FT.SEARCH", "products", '@desc:"inspector\'s palm"') # TODO: We can test this once the queries are tokenized with punctuation and stopword removal applied. @@ -364,21 +363,22 @@ def test_default_tokenization(self): client: Valkey = self.server.get_new_client() client.execute_command("FT.CREATE idx ON HASH SCHEMA content TEXT") client.execute_command("HSET", "doc:1", "content", "The quick-running searches are finding EFFECTIVE results!") - - # List of queries with pass/fail expectations + client.execute_command("HSET", "doc:2", "content", "But slow searches aren't working...") + # List of queries with match / no match expectations test_cases = [ ("quick*", True, "Punctuation tokenization - hyphen creates word boundaries"), ("effect*", True, "Case insensitivity - lowercase matches uppercase"), - ("the", False, "Stop word filtering - common words filtered out"), + ("\"The quick-running searches are finding EFFECTIVE results!\"", False, "Stop word cannot be used in exact phrase searches"), + # TODO: Change to True once the stem tree is supported and ingestion is updated. + ("\"quick-running searches finding EFFECTIVE results!\"", False, "Exact phrase without stopwords"), + ("\"quick-run search find EFFECT result!\"", True, "Exact Phrase Query without stopwords and using stemmed words"), ("find*", True, "Prefix wildcard - matches 'finding'"), ("nonexistent", False, "Non-existent terms return no results") ] - expected_key = b'doc:1' expected_fields = [b'content', b"The quick-running searches are finding EFFECTIVE results!"] - for query_term, should_match, description in test_cases: - result = client.execute_command("FT.SEARCH", "idx", f'@content:"{query_term}"') + result = client.execute_command("FT.SEARCH", "idx", f'@content:{query_term}') if should_match: assert result[0] == 1 and result[1] == expected_key and result[2] == expected_fields, f"Failed: {description}" else: @@ -412,16 +412,44 @@ def test_custom_stopwords(self): client: Valkey = self.server.get_new_client() client.execute_command("FT.CREATE idx ON HASH STOPWORDS 2 the and SCHEMA content TEXT") client.execute_command("HSET", "doc:1", "content", "the cat and dog are good") + # non stop words should be findable + result = client.execute_command("FT.SEARCH", "idx", '@content:"cat dog are good"') + assert result[0] == 1 # Regular word indexed + assert result[1] == b'doc:1' + assert result[2] == [b'content', b"the cat and dog are good"] # Stop words should not be findable result = client.execute_command("FT.SEARCH", "idx", '@content:"and"') assert result[0] == 0 # Stop word "and" filtered out - # non stop words should be findable result = client.execute_command("FT.SEARCH", "idx", '@content:"are"') assert result[0] == 1 # Regular word indexed assert result[1] == b'doc:1' assert result[2] == [b'content', b"the cat and dog are good"] + # Stop words should not be findable + result = client.execute_command("FT.SEARCH", "idx", '@content:"and"') + assert result[0] == 0 # Stop word "and" filtered out + + def test_nostem(self): + """ + End-to-end test: FT.CREATE NOSTEM config actually affects stemming in search + """ + client: Valkey = self.server.get_new_client() + client.execute_command("FT.CREATE idx ON HASH NOSTEM SCHEMA content TEXT") + client.execute_command("HSET", "doc:1", "content", "running quickly") + # With NOSTEM, exact tokens should be findable with exact phrase + result = client.execute_command("FT.SEARCH", "idx", '@content:"running"') + assert result[0] == 1 # Exact form "running" found + assert result[1] == b'doc:1' + assert result[2] == [b'content', b"running quickly"] + # With NOSTEM, exact tokens should be findable with non exact phrase + result = client.execute_command("FT.SEARCH", "idx", '@content:"running"') + assert result[0] == 1 # Exact form "running" found + assert result[1] == b'doc:1' + assert result[2] == [b'content', b"running quickly"] + # With NOSTEM, stemmed tokens should not be findable + result = client.execute_command("FT.SEARCH", "idx", '@content:"run"') + assert result[0] == 0 def test_custom_punctuation(self): """ @@ -430,16 +458,18 @@ def test_custom_punctuation(self): client: Valkey = self.server.get_new_client() client.execute_command("FT.CREATE idx ON HASH PUNCTUATION . SCHEMA content TEXT") client.execute_command("HSET", "doc:1", "content", "hello.world test@email") - # Dot configured as separator - should find split words result = client.execute_command("FT.SEARCH", "idx", '@content:"hello"') assert result[0] == 1 # Found "hello" as separate token assert result[1] == b'doc:1' assert result[2] == [b'content', b"hello.world test@email"] - # @ NOT configured as separator - should not be able with split words result = client.execute_command("FT.SEARCH", "idx", '@content:"test"') assert result[0] == 0 + result = client.execute_command("FT.SEARCH", "idx", '@content:"test@email"') + assert result[0] == 1 # Found "hello" as separate token + assert result[1] == b'doc:1' + assert result[2] == [b'content', b"hello.world test@email"] def test_add_update_delete_documents_single_client(self): """ @@ -637,8 +667,29 @@ def delete_documents(client_id): perform_concurrent_searches(clients, num_clients, delete_searches, "DELETE") def test_suffix_search(self): - # TODO - pass + """Test suffix search functionality using *suffix pattern""" + # Create index + self.client.execute_command("FT.CREATE", "idx", "ON", "HASH", "PREFIX", "1", "doc:", "SCHEMA", "content", "TEXT", "WITHSUFFIXTRIE", "NOSTEM") + # Add test documents + self.client.execute_command("HSET", "doc:1", "content", "running jumping walking") + self.client.execute_command("HSET", "doc:2", "content", "testing debugging coding") + self.client.execute_command("HSET", "doc:3", "content", "reading writing speaking") + self.client.execute_command("HSET", "doc:4", "content", "swimming diving surfing") + # Test suffix search with *ing + result = self.client.execute_command("FT.SEARCH", "idx", "@content:*ing") + assert result[0] == 4 # All documents contain words ending with 'ing' + # Test suffix search with *ing (should match running, jumping, walking, etc.) + result = self.client.execute_command("FT.SEARCH", "idx", "@content:*ning") + assert result[0] == 1 # Only doc:1 has "running" + # Test suffix search with *ing + result = self.client.execute_command("FT.SEARCH", "idx", "@content:*ping") + assert result[0] == 1 # Only doc:1 has "jumping" + # Test suffix search with *ing + result = self.client.execute_command("FT.SEARCH", "idx", "@content:*ding") + assert result[0] == 2 # doc:2 has "coding", doc:3 has "reading" + # Test non-matching suffix + result = self.client.execute_command("FT.SEARCH", "idx", "@content:*xyz") + assert result[0] == 0 # No matches class TestFullTextDebugMode(ValkeySearchTestCaseDebugMode): """ diff --git a/src/commands/filter_parser.cc b/src/commands/filter_parser.cc index 6a6f9453e..cbb28a2b7 100644 --- a/src/commands/filter_parser.cc +++ b/src/commands/filter_parser.cc @@ -149,28 +149,27 @@ void PrintPredicate(const query::Predicate* pred, int depth, bool last, } else if (auto term = dynamic_cast(pred)) { VMSDK_LOG(WARNING, nullptr) << prefix << "TERM(" << term->GetTextString() << ")_" - << term->GetIdentifier() << "\n"; + << term->GetFieldMask() << "\n"; } else if (auto pre = dynamic_cast(pred)) { VMSDK_LOG(WARNING, nullptr) << prefix << "PREFIX(" << pre->GetTextString() << ")_" - << pre->GetIdentifier() << "\n"; + << pre->GetFieldMask() << "\n"; } else if (auto pre = dynamic_cast(pred)) { - valid = false; VMSDK_LOG(WARNING, nullptr) << prefix << "Suffix(" << pre->GetTextString() << ")_" - << pre->GetIdentifier() << "\n"; + << pre->GetFieldMask() << "\n"; } else if (auto pre = dynamic_cast(pred)) { valid = false; VMSDK_LOG(WARNING, nullptr) << prefix << "Infix(" << pre->GetTextString() << ")_" - << pre->GetIdentifier() << "\n"; + << pre->GetFieldMask() << "\n"; } else if (auto fuzzy = dynamic_cast(pred)) { valid = false; VMSDK_LOG(WARNING, nullptr) << prefix << "FUZZY(" << fuzzy->GetTextString() << ", distance=" << fuzzy->GetDistance() << ")_" - << fuzzy->GetIdentifier() << "\n"; + << fuzzy->GetFieldMask() << "\n"; } else { valid = false; VMSDK_LOG(WARNING, nullptr) << prefix << "UNKNOWN TEXT\n"; @@ -200,9 +199,11 @@ void PrintPredicate(const query::Predicate* pred, int depth, bool last, } FilterParser::FilterParser(const IndexSchema& index_schema, - absl::string_view expression) + absl::string_view expression, + const TextParsingOptions& options) : index_schema_(index_schema), - expression_(absl::StripAsciiWhitespace(expression)) {} + expression_(absl::StripAsciiWhitespace(expression)), + options_(options) {} bool FilterParser::Match(char expected, bool skip_whitespace) { if (skip_whitespace) { @@ -398,7 +399,7 @@ absl::StatusOr FilterParser::IsMatchAllExpression() { } return absl::InvalidArgumentError("Missing `)`"); } - return UnexpectedChar(expression_, pos_); + return false; } absl::StatusOr FilterParser::Parse() { @@ -449,218 +450,325 @@ std::unique_ptr WrapPredicate( static const uint32_t FUZZY_MAX_DISTANCE = 3; -absl::StatusOr> -FilterParser::BuildSingleTextPredicate(const std::string& field_name, - absl::string_view raw_token) { - // --- Validate the field is a text index --- - auto index = index_schema_.GetIndex(field_name); - if (!index.ok() || - index.value()->GetIndexerType() != indexes::IndexerType::kText) { - return absl::InvalidArgumentError( - absl::StrCat("`", field_name, "` is not indexed as a text field")); +// Handles backslash escaping for both quoted and unquoted text +// Escape Syntax: +// \\ -> \ +// \ -> +// \ -> (break to new token)... +// \ -> Return error +absl::StatusOr FilterParser::HandleBackslashEscape( + const indexes::text::Lexer& lexer, std::string& processed_content) { + if (!Match('\\', false)) { + // No backslash, continue normal processing of the same token. + return true; } - auto identifier = index_schema_.GetIdentifier(field_name).value(); - filter_identifiers_.insert(identifier); - auto* text_index = dynamic_cast(index.value().get()); - absl::string_view token = absl::StripAsciiWhitespace(raw_token); - if (token.empty()) { - return absl::InvalidArgumentError("Empty text token"); - } - // --- Fuzzy --- - size_t lead_pct = 0; - while (lead_pct < token.size() && token[lead_pct] == '%') { - ++lead_pct; - if (lead_pct > FUZZY_MAX_DISTANCE) { - return absl::InvalidArgumentError("Too many leading '%' markers"); - } - } - size_t tail_pct = 0; - while (tail_pct < token.size() && token[token.size() - 1 - tail_pct] == '%') { - ++tail_pct; - if (tail_pct > FUZZY_MAX_DISTANCE) { - return absl::InvalidArgumentError("Too many trailing '%' markers"); - } - } - if (lead_pct || tail_pct) { - if (lead_pct != tail_pct) { - return absl::InvalidArgumentError("Mismatched fuzzy '%' markers"); - } - absl::string_view core = token; - core.remove_prefix(lead_pct); - core.remove_suffix(tail_pct); - if (core.empty()) { - return absl::InvalidArgumentError("Empty fuzzy token"); - } - return std::make_unique( - text_index, identifier, field_name, std::string(core), lead_pct); - } - // --- Wildcard --- - bool starts_star = !token.empty() && token.front() == '*'; - bool ends_star = !token.empty() && token.back() == '*'; - if (starts_star || ends_star) { - absl::string_view core = token; - if (starts_star) core.remove_prefix(1); - if (ends_star) core.remove_suffix(1); - if (core.empty()) { - return absl::InvalidArgumentError( - "Wildcard token must contain at least one character besides '*'"); - } - if (starts_star && ends_star) { - return std::make_unique( - text_index, identifier, field_name, std::string(core)); + if (!IsEnd()) { + char next_ch = Peek(); + if (next_ch == '\\' || lexer.IsPunctuation(next_ch)) { + // If Double backslash, retain the double backslash + // If Single backslash with punct on right, retain the char on right + processed_content.push_back(next_ch); + ++pos_; + // Continue parsing the same token. + return true; + } else { + // Single backslash with non-punct on right, consume the backslash and + // break into a new token. + return false; } - if (starts_star) { - return std::make_unique( - text_index, identifier, field_name, std::string(core)); + } else { + // Unescaped backslash at end of input is invalid. + return absl::InvalidArgumentError( + "Invalid escape sequence: backslash at end of input"); + } +} + +// Returns a token within an exact phrase parsing it until reaching the +// token boundary while handling escape chars. +// Quoted Text Syntax: +// word1 word2" word3 -> word1 +// word2" word3 -> word2 +// Token boundaries (separated by space): " \ +absl::StatusOr FilterParser::ParseQuotedTextToken( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::optional min_stem_size) { + const auto& lexer = text_index_schema->GetLexer(); + std::string processed_content; + while (!IsEnd()) { + VMSDK_ASSIGN_OR_RETURN(bool should_continue, + HandleBackslashEscape(lexer, processed_content)); + if (!should_continue) { + break; } - return std::make_unique( - text_index, identifier, field_name, std::string(core)); + // Break to complete an exact phrase or start a new exact phrase. + char ch = Peek(); + if (ch == '"') break; + if (lexer.IsPunctuation(ch)) break; + processed_content.push_back(ch); + ++pos_; + } + if (processed_content.empty()) { + return FilterParser::TokenResult{nullptr, false}; } - // --- Term --- - return std::make_unique(text_index, identifier, - field_name, std::string(token)); + std::string token = absl::AsciiStrToLower(processed_content); + return FilterParser::TokenResult{ + std::make_unique(text_index_schema, field_mask, + std::move(token), true), + false}; } -// TODO: Needs punctuation handing -absl::StatusOr>> -FilterParser::ParseOneTextAtomIntoTerms(const std::string& field_for_default) { - std::vector> terms; - SkipWhitespace(); - auto push_token = [&](std::string& tok) -> absl::Status { - if (tok.empty()) return absl::OkStatus(); - VMSDK_ASSIGN_OR_RETURN(auto t, - BuildSingleTextPredicate(field_for_default, tok)); - terms.push_back(std::move(t)); - tok.clear(); - return absl::OkStatus(); - }; - if (Match('"')) { - std::string curr; - while (!IsEnd()) { - char c = Peek(); - if (c == '"') { - ++pos_; +// Returns a token after parsing it until the token boundary while handling +// escape chars. +// Unquoted Text Syntax: +// Term: word +// Prefix: word* +// Suffix: *word +// Infix: *word* +// Fuzzy: %word% | %%word%% | %%%word%%% +// Token boundaries: +// ( ) | @ " - { } [ ] : ; $ +// Reserved chars: +// { } [ ] : ; $ -> error +absl::StatusOr FilterParser::ParseUnquotedTextToken( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::optional min_stem_size) { + const auto& lexer = text_index_schema->GetLexer(); + std::string processed_content; + bool starts_with_star = false; + bool ends_with_star = false; + size_t leading_percent_count = 0; + size_t trailing_percent_count = 0; + bool break_on_query_syntax = false; + while (!IsEnd()) { + VMSDK_ASSIGN_OR_RETURN(bool should_continue, + HandleBackslashEscape(lexer, processed_content)); + if (!should_continue) { + break; + } + char ch = Peek(); + // Break on non text specific query syntax characters. + if (ch == ')' || ch == '|' || ch == '(' || ch == '@') { + break_on_query_syntax = true; + break; + } + // Reject reserved characters in unquoted text + if (ch == '{' || ch == '}' || ch == '[' || ch == ']' || ch == ':' || + ch == ';' || ch == '$') { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected character at position ", pos_ + 1, ": `", + expression_.substr(pos_, 1), "`")); + } + // - characters in the middle of text tokens are not negate. If they are in + // the beginning, break. + if (ch == '-' && processed_content.empty()) { + break_on_query_syntax = true; + break; + } + // Break to complete an exact phrase or start a new exact phrase. + if (ch == '"') break; + // Handle fuzzy token boundary detection + if (ch == '%') { + if (processed_content.empty()) { + // Leading percent + while (Match('%', false)) { + leading_percent_count++; + if (leading_percent_count > FUZZY_MAX_DISTANCE) break; + } + continue; + } else { + // If there was no leading percent, we break. + // Else, we keep consuming trailing percent (to match the leading count) + // - count them + while (trailing_percent_count < leading_percent_count && + Match('%', false)) { + trailing_percent_count++; + } break; } - if (std::isspace(static_cast(c))) { - VMSDK_RETURN_IF_ERROR(push_token(curr)); - ++pos_; + } + // Handle wildcard token boundary detection + if (Match('*', false)) { + if (processed_content.empty() && !starts_with_star) { + starts_with_star = true; + continue; } else { - curr.push_back(c); - ++pos_; + // Trailing star + ends_with_star = true; + break; } } - VMSDK_RETURN_IF_ERROR(push_token(curr)); - if (terms.empty()) return absl::InvalidArgumentError("Empty quoted string"); - return terms; // exact phrase realized later by proximity (slop=0, - // inorder=true) - } - // Reads one raw token (unquoted) stopping on space, ')', '|', '{', '[', or - // start of '@field' - std::string tok; - bool seen_nonwildcard = false; - while (pos_ < expression_.size()) { - char c = expression_[pos_]; - if (std::isspace(static_cast(c)) || c == ')' || c == '|' || - c == '{' || c == '[' || c == '@') - break; - tok.push_back(c); + // Break on all punctuation characters. + if (lexer.IsPunctuation(ch)) break; + // Regular character + processed_content.push_back(ch); ++pos_; - // If we encounter a tailing * (wildcard) after content, break to split into - // a new predicate. - if (c == '*' && seen_nonwildcard) { - break; + } + std::string token = absl::AsciiStrToLower(processed_content); + // Build predicate directly based on detected pattern + if (leading_percent_count > 0) { + if (trailing_percent_count == leading_percent_count && + leading_percent_count <= FUZZY_MAX_DISTANCE) { + if (token.empty()) return absl::InvalidArgumentError("Empty fuzzy token"); + return FilterParser::TokenResult{ + std::make_unique(text_index_schema, field_mask, + std::move(token), + leading_percent_count), + break_on_query_syntax}; + } else { + return absl::InvalidArgumentError("Invalid fuzzy '%' markers"); + } + } else if (starts_with_star) { + if (token.empty()) + return absl::InvalidArgumentError("Invalid wildcard '*' markers"); + if (!text_index_schema->GetTextIndex()->suffix_.has_value()) { + return absl::InvalidArgumentError("Index created without Suffix Trie"); + } + if (ends_with_star) { + return FilterParser::TokenResult{ + std::make_unique(text_index_schema, field_mask, + std::move(token)), + break_on_query_syntax}; + } else { + return FilterParser::TokenResult{ + std::make_unique( + text_index_schema, field_mask, std::move(token)), + break_on_query_syntax}; + } + } else if (ends_with_star) { + if (token.empty()) + return absl::InvalidArgumentError("Invalid wildcard '*' markers"); + return FilterParser::TokenResult{ + std::make_unique(text_index_schema, field_mask, + std::move(token)), + break_on_query_syntax}; + } else { + // Term predicate handling: + bool exact = options_.verbatim; + if (lexer.IsStopWord(token) || token.empty()) { + // Skip stop words and empty words. + return FilterParser::TokenResult{nullptr, break_on_query_syntax}; } - if (c != '*') { - seen_nonwildcard = true; + if (!exact && min_stem_size.has_value()) { + token = lexer.StemWord(token, true, *min_stem_size, lexer.GetStemmer()); } + return FilterParser::TokenResult{ + std::make_unique(text_index_schema, field_mask, + std::move(token), exact), + break_on_query_syntax}; } - if (tok.empty()) return absl::InvalidArgumentError("Empty text token"); - VMSDK_ASSIGN_OR_RETURN(auto t, - BuildSingleTextPredicate(field_for_default, tok)); - terms.push_back(std::move(t)); - return terms; } -absl::StatusOr FilterParser::ResolveTextFieldOrDefault( - const std::optional& maybe_field) { - if (maybe_field.has_value()) return *maybe_field; - // Placeholder for default text field - return std::string("__default__"); +absl::Status FilterParser::SetupTextFieldConfiguration( + FieldMaskPredicate& field_mask, std::optional& min_stem_size, + const std::optional& field_name) { + if (field_name.has_value()) { + auto index = index_schema_.GetIndex(*field_name); + if (!index.ok() || + index.value()->GetIndexerType() != indexes::IndexerType::kText) { + return absl::InvalidArgumentError("Index does not have any text field"); + } + auto* text_index = dynamic_cast(index.value().get()); + auto identifier = index_schema_.GetIdentifier(*field_name).value(); + filter_identifiers_.insert(identifier); + field_mask = 1ULL << text_index->GetTextFieldNumber(); + if (text_index->IsStemmingEnabled()) { + min_stem_size = text_index->GetMinStemSize(); + } + } else { + // Set identifiers to include all text fields in the index schema. + auto text_identifiers = index_schema_.GetAllTextIdentifiers(); + for (const auto& identifier : text_identifiers) { + filter_identifiers_.insert(identifier); + } + // Set field mask to include all text fields in the index schema. + field_mask = ~0ULL; + // When no field was specified, we use the min stem across all text fields + // in the index schema. This helps ensure the root of the text token can be + // searched for. + min_stem_size = index_schema_.MinStemSizeAcrossTextIndexes(); + } + return absl::OkStatus(); } -// TODO: -// - Handle negation -// - Handle parenthesis by including terms in the proximity predicate. This -// requires folding this fn in the caller site. -// - Handle parsing and setup of default text field predicates -// - Try to move out nested standard operations (negate/numeric/tag/parenthesis) -// back to the caller site and reduce responsibilities of the text parser -// - Handle escaped characters in text tokens -absl::StatusOr> FilterParser::ParseTextGroup( - const std::string& initial_field) { - std::vector> all_terms; - std::vector> extra_terms; - std::string current_field = initial_field; +// This function is called when the characters detected are potentially those of +// a text predicate. +// Text Parsing Syntax: +// Quoted: "word1 word2" -> ProximityPredicate(exact, slop=0, inorder=true) +// Unquoted: word1 word2 -> TermPredicate(word1) - stops at first token +// Token boundaries for unquoted text: ( ) | @ " - { } [ ] : ; $ +// Quoted phrases (Exact Phrase) parse all tokens within quotes, unquoted +// parsing stops after first token. +// TODO: Update ProximityPredicate to ComposedAND. +absl::StatusOr> FilterParser::ParseTextTokens( + const std::optional& field_or_default) { + auto text_index_schema = index_schema_.GetTextIndexSchema(); + if (!text_index_schema) { + return absl::InvalidArgumentError("Index does not have any text field"); + } + std::vector> terms; + // Handle default / every field (no field specifier) and specific + // field query cases. + FieldMaskPredicate field_mask; + std::optional min_stem_size = std::nullopt; + VMSDK_RETURN_IF_ERROR( + SetupTextFieldConfiguration(field_mask, min_stem_size, field_or_default)); + bool in_quotes = false; + bool exact_phrase = false; while (!IsEnd()) { - SkipWhitespace(); - if (IsEnd()) break; - bool negate = Match('-'); char c = Peek(); - // Stop text group if next is OR - if (c == '|') break; - // Currently, parenthesis is not included in Proximity predicate. This needs - // to be addressed. - if (c == '(' || c == ')') break; - std::optional field_for_atom; - if (!current_field.empty()) { - field_for_atom = current_field; - } - // Field override or numeric/tag - if (c == '@') { - VMSDK_ASSIGN_OR_RETURN(current_field, ParseFieldName()); - field_for_atom = current_field; - SkipWhitespace(); - if (!IsEnd()) { - if (Match('[')) { - VMSDK_ASSIGN_OR_RETURN(auto numeric, - ParseNumericPredicate(current_field)); - extra_terms.push_back(std::move(numeric)); - continue; - } else if (Match('{')) { - VMSDK_ASSIGN_OR_RETURN(auto tag, ParseTagPredicate(current_field)); - extra_terms.push_back(std::move(tag)); - continue; - } - } else { - return absl::InvalidArgumentError("Invalid query string"); + if (c == '"') { + in_quotes = !in_quotes; + ++pos_; + if (in_quotes && terms.empty()) { + exact_phrase = true; + continue; } + break; + } + size_t token_start = pos_; + VMSDK_ASSIGN_OR_RETURN( + auto result, + in_quotes + ? ParseQuotedTextToken(text_index_schema, field_mask, min_stem_size) + : ParseUnquotedTextToken(text_index_schema, field_mask, + min_stem_size)); + if (result.predicate) { + terms.push_back(std::move(result.predicate)); + // TODO: Uncomment this once we have ComposedAND evaluation functional for + // handling proximity checks. Until the, we handle unquoted text tokens + // by building a proximity predicate containing them. + // if (!exact_phrase) break; + } + if (result.break_on_query_syntax) { + break; + } + // If this happens, we are either done (at the end of the prefilter string) + // or were on a punctuation character which should be consumed. + if (token_start == pos_) { + ++pos_; } - // Parse next text atom (first or subsequent) - VMSDK_ASSIGN_OR_RETURN(auto resolved, - ResolveTextFieldOrDefault(field_for_atom)); - VMSDK_ASSIGN_OR_RETURN(auto terms, ParseOneTextAtomIntoTerms(resolved)); - for (auto& t : terms) all_terms.push_back(std::move(t)); - // Only use initial_field for first atom - current_field.clear(); - } - // Build main predicate from text terms - std::unique_ptr prox; - if (all_terms.size() == 1) { - prox = std::move(all_terms[0]); - } else if (!all_terms.empty()) { - prox = std::make_unique( - std::move(all_terms), /*slop=*/0, /*inorder=*/true); - } else { - return absl::InvalidArgumentError("Invalid query string"); } - // Append numeric/tag predicates - for (auto& extra : extra_terms) { - bool neg = false; - prox = WrapPredicate(std::move(prox), std::move(extra), neg, - query::LogicalOperator::kAnd); + std::unique_ptr pred; + if (terms.size() > 1) { + uint32_t slop = options_.slop.value_or(0); + bool inorder = options_.inorder; + if (exact_phrase) { + slop = 0; + inorder = true; + } + // TODO: Swap ProximityPredicate with ComposedANDPredicate once it is + // flattened. Once that happens, we need to add slop and inorder properties + // to ComposedANDPredicate. + pred = std::make_unique(std::move(terms), slop, + inorder); + node_count_ += terms.size(); + } else { + if (terms.empty()) { + return absl::InvalidArgumentError("Invalid Query Syntax"); + } + pred = std::move(terms[0]); } - return prox; + return pred; } // Parsing rules: @@ -721,16 +829,25 @@ absl::StatusOr> FilterParser::ParseExpression( WrapPredicate(std::move(prev_predicate), std::move(predicate), negate, query::LogicalOperator::kOr); } else { - VMSDK_ASSIGN_OR_RETURN(auto field_name, ParseFieldName()); - if (Match('[')) { - node_count_++; // Count the NumericPredicate Node - VMSDK_ASSIGN_OR_RETURN(predicate, ParseNumericPredicate(field_name)); - } else if (Match('{')) { - node_count_++; // Count the TagPredicate Node - VMSDK_ASSIGN_OR_RETURN(predicate, ParseTagPredicate(field_name)); - } else { - node_count_++; // Count the TextPredicate Node - VMSDK_ASSIGN_OR_RETURN(predicate, ParseTextGroup(field_name)); + std::optional field_name; + bool non_text = false; + if (Peek() == '@') { + std::string parsed_field; + VMSDK_ASSIGN_OR_RETURN(parsed_field, ParseFieldName()); + field_name = parsed_field; + if (Match('[')) { + node_count_++; + VMSDK_ASSIGN_OR_RETURN(predicate, ParseNumericPredicate(*field_name)); + non_text = true; + } else if (Match('{')) { + node_count_++; + VMSDK_ASSIGN_OR_RETURN(predicate, ParseTagPredicate(*field_name)); + non_text = true; + } + } + if (!non_text) { + node_count_++; + VMSDK_ASSIGN_OR_RETURN(predicate, ParseTextTokens(field_name)); } if (prev_predicate) { node_count_++; // Count the ComposedPredicate Node diff --git a/src/commands/filter_parser.h b/src/commands/filter_parser.h index 77bea7370..52488dda8 100644 --- a/src/commands/filter_parser.h +++ b/src/commands/filter_parser.h @@ -16,6 +16,7 @@ #include "absl/strings/string_view.h" #include "src/index_schema.h" #include "src/indexes/tag.h" +#include "src/indexes/text/lexer.h" #include "src/query/predicate.h" #include "vmsdk/src/module_config.h" @@ -23,32 +24,49 @@ namespace valkey_search { namespace indexes { class Tag; } // namespace indexes +using FieldMaskPredicate = uint64_t; +struct TextParsingOptions { + bool verbatim = false; + bool inorder = false; + std::optional slop = std::nullopt; +}; struct FilterParseResults { std::unique_ptr root_predicate; absl::flat_hash_set filter_identifiers; }; class FilterParser { public: - FilterParser(const IndexSchema& index_schema, absl::string_view expression); + FilterParser(const IndexSchema& index_schema, absl::string_view expression, + const TextParsingOptions& options); absl::StatusOr Parse(); private: + const TextParsingOptions& options_; const IndexSchema& index_schema_; absl::string_view expression_; size_t pos_{0}; size_t node_count_{0}; absl::flat_hash_set filter_identifiers_; - absl::StatusOr ResolveTextFieldOrDefault( - const std::optional& maybe_field); - absl::StatusOr> - BuildSingleTextPredicate(const std::string& field_name, - absl::string_view raw_token); - absl::StatusOr>> - ParseOneTextAtomIntoTerms(const std::string& field_for_default); - absl::StatusOr> ParseTextGroup( - const std::string& initial_field); + absl::StatusOr HandleBackslashEscape(const indexes::text::Lexer& lexer, + std::string& processed_content); + struct TokenResult { + std::unique_ptr predicate; + bool break_on_query_syntax; + }; + absl::StatusOr ParseQuotedTextToken( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::optional min_stem_size); + + absl::StatusOr ParseUnquotedTextToken( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::optional min_stem_size); + absl::Status SetupTextFieldConfiguration( + FieldMaskPredicate& field_mask, std::optional& min_stem_size, + const std::optional& field_name = std::nullopt); + absl::StatusOr> ParseTextTokens( + const std::optional& field_for_default); absl::StatusOr IsMatchAllExpression(); absl::StatusOr> ParseExpression( uint32_t level); diff --git a/src/commands/ft_create_parser.cc b/src/commands/ft_create_parser.cc index b158a1901..88e72e2ae 100644 --- a/src/commands/ft_create_parser.cc +++ b/src/commands/ft_create_parser.cc @@ -593,7 +593,7 @@ absl::StatusOr ParseFTCreateArgs( PerIndexTextParams schema_text_defaults; // Initialize with defaults for each parse call schema_text_defaults.punctuation = kDefaultPunctuation; - schema_text_defaults.min_stem_size = 4; + schema_text_defaults.min_stem_size = kDefaultMinStemSize; schema_text_defaults.with_offsets = true; schema_text_defaults.no_stem = false; schema_text_defaults.language = data_model::LANGUAGE_ENGLISH; diff --git a/src/commands/ft_create_parser.h b/src/commands/ft_create_parser.h index dc217dfb0..fcd6313f1 100644 --- a/src/commands/ft_create_parser.h +++ b/src/commands/ft_create_parser.h @@ -26,6 +26,7 @@ namespace valkey_search { static constexpr absl::string_view kDefaultPunctuation = ",.<>{}[]\"':;!@#$%^&*()-+=~/\\|"; +static uint32_t kDefaultMinStemSize = 4; // Default stop words set const std::vector kDefaultStopWords{ diff --git a/src/commands/ft_search_parser.cc b/src/commands/ft_search_parser.cc index d7f3861f0..38e5c93f6 100644 --- a/src/commands/ft_search_parser.cc +++ b/src/commands/ft_search_parser.cc @@ -177,8 +177,12 @@ absl::StatusOr FindCloseSquareBracket(absl::string_view input) { } absl::StatusOr ParsePreFilter( - const IndexSchema &index_schema, absl::string_view pre_filter) { - FilterParser parser(index_schema, pre_filter); + const IndexSchema &index_schema, absl::string_view pre_filter, + const query::SearchParameters &search_params) { + TextParsingOptions options{.verbatim = search_params.verbatim, + .inorder = search_params.inorder, + .slop = search_params.slop}; + FilterParser parser(index_schema, pre_filter, options); return parser.Parse(); } @@ -385,7 +389,7 @@ absl::Status PreParseQueryString(query::SearchParameters ¶meters) { } VMSDK_ASSIGN_OR_RETURN( parameters.filter_parse_results, - ParsePreFilter(*parameters.index_schema, pre_filter), + ParsePreFilter(*parameters.index_schema, pre_filter, parameters), _.SetPrepend() << "Invalid filter expression: `" << pre_filter << "`. "); if (!parameters.filter_parse_results.root_predicate && vector_filter.empty()) { diff --git a/src/index_schema.cc b/src/index_schema.cc index f4ae484eb..977e1e19d 100644 --- a/src/index_schema.cc +++ b/src/index_schema.cc @@ -267,6 +267,42 @@ absl::StatusOr> IndexSchema::GetIndex( return itr->second.GetIndex(); } +// Returns a vector of all the text (field) identifiers within the text +// index schema. This is intended to be used by queries where there +// is no field specification, and we want to include results from all +// text fields. +std::vector IndexSchema::GetAllTextIdentifiers() const { + std::vector identifiers; + for (const auto &[alias, attribute] : attributes_) { + auto index = attribute.GetIndex(); + if (index->GetIndexerType() == indexes::IndexerType::kText) { + identifiers.push_back(attribute.GetIdentifier()); + } + } + return identifiers; +} + +// Find the min stem size across all text fields in the text index schema. +// If stemming is disabled across all text field indexes, return `nullopt`. +std::optional IndexSchema::MinStemSizeAcrossTextIndexes() const { + uint32_t min_stem_size = kDefaultMinStemSize; + bool is_stemming_enabled = false; + for (const auto &[alias, attribute] : attributes_) { + auto index = attribute.GetIndex(); + if (index->GetIndexerType() == indexes::IndexerType::kText) { + auto *text_index = dynamic_cast(index.get()); + min_stem_size = std::min(min_stem_size, text_index->GetMinStemSize()); + if (text_index->IsStemmingEnabled()) { + is_stemming_enabled = true; + } + } + } + if (!is_stemming_enabled) { + return std::nullopt; + } + return min_stem_size; +} + absl::StatusOr IndexSchema::GetIdentifier( absl::string_view attribute_alias) const { auto itr = attributes_.find(std::string{attribute_alias}); diff --git a/src/index_schema.h b/src/index_schema.h index 1dfd98d2d..3360b3db8 100644 --- a/src/index_schema.h +++ b/src/index_schema.h @@ -28,6 +28,7 @@ #include "gtest/gtest_prod.h" #include "src/attribute.h" #include "src/attribute_data_type.h" +#include "src/commands/ft_create_parser.h" #include "src/index_schema.pb.h" #include "src/indexes/index_base.h" #include "src/indexes/text/text_index.h" @@ -95,6 +96,8 @@ class IndexSchema : public KeyspaceEventSubscription, ~IndexSchema() override; absl::StatusOr> GetIndex( absl::string_view attribute_alias) const; + std::vector GetAllTextIdentifiers() const; + std::optional MinStemSizeAcrossTextIndexes() const; virtual absl::StatusOr GetIdentifier( absl::string_view attribute_alias) const; absl::StatusOr GetAlias(absl::string_view identifier) const; diff --git a/src/indexes/text.cc b/src/indexes/text.cc index bec225415..20267c672 100644 --- a/src/indexes/text.cc +++ b/src/indexes/text.cc @@ -103,18 +103,6 @@ size_t Text::CalculateSize(const query::TextPredicate& predicate) const { return 0; } -std::unique_ptr Text::Search( - const query::TextPredicate& predicate, bool negate) const { - auto fetcher = std::make_unique( - CalculateSize(predicate), text_index_schema_->GetTextIndex(), - negate ? &untracked_keys_ : nullptr); - fetcher->predicate_ = &predicate; - // TODO : Update for the default search case (all fields). - // The TextPredicate needs to support a GetFieldMask API to indicate this. - fetcher->field_mask_ = 1ULL << text_field_number_; - return fetcher; -} - size_t Text::EntriesFetcher::Size() const { return size_; } std::unique_ptr Text::EntriesFetcher::Begin() { @@ -127,6 +115,14 @@ std::unique_ptr Text::EntriesFetcher::Begin() { // Implement the TextPredicate BuildTextIterator virtual method namespace valkey_search::query { +void* TextPredicate::Search(bool negate) const { + // TODO: Add logic to calculate the size based on number of keys estimated. + auto fetcher = std::make_unique( + 0, GetTextIndexSchema()->GetTextIndex(), nullptr, GetFieldMask()); + fetcher->predicate_ = this; + return fetcher.release(); +} + std::unique_ptr TermPredicate::BuildTextIterator( const void* fetcher_ptr) const { const auto* fetcher = diff --git a/src/indexes/text.h b/src/indexes/text.h index 4f10b38a2..409b6ed6b 100644 --- a/src/indexes/text.h +++ b/src/indexes/text.h @@ -39,6 +39,11 @@ class Text : public IndexBase { explicit Text(const data_model::TextIndex& text_index_proto, std::shared_ptr text_index_schema); + std::shared_ptr GetTextIndexSchema() const { + return text_index_schema_; + } + uint32_t GetMinStemSize() const { return min_stem_size_; } + bool IsStemmingEnabled() const { return !no_stem_; } absl::StatusOr AddRecord(const InternedStringPtr& key, absl::string_view data) override ABSL_LOCKS_EXCLUDED(index_mutex_); @@ -72,8 +77,8 @@ class Text : public IndexBase { public: EntriesFetcher(size_t size, const std::shared_ptr& text_index, - const InternedStringSet* untracked_keys = nullptr, - text::FieldMaskPredicate field_mask = ~0ULL) + const InternedStringSet* untracked_keys, + text::FieldMaskPredicate field_mask) : size_(size), text_index_(text_index), untracked_keys_(untracked_keys), @@ -92,17 +97,13 @@ class Text : public IndexBase { const InternedStringSet* untracked_keys_; std::shared_ptr text_index_; const query::TextPredicate* predicate_; - absl::string_view data_; - bool no_field_{false}; text::FieldMaskPredicate field_mask_; }; // Calculate size based on the predicate. size_t CalculateSize(const query::TextPredicate& predicate) const; - virtual std::unique_ptr Search( - const query::TextPredicate& predicate, - bool negate) const ABSL_NO_THREAD_SAFETY_ANALYSIS; + size_t GetTextFieldNumber() const { return text_field_number_; } private: // Each text field index within the schema is assigned a unique number, this @@ -116,7 +117,7 @@ class Text : public IndexBase { bool with_suffix_trie_; bool no_stem_; - int32_t min_stem_size_; + uint32_t min_stem_size_; // TODO: Map to track which keys are indexed and their raw data diff --git a/src/indexes/text/lexer.h b/src/indexes/text/lexer.h index 0c77cc158..4ca1f6416 100644 --- a/src/indexes/text/lexer.h +++ b/src/indexes/text/lexer.h @@ -46,16 +46,8 @@ struct Lexer { absl::string_view text, bool stemming_enabled, uint32_t min_stem_size) const; - private: - data_model::Language language_; - std::bitset<256> punct_bitmap_; - absl::flat_hash_set stop_words_set_; - - sb_stemmer* GetStemmer() const; - std::string StemWord(const std::string& word, bool stemming_enabled, uint32_t min_stem_size, sb_stemmer* stemmer) const; - bool IsPunctuation(char c) const { return punct_bitmap_[static_cast(c)]; } @@ -63,6 +55,12 @@ struct Lexer { bool IsStopWord(const std::string& lowercase_word) const { return stop_words_set_.contains(lowercase_word); } + sb_stemmer* GetStemmer() const; + + private: + data_model::Language language_; + std::bitset<256> punct_bitmap_; + absl::flat_hash_set stop_words_set_; // UTF-8 processing helpers bool IsValidUtf8(absl::string_view text) const; diff --git a/src/indexes/text/text_index.h b/src/indexes/text/text_index.h index a679449a4..81afa3ad3 100644 --- a/src/indexes/text/text_index.h +++ b/src/indexes/text/text_index.h @@ -83,6 +83,7 @@ class TextIndexSchema { uint8_t GetNumTextFields() const { return num_text_fields_; } std::shared_ptr GetTextIndex() const { return text_index_; } + Lexer GetLexer() const { return lexer_; } // Access to metadata for memory pool usage TextIndexMetadata& GetMetadata() { return metadata_; } diff --git a/src/query/predicate.cc b/src/query/predicate.cc index 2a5410326..0312ddd08 100644 --- a/src/query/predicate.cc +++ b/src/query/predicate.cc @@ -25,14 +25,14 @@ bool NegatePredicate::Evaluate(Evaluator& evaluator) const { return !predicate_->Evaluate(evaluator); } -TermPredicate::TermPredicate(const indexes::Text* index, - absl::string_view identifier, - absl::string_view alias, std::string term) +TermPredicate::TermPredicate( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::string term, bool exact_) : TextPredicate(), - index_(index), - identifier_(vmsdk::MakeUniqueValkeyString(identifier)), - alias_(alias), - term_(term) {} + text_index_schema_(text_index_schema), + field_mask_(field_mask), + term_(term), + exact_(exact_) {} bool TermPredicate::Evaluate(Evaluator& evaluator) const { // call dynamic dispatch on the evaluator @@ -44,13 +44,12 @@ bool TermPredicate::Evaluate(const std::string_view& text) const { return text == term_; // exact match } -PrefixPredicate::PrefixPredicate(const indexes::Text* index, - absl::string_view identifier, - absl::string_view alias, std::string term) +PrefixPredicate::PrefixPredicate( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::string term) : TextPredicate(), - index_(index), - identifier_(vmsdk::MakeUniqueValkeyString(identifier)), - alias_(alias), + text_index_schema_(text_index_schema), + field_mask_(field_mask), term_(term) {} bool PrefixPredicate::Evaluate(Evaluator& evaluator) const { @@ -62,13 +61,12 @@ bool PrefixPredicate::Evaluate(const std::string_view& text) const { return absl::StartsWith(text, term_); } -SuffixPredicate::SuffixPredicate(const indexes::Text* index, - absl::string_view identifier, - absl::string_view alias, std::string term) +SuffixPredicate::SuffixPredicate( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::string term) : TextPredicate(), - index_(index), - identifier_(vmsdk::MakeUniqueValkeyString(identifier)), - alias_(alias), + text_index_schema_(text_index_schema), + field_mask_(field_mask), term_(term) {} bool SuffixPredicate::Evaluate(Evaluator& evaluator) const { @@ -80,13 +78,12 @@ bool SuffixPredicate::Evaluate(const std::string_view& text) const { return absl::EndsWith(text, term_); } -InfixPredicate::InfixPredicate(const indexes::Text* index, - absl::string_view identifier, - absl::string_view alias, std::string term) +InfixPredicate::InfixPredicate( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::string term) : TextPredicate(), - index_(index), - identifier_(vmsdk::MakeUniqueValkeyString(identifier)), - alias_(alias), + text_index_schema_(text_index_schema), + field_mask_(field_mask), term_(term) {} bool InfixPredicate::Evaluate(Evaluator& evaluator) const { @@ -98,14 +95,12 @@ bool InfixPredicate::Evaluate(const std::string_view& text) const { return absl::StrContains(text, term_); } -FuzzyPredicate::FuzzyPredicate(const indexes::Text* index, - absl::string_view identifier, - absl::string_view alias, std::string term, - uint32_t distance) +FuzzyPredicate::FuzzyPredicate( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::string term, uint32_t distance) : TextPredicate(), - index_(index), - identifier_(vmsdk::MakeUniqueValkeyString(identifier)), - alias_(alias), + text_index_schema_(text_index_schema), + field_mask_(field_mask), term_(term), distance_(distance) {} diff --git a/src/query/predicate.h b/src/query/predicate.h index c65f088c9..604e67719 100644 --- a/src/query/predicate.h +++ b/src/query/predicate.h @@ -26,7 +26,8 @@ class Tag; namespace valkey_search::indexes::text { class TextIterator; -} +class TextIndexSchema; +} // namespace valkey_search::indexes::text namespace valkey_search::query { @@ -136,128 +137,114 @@ class TagPredicate : public Predicate { absl::flat_hash_set tags_; }; +using FieldMaskPredicate = uint64_t; + class TextPredicate : public Predicate { public: TextPredicate() : Predicate(PredicateType::kText) {} virtual ~TextPredicate() = default; virtual bool Evaluate(Evaluator& evaluator) const = 0; virtual bool Evaluate(const std::string_view& text) const = 0; - virtual const indexes::Text* GetIndex() const = 0; + virtual std::shared_ptr GetTextIndexSchema() + const = 0; + virtual const FieldMaskPredicate GetFieldMask() const = 0; + virtual void* Search(bool negate) const; virtual std::unique_ptr BuildTextIterator( const void* fetcher) const = 0; }; class TermPredicate : public TextPredicate { public: - TermPredicate(const indexes::Text* index, absl::string_view identifier, - absl::string_view alias, std::string term); - const indexes::Text* GetIndex() const { return index_; } - absl::string_view GetAlias() const { return alias_; } - absl::string_view GetIdentifier() const { - return vmsdk::ToStringView(identifier_.get()); - } - vmsdk::UniqueValkeyString GetRetainedIdentifier() const { - return vmsdk::RetainUniqueValkeyString(identifier_.get()); + TermPredicate( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::string term, bool exact); + std::shared_ptr GetTextIndexSchema() const { + return text_index_schema_; } absl::string_view GetTextString() const { return term_; } bool Evaluate(Evaluator& evaluator) const override; bool Evaluate(const std::string_view& text) const override; std::unique_ptr BuildTextIterator( const void* fetcher) const override; + const FieldMaskPredicate GetFieldMask() const override { return field_mask_; } private: - const indexes::Text* index_; - vmsdk::UniqueValkeyString identifier_; - absl::string_view alias_; + std::shared_ptr text_index_schema_; + FieldMaskPredicate field_mask_; std::string term_; + bool exact_; }; class PrefixPredicate : public TextPredicate { public: - PrefixPredicate(const indexes::Text* index, absl::string_view identifier, - absl::string_view alias, std::string term); - const indexes::Text* GetIndex() const { return index_; } - absl::string_view GetAlias() const { return alias_; } - absl::string_view GetIdentifier() const { - return vmsdk::ToStringView(identifier_.get()); - } - vmsdk::UniqueValkeyString GetRetainedIdentifier() const { - return vmsdk::RetainUniqueValkeyString(identifier_.get()); + PrefixPredicate( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::string term); + std::shared_ptr GetTextIndexSchema() const { + return text_index_schema_; } absl::string_view GetTextString() const { return term_; } bool Evaluate(Evaluator& evaluator) const override; bool Evaluate(const std::string_view& text) const override; std::unique_ptr BuildTextIterator( const void* fetcher) const override; + const FieldMaskPredicate GetFieldMask() const override { return field_mask_; } private: - const indexes::Text* index_; - vmsdk::UniqueValkeyString identifier_; - absl::string_view alias_; + std::shared_ptr text_index_schema_; + FieldMaskPredicate field_mask_; std::string term_; }; class SuffixPredicate : public TextPredicate { public: - SuffixPredicate(const indexes::Text* index, absl::string_view identifier, - absl::string_view alias, std::string term); - const indexes::Text* GetIndex() const { return index_; } - absl::string_view GetAlias() const { return alias_; } - absl::string_view GetIdentifier() const { - return vmsdk::ToStringView(identifier_.get()); - } - vmsdk::UniqueValkeyString GetRetainedIdentifier() const { - return vmsdk::RetainUniqueValkeyString(identifier_.get()); + SuffixPredicate( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::string term); + std::shared_ptr GetTextIndexSchema() const { + return text_index_schema_; } absl::string_view GetTextString() const { return term_; } bool Evaluate(Evaluator& evaluator) const override; bool Evaluate(const std::string_view& text) const override; std::unique_ptr BuildTextIterator( const void* fetcher) const override; + const FieldMaskPredicate GetFieldMask() const override { return field_mask_; } private: - const indexes::Text* index_; - vmsdk::UniqueValkeyString identifier_; - absl::string_view alias_; + std::shared_ptr text_index_schema_; + FieldMaskPredicate field_mask_; std::string term_; }; class InfixPredicate : public TextPredicate { public: - InfixPredicate(const indexes::Text* index, absl::string_view identifier, - absl::string_view alias, std::string term); - const indexes::Text* GetIndex() const { return index_; } - absl::string_view GetAlias() const { return alias_; } - absl::string_view GetIdentifier() const { - return vmsdk::ToStringView(identifier_.get()); - } - vmsdk::UniqueValkeyString GetRetainedIdentifier() const { - return vmsdk::RetainUniqueValkeyString(identifier_.get()); + InfixPredicate( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::string term); + std::shared_ptr GetTextIndexSchema() const { + return text_index_schema_; } absl::string_view GetTextString() const { return term_; } bool Evaluate(Evaluator& evaluator) const override; bool Evaluate(const std::string_view& text) const override; std::unique_ptr BuildTextIterator( const void* fetcher) const override; + const FieldMaskPredicate GetFieldMask() const override { return field_mask_; } private: - const indexes::Text* index_; - vmsdk::UniqueValkeyString identifier_; - absl::string_view alias_; + std::shared_ptr text_index_schema_; + FieldMaskPredicate field_mask_; std::string term_; }; class FuzzyPredicate : public TextPredicate { public: - FuzzyPredicate(const indexes::Text* index, absl::string_view identifier, - absl::string_view alias, std::string term, uint32_t distance); - const indexes::Text* GetIndex() const { return index_; } - absl::string_view GetAlias() const { return alias_; } - absl::string_view GetIdentifier() const { - return vmsdk::ToStringView(identifier_.get()); - } - vmsdk::UniqueValkeyString GetRetainedIdentifier() const { - return vmsdk::RetainUniqueValkeyString(identifier_.get()); + FuzzyPredicate( + std::shared_ptr text_index_schema, + FieldMaskPredicate field_mask, std::string term, uint32_t distance); + std::shared_ptr GetTextIndexSchema() const { + return text_index_schema_; } absl::string_view GetTextString() const { return term_; } uint32_t GetDistance() const { return distance_; } @@ -265,11 +252,11 @@ class FuzzyPredicate : public TextPredicate { bool Evaluate(const std::string_view& text) const override; std::unique_ptr BuildTextIterator( const void* fetcher) const override; + const FieldMaskPredicate GetFieldMask() const override { return field_mask_; } private: - const indexes::Text* index_; - vmsdk::UniqueValkeyString identifier_; - absl::string_view alias_; + std::shared_ptr text_index_schema_; + FieldMaskPredicate field_mask_; std::string term_; uint32_t distance_; }; @@ -284,8 +271,11 @@ class ProximityPredicate : public TextPredicate { bool Evaluate(const std::string_view& text) const override { return false; } std::unique_ptr BuildTextIterator( const void* fetcher) const override; - const indexes::Text* GetIndex() const override { - return terms_[0]->GetIndex(); + std::shared_ptr GetTextIndexSchema() const { + return terms_[0]->GetTextIndexSchema(); + } + const FieldMaskPredicate GetFieldMask() const override { + return terms_[0]->GetFieldMask(); } const std::vector>& Terms() const { return terms_; diff --git a/src/query/search.cc b/src/query/search.cc index 9fb4f20ef..ce8a88dea 100644 --- a/src/query/search.cc +++ b/src/query/search.cc @@ -169,7 +169,9 @@ size_t EvaluateFilterAsPrimary( } if (predicate->GetType() == PredicateType::kText) { auto text_predicate = dynamic_cast(predicate); - auto fetcher = text_predicate->GetIndex()->Search(*text_predicate, negate); + auto fetcher = std::unique_ptr( + static_cast( + text_predicate->Search(negate))); size_t size = fetcher->Size(); entries_fetchers.push(std::move(fetcher)); return size; diff --git a/testing/common.cc b/testing/common.cc index 018e34005..5d956f233 100644 --- a/testing/common.cc +++ b/testing/common.cc @@ -104,12 +104,16 @@ absl::StatusOr> CreateIndexSchema( .WillByDefault(testing::Return(index_schema_db_num)); EXPECT_CALL(*kMockValkeyModule, GetDetachedThreadSafeContext(testing::_)) .WillRepeatedly(testing::Return(fake_ctx)); + data_model::Language language = data_model::LANGUAGE_ENGLISH; + std::string punctuation = ",.<>{}[]\"':;!@#$%^&*()-+=~/\\|"; + bool with_offsets = false; + std::vector stop_words = {}; VMSDK_ASSIGN_OR_RETURN( auto test_index_schema, - valkey_search::MockIndexSchema::Create( + MockIndexSchema::Create( fake_ctx, index_schema_key, *key_prefixes, std::make_unique(), - writer_thread_pool)); + writer_thread_pool, language, punctuation, with_offsets, stop_words)); VMSDK_RETURN_IF_ERROR( SchemaManager::Instance().ImportIndexSchema(test_index_schema)); return test_index_schema; diff --git a/testing/filter_test.cc b/testing/filter_test.cc index efcbbb33b..abb0d2e29 100644 --- a/testing/filter_test.cc +++ b/testing/filter_test.cc @@ -91,11 +91,9 @@ void InitIndexSchema(MockIndexSchema *index_schema) { "tag_field_case_insensitive", tag_field_case_insensitive)); - data_model::TextIndex text_index_proto; - auto text_index_schema = - std::make_shared( - data_model::LANGUAGE_ENGLISH, std::string(kDefaultPunctuation), true, - kDefaultStopWords); + index_schema->CreateTextIndexSchema(); + auto text_index_schema = index_schema->GetTextIndexSchema(); + data_model::TextIndex text_index_proto = CreateTextIndexProto(true, false, 4); auto text_index_1 = std::make_shared(text_index_proto, text_index_schema); auto text_index_2 = @@ -112,7 +110,7 @@ TEST_P(FilterTest, ParseParams) { InitIndexSchema(index_schema.get()); EXPECT_CALL(*index_schema, GetIdentifier(::testing::_)) .Times(::testing::AnyNumber()); - FilterParser parser(*index_schema, test_case.filter); + FilterParser parser(*index_schema, test_case.filter, {}); auto parse_results = parser.Parse(); EXPECT_EQ(test_case.create_success, parse_results.ok()); if (!test_case.create_success) { @@ -496,13 +494,15 @@ INSTANTIATE_TEST_SUITE_P( .test_name = "exact_suffix", .filter = "@text_field1:*word", .create_success = false, - .create_expected_error_message = "Unsupported query operation", + .create_expected_error_message = + "Index created without Suffix Trie", }, { .test_name = "exact_inffix", .filter = "@text_field1:*word*", .create_success = false, - .create_expected_error_message = "Unsupported query operation", + .create_expected_error_message = + "Index created without Suffix Trie", }, { .test_name = "exact_fuzzy1", @@ -535,6 +535,76 @@ INSTANTIATE_TEST_SUITE_P( .create_success = true, .evaluate_success = true, }, + { + .test_name = "default_field_text", + .filter = "Hello, how are you doing?", + .create_success = true, + .evaluate_success = true, + }, + { + .test_name = "default_field_exact_phrase", + .filter = "\"Hello, how are you doing?\"", + .create_success = true, + .evaluate_success = true, + }, + { + .test_name = "default_field_exact_phrase_with_punct", + .filter = "\"Hello, h(ow a)re yo#u doi_n$g?\"", + .create_success = true, + .evaluate_success = true, + }, + { + .test_name = "default_field_with_escape1", + .filter = + "\"\\\\\\\\\\Hello, \\how \\\\are \\\\\\you \\\\\\\\doing?\"", + .create_success = true, + .evaluate_success = true, + }, + { + .test_name = "default_field_with_escape2", + .filter = "\\\\\\\\\\Hello, \\how \\\\are \\\\\\you \\\\\\\\doing?", + .create_success = true, + .evaluate_success = true, + }, + { + .test_name = "default_field_with_escape3", + .filter = "Hel\\(lo, ho\\$w a\\*re yo\\{u do\\|ing?", + .create_success = true, + .evaluate_success = true, + }, + { + .test_name = "default_field_with_escape4", + .filter = "\\\\\\\\\\(Hello, \\$how \\\\\\*are \\\\\\-you " + "\\\\\\\\\\%doing?", + .create_success = true, + .evaluate_success = true, + }, + { + .test_name = "default_field_with_escape5", + .filter = "Hello, how are you\\% doing", + .create_success = true, + .evaluate_success = true, + }, + { + .test_name = "default_field_with_escape6", + .filter = "Hello, how are you\\\\\\\\\\% doing", + .create_success = true, + .evaluate_success = true, + }, + { + .test_name = "default_field_with_escape_query_syntax", + .filter = + "Hello, how are you\\]\\[\\$\\}\\{\\;\\:\\)\\(\\| \\-doing", + .create_success = true, + .evaluate_success = true, + }, + { + .test_name = "default_field_with_all_operations", + .filter = "%Hllo%, how are *ou do* *oda*", + .create_success = false, + .create_expected_error_message = + "Index created without Suffix Trie", + }, { .test_name = "proximity3", .filter = @@ -544,7 +614,53 @@ INSTANTIATE_TEST_SUITE_P( "@tag_field_1:{books} @text_field2:Neural | " "@text_field1:%%%word%%% @text_field2:network", .create_success = false, - .create_expected_error_message = "Unsupported query operation", + .create_expected_error_message = + "Invalid range: Value above maximum; Query string is too " + "complex: max number of terms can't exceed 16", + }, + { + .test_name = "invalid_fuzzy1", + .filter = "Hello, how are you% doing", + .create_success = false, + .create_expected_error_message = "Invalid fuzzy '%' markers", + }, + { + .test_name = "invalid_fuzzy2", + .filter = "Hello, how are %you%% doing", + .create_success = false, + .create_expected_error_message = "Invalid fuzzy '%' markers", + }, + { + .test_name = "invalid_fuzzy3", + .filter = "Hello, how are %%you% doing", + .create_success = false, + .create_expected_error_message = "Invalid fuzzy '%' markers", + }, + { + .test_name = "invalid_fuzzy4", + .filter = "Hello, how are %%%you%%%doing%%%", + .create_success = false, + .create_expected_error_message = "Invalid fuzzy '%' markers", + }, + { + .test_name = "invalid_escape1", + .filter = + "\\\\\\\\\\(Hello, \\$how \\\\*are \\\\\\-you \\\\\\\\%doing?", + .create_success = false, + .create_expected_error_message = "Invalid fuzzy '%' markers", + }, + { + .test_name = "invalid_wildcard1", + .filter = "Hello, how are **you* doing", + .create_success = false, + .create_expected_error_message = "Invalid wildcard '*' markers", + }, + { + .test_name = "invalid_wildcard2", + .filter = "Hello, how are *you** doing", + .create_success = false, + .create_expected_error_message = + "Index created without Suffix Trie", }, { .test_name = "bad_filter_1", @@ -565,7 +681,7 @@ INSTANTIATE_TEST_SUITE_P( .filter = "@num_field_2.0 : [23 25] | num_field_2.0:[0 2.5] ", .create_success = false, .create_expected_error_message = - "Unexpected character at position 28: `n`, expecting `@`", + "Unexpected character at position 41: `:`", }, { .test_name = "bad_filter_4", @@ -579,7 +695,7 @@ INSTANTIATE_TEST_SUITE_P( .filter = "@num_field_2.0 : [23 25] $ @num_field_2.0:[0 2.5] ", .create_success = false, .create_expected_error_message = - "Unexpected character at position 26: `$`, expecting `@`", + "Unexpected character at position 26: `$`", }, { .test_name = "bad_filter_6", @@ -629,6 +745,55 @@ INSTANTIATE_TEST_SUITE_P( .create_success = false, .create_expected_error_message = "Missing closing TAG bracket, '}'", }, + { + .test_name = "bad_filter_13", + .filter = "hello{world", + .create_success = false, + .create_expected_error_message = + "Unexpected character at position 6: `{`", + }, + { + .test_name = "bad_filter_14", + .filter = "hello}world", + .create_success = false, + .create_expected_error_message = + "Unexpected character at position 6: `}`", + }, + { + .test_name = "bad_filter_15", + .filter = "hello$world", + .create_success = false, + .create_expected_error_message = + "Unexpected character at position 6: `$`", + }, + { + .test_name = "bad_filter_16", + .filter = "hello[world", + .create_success = false, + .create_expected_error_message = + "Unexpected character at position 6: `[`", + }, + { + .test_name = "bad_filter_17", + .filter = "hello]world", + .create_success = false, + .create_expected_error_message = + "Unexpected character at position 6: `]`", + }, + { + .test_name = "bad_filter_18", + .filter = "hello:world", + .create_success = false, + .create_expected_error_message = + "Unexpected character at position 6: `:`", + }, + { + .test_name = "bad_filter_19", + .filter = "hello;world", + .create_success = false, + .create_expected_error_message = + "Unexpected character at position 6: `;`", + }, }), [](const TestParamInfo &info) { return info.param.test_name; diff --git a/testing/search_test.cc b/testing/search_test.cc index 3a78f3137..a4e9ed718 100644 --- a/testing/search_test.cc +++ b/testing/search_test.cc @@ -215,7 +215,7 @@ TEST_P(EvaluateFilterAsPrimaryTest, ParseParams) { const EvaluateFilterAsPrimaryTestCase &test_case = GetParam(); auto index_schema = CreateIndexSchema(kIndexSchemaName).value(); InitIndexSchema(index_schema.get()); - FilterParser parser(*index_schema, test_case.filter); + FilterParser parser(*index_schema, test_case.filter, {}); auto filter_parse_results = parser.Parse(); std::queue> entries_fetchers; EXPECT_EQ( @@ -410,7 +410,7 @@ TEST_P(LocalSearchTest, LocalSearchTest) { params.ef = kEfRuntime; std::vector query_vector(kVectorDimensions, 1.0); params.query = VectorToStr(query_vector); - FilterParser parser(*index_schema, test_case.filter); + FilterParser parser(*index_schema, test_case.filter, {}); params.filter_parse_results = std::move(parser.Parse().value()); params.index_schema = index_schema; auto time_slice_queries = Metrics::GetStats().time_slice_queries.load(); @@ -505,7 +505,7 @@ TEST_P(FetchFilteredKeysTest, ParseParams) { index_schema->GetIndex(kVectorAttributeAlias)->get()); const FetchFilteredKeysTestCase &test_case = GetParam(); query::SearchParameters params(100000, nullptr); - FilterParser parser(*index_schema, test_case.filter); + FilterParser parser(*index_schema, test_case.filter, {}); params.filter_parse_results = std::move(parser.Parse().value()); params.k = 100; auto vectors = DeterministicallyGenerateVectors(1, kVectorDimensions, 10.0); @@ -593,7 +593,7 @@ TEST_P(SearchTest, ParseParams) { std::vector query_vector(kVectorDimensions, 0.0); params.query = VectorToStr(query_vector); if (!test_case.filter.empty()) { - FilterParser parser(*params.index_schema, test_case.filter); + FilterParser parser(*params.index_schema, test_case.filter, {}); params.filter_parse_results = std::move(parser.Parse().value()); } auto neighbors = Search(params, query::SearchMode::kLocal);