diff --git a/Cargo.lock b/Cargo.lock index 79ec52f0..2b5f227f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,6 +45,12 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.18" @@ -733,6 +739,12 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.2.3" @@ -766,6 +778,33 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -898,6 +937,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam" version = "0.8.4" @@ -954,6 +1029,12 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crunchy" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" + [[package]] name = "crypto-common" version = "0.1.6" @@ -1513,6 +1594,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "half" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1572,6 +1663,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +[[package]] +name = "hermit-abi" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbd780fe5cc30f81464441920d82ac8740e2e46b29a6fad543ddd075229ce37e" + [[package]] name = "hex" version = "0.4.3" @@ -1821,6 +1918,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "is-terminal" +version = "0.4.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" +dependencies = [ + "hermit-abi 0.5.0", + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "is_ci" version = "1.2.0" @@ -1842,6 +1950,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.14" @@ -2229,6 +2346,12 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "option-ext" version = "0.2.0" @@ -2361,7 +2484,7 @@ dependencies = [ "cc", "fs_extra", "glob", - "itertools", + "itertools 0.10.5", "prost", "prost-build", "serde", @@ -2443,6 +2566,7 @@ name = "pgt_completions" version = "0.0.0" dependencies = [ "async-std", + "criterion", "pgt_schema_cache", "pgt_test_utils", "pgt_text_size", @@ -2452,6 +2576,7 @@ dependencies = [ "serde_json", "sqlx", "tokio", + "tracing", "tree-sitter", "tree_sitter_sql", ] @@ -2853,6 +2978,34 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "polling" version = "2.8.0" @@ -2998,7 +3151,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ "heck", - "itertools", + "itertools 0.14.0", "log", "multimap", "once_cell", @@ -3018,7 +3171,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools", + "itertools 0.14.0", "proc-macro2", "quote", "syn 2.0.90", @@ -4068,6 +4221,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.8.0" diff --git a/Cargo.toml b/Cargo.toml index e4472250..ef29bcaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,6 @@ syn = "1.0.109" termcolor = "1.4.1" test-log = "0.2.17" tokio = { version = "1.40.0", features = ["full"] } -tower-lsp = "0.20.0" tracing = { version = "0.1.40", default-features = false, features = ["std"] } tracing-bunyan-formatter = { version = "0.3.10 " } tracing-subscriber = "0.3.18" @@ -57,7 +56,6 @@ unicode-width = "0.1.12" # postgres specific crates pgt_analyse = { path = "./crates/pgt_analyse", version = "0.0.0" } pgt_analyser = { path = "./crates/pgt_analyser", version = "0.0.0" } -pgt_base_db = { path = "./crates/pgt_base_db", version = "0.0.0" } pgt_cli = { path = "./crates/pgt_cli", version = "0.0.0" } pgt_completions = { path = "./crates/pgt_completions", version = "0.0.0" } pgt_configuration = { path = "./crates/pgt_configuration", version = "0.0.0" } @@ -69,9 +67,7 @@ pgt_flags = { path = "./crates/pgt_flags", version = "0.0.0" } pgt_fs = { path = "./crates/pgt_fs", version = "0.0.0" } pgt_lexer = { path = "./crates/pgt_lexer", version = "0.0.0" } pgt_lexer_codegen = { path = "./crates/pgt_lexer_codegen", version = "0.0.0" } -pgt_lint = { path = "./crates/pgt_lint", version = "0.0.0" } pgt_lsp = { path = "./crates/pgt_lsp", version = "0.0.0" } -pgt_lsp_converters = { path = "./crates/pgt_lsp_converters", version = "0.0.0" } pgt_markup = { path = "./crates/pgt_markup", version = "0.0.0" } pgt_query_ext = { path = "./crates/pgt_query_ext", version = "0.0.0" } pgt_query_ext_codegen = { path = "./crates/pgt_query_ext_codegen", version = "0.0.0" } @@ -81,14 +77,11 @@ pgt_statement_splitter = { path = "./crates/pgt_statement_splitter", version pgt_text_edit = { path = "./crates/pgt_text_edit", version = "0.0.0" } pgt_text_size = { path = "./crates/pgt_text_size", version = "0.0.0" } pgt_treesitter_queries = { path = "./crates/pgt_treesitter_queries", version = "0.0.0" } -pgt_type_resolver = { path = "./crates/pgt_type_resolver", version = "0.0.0" } pgt_typecheck = { path = "./crates/pgt_typecheck", version = "0.0.0" } pgt_workspace = { path = "./crates/pgt_workspace", version = "0.0.0" } pgt_test_macros = { path = "./crates/pgt_test_macros" } pgt_test_utils = { path = "./crates/pgt_test_utils" } -docs_codegen = { path = "./docs/codegen", version = "0.0.0" } - [profile.dev.package] insta.opt-level = 3 diff --git a/crates/pgt_completions/Cargo.toml b/crates/pgt_completions/Cargo.toml index dba88f41..a69ee75a 100644 --- a/crates/pgt_completions/Cargo.toml +++ b/crates/pgt_completions/Cargo.toml @@ -22,6 +22,7 @@ pgt_treesitter_queries.workspace = true schemars = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +tracing = { workspace = true } tree-sitter.workspace = true tree_sitter_sql.workspace = true @@ -30,6 +31,7 @@ sqlx.workspace = true tokio = { version = "1.41.1", features = ["full"] } [dev-dependencies] +criterion = "0.5.1" pgt_test_utils.workspace = true [lib] @@ -37,3 +39,7 @@ doctest = false [features] schema = ["dep:schemars"] + +[[bench]] +harness = false +name = "sanitization" diff --git a/crates/pgt_completions/benches/sanitization.rs b/crates/pgt_completions/benches/sanitization.rs new file mode 100644 index 00000000..c21538de --- /dev/null +++ b/crates/pgt_completions/benches/sanitization.rs @@ -0,0 +1,249 @@ +use criterion::{Criterion, black_box, criterion_group, criterion_main}; +use pgt_completions::{CompletionParams, benchmark_sanitization}; +use pgt_schema_cache::SchemaCache; +use pgt_text_size::TextSize; + +static CURSOR_POS: &str = "€"; + +fn sql_and_pos(sql: &str) -> (String, usize) { + let pos = sql.find(CURSOR_POS).unwrap(); + (sql.replace(CURSOR_POS, ""), pos) +} + +fn get_tree(sql: &str) -> tree_sitter::Tree { + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + parser.parse(sql.to_string(), None).unwrap() +} + +fn to_params<'a>( + text: String, + tree: &'a tree_sitter::Tree, + pos: usize, + cache: &'a SchemaCache, +) -> CompletionParams<'a> { + let pos: u32 = pos.try_into().unwrap(); + CompletionParams { + position: TextSize::new(pos), + schema: &cache, + text, + tree: Some(tree), + } +} + +pub fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("small sql, adjusted", |b| { + let content = format!("select {} from users;", CURSOR_POS); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); + + c.bench_function("mid sql, adjusted", |b| { + let content = format!( + r#"select + n.oid :: int8 as "id!", + n.nspname as name, + u.rolname as "owner!" +from + pg_namespace n, + {} +where + n.nspowner = u.oid + and ( + pg_has_role(n.nspowner, 'USAGE') + or has_schema_privilege(n.oid, 'CREATE, USAGE') + ) + and not pg_catalog.starts_with(n.nspname, 'pg_temp_') + and not pg_catalog.starts_with(n.nspname, 'pg_toast_temp_');"#, + CURSOR_POS + ); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); + + c.bench_function("large sql, adjusted", |b| { + let content = format!( + r#"with + available_tables as ( + select + c.relname as table_name, + c.oid as table_oid, + c.relkind as class_kind, + n.nspname as schema_name + from + pg_catalog.pg_class c + join pg_catalog.pg_namespace n on n.oid = c.relnamespace + where + -- r: normal tables + -- v: views + -- m: materialized views + -- f: foreign tables + -- p: partitioned tables + c.relkind in ('r', 'v', 'm', 'f', 'p') + ), + available_indexes as ( + select + unnest (ix.indkey) as attnum, + ix.indisprimary as is_primary, + ix.indisunique as is_unique, + ix.indrelid as table_oid + from + {} + where + c.relkind = 'i' + ) +select + atts.attname as name, + ts.table_name, + ts.table_oid :: int8 as "table_oid!", + ts.class_kind :: char as "class_kind!", + ts.schema_name, + atts.atttypid :: int8 as "type_id!", + not atts.attnotnull as "is_nullable!", + nullif( + information_schema._pg_char_max_length (atts.atttypid, atts.atttypmod), + -1 + ) as varchar_length, + pg_get_expr (def.adbin, def.adrelid) as default_expr, + coalesce(ix.is_primary, false) as "is_primary_key!", + coalesce(ix.is_unique, false) as "is_unique!", + pg_catalog.col_description (ts.table_oid, atts.attnum) as comment +from + pg_catalog.pg_attribute atts + join available_tables ts on atts.attrelid = ts.table_oid + left join available_indexes ix on atts.attrelid = ix.table_oid + and atts.attnum = ix.attnum + left join pg_catalog.pg_attrdef def on atts.attrelid = def.adrelid + and atts.attnum = def.adnum +where + -- system columns, such as `cmax` or `tableoid`, have negative `attnum`s + atts.attnum >= 0; +"#, + CURSOR_POS + ); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); + + c.bench_function("small sql, unadjusted", |b| { + let content = format!("select e{} from users;", CURSOR_POS); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); + + c.bench_function("mid sql, unadjusted", |b| { + let content = format!( + r#"select + n.oid :: int8 as "id!", + n.nspname as name, + u.rolname as "owner!" +from + pg_namespace n, + pg_r{} +where + n.nspowner = u.oid + and ( + pg_has_role(n.nspowner, 'USAGE') + or has_schema_privilege(n.oid, 'CREATE, USAGE') + ) + and not pg_catalog.starts_with(n.nspname, 'pg_temp_') + and not pg_catalog.starts_with(n.nspname, 'pg_toast_temp_');"#, + CURSOR_POS + ); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); + + c.bench_function("large sql, unadjusted", |b| { + let content = format!( + r#"with + available_tables as ( + select + c.relname as table_name, + c.oid as table_oid, + c.relkind as class_kind, + n.nspname as schema_name + from + pg_catalog.pg_class c + join pg_catalog.pg_namespace n on n.oid = c.relnamespace + where + -- r: normal tables + -- v: views + -- m: materialized views + -- f: foreign tables + -- p: partitioned tables + c.relkind in ('r', 'v', 'm', 'f', 'p') + ), + available_indexes as ( + select + unnest (ix.indkey) as attnum, + ix.indisprimary as is_primary, + ix.indisunique as is_unique, + ix.indrelid as table_oid + from + pg_catalog.pg_class c + join pg_catalog.pg_index ix on c.oid = ix.indexrelid + where + c.relkind = 'i' + ) +select + atts.attname as name, + ts.table_name, + ts.table_oid :: int8 as "table_oid!", + ts.class_kind :: char as "class_kind!", + ts.schema_name, + atts.atttypid :: int8 as "type_id!", + not atts.attnotnull as "is_nullable!", + nullif( + information_schema._pg_char_max_length (atts.atttypid, atts.atttypmod), + -1 + ) as varchar_length, + pg_get_expr (def.adbin, def.adrelid) as default_expr, + coalesce(ix.is_primary, false) as "is_primary_key!", + coalesce(ix.is_unique, false) as "is_unique!", + pg_catalog.col_description (ts.table_oid, atts.attnum) as comment +from + pg_catalog.pg_attribute atts + join available_tables ts on atts.attrelid = ts.table_oid + left join available_indexes ix on atts.attrelid = ix.table_oid + and atts.attnum = ix.attnum + left join pg_catalog.pg_attrdef def on atts.attrelid = def.adrelid + and atts.attnum = def.adnum +where + -- system columns, such as `cmax` or `tableoid`, have negative `attnum`s + atts.attnum >= 0 +order by + sch{} "#, + CURSOR_POS + ); + + let cache = SchemaCache::default(); + let (sql, pos) = sql_and_pos(content.as_str()); + let tree = get_tree(sql.as_str()); + + b.iter(|| benchmark_sanitization(black_box(to_params(sql.clone(), &tree, pos, &cache)))); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index ed51c653..ec1232a5 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -5,6 +5,7 @@ use crate::{ context::CompletionContext, item::CompletionItem, providers::{complete_columns, complete_functions, complete_tables}, + sanitization::SanitizedCompletionParams, }; pub const LIMIT: usize = 50; @@ -17,8 +18,14 @@ pub struct CompletionParams<'a> { pub tree: &'a tree_sitter::Tree, } +#[tracing::instrument(level = "debug", skip_all, fields( + text = params.text, + position = params.position.to_string() +))] pub fn complete(params: CompletionParams) -> Vec<CompletionItem> { - let ctx = CompletionContext::new(¶ms); + let sanitized_params = SanitizedCompletionParams::from(params); + + let ctx = CompletionContext::new(&sanitized_params); let mut builder = CompletionBuilder::new(); diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 775b8870..a4578df8 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -6,7 +6,7 @@ use pgt_treesitter_queries::{ queries::{self, QueryResult}, }; -use crate::CompletionParams; +use crate::sanitization::SanitizedCompletionParams; #[derive(Debug, PartialEq, Eq)] pub enum ClauseType { @@ -17,6 +17,12 @@ pub enum ClauseType { Delete, } +#[derive(PartialEq, Eq, Debug)] +pub(crate) enum NodeText<'a> { + Replaced, + Original(&'a str), +} + impl TryFrom<&str> for ClauseType { type Error = String; @@ -49,7 +55,8 @@ impl TryFrom<String> for ClauseType { } pub(crate) struct CompletionContext<'a> { - pub ts_node: Option<tree_sitter::Node<'a>>, + pub node_under_cursor: Option<tree_sitter::Node<'a>>, + pub tree: &'a tree_sitter::Tree, pub text: &'a str, pub schema_cache: &'a SchemaCache, @@ -64,13 +71,13 @@ pub(crate) struct CompletionContext<'a> { } impl<'a> CompletionContext<'a> { - pub fn new(params: &'a CompletionParams) -> Self { + pub fn new(params: &'a SanitizedCompletionParams) -> Self { let mut ctx = Self { - tree: params.tree, + tree: params.tree.as_ref(), text: ¶ms.text, schema_cache: params.schema, position: usize::from(params.position), - ts_node: None, + node_under_cursor: None, schema_name: None, wrapping_clause_type: None, wrapping_statement_range: None, @@ -85,12 +92,10 @@ impl<'a> CompletionContext<'a> { } fn gather_info_from_ts_queries(&mut self) { - let tree = self.tree; - let stmt_range = self.wrapping_statement_range.as_ref(); let sql = self.text; - let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + let mut executor = TreeSitterQueriesExecutor::new(self.tree.root_node(), sql); executor.add_query_results::<queries::RelationMatch>(); @@ -117,9 +122,15 @@ impl<'a> CompletionContext<'a> { } } - pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> { + pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<NodeText<'a>> { let source = self.text; - ts_node.utf8_text(source.as_bytes()).ok() + ts_node.utf8_text(source.as_bytes()).ok().map(|txt| { + if SanitizedCompletionParams::is_sanitized_token(txt) { + NodeText::Replaced + } else { + NodeText::Original(txt) + } + }) } fn gather_tree_context(&mut self) { @@ -148,20 +159,20 @@ impl<'a> CompletionContext<'a> { fn gather_context_from_node( &mut self, mut cursor: tree_sitter::TreeCursor<'a>, - previous_node: tree_sitter::Node<'a>, + parent_node: tree_sitter::Node<'a>, ) { let current_node = cursor.node(); // prevent infinite recursion – this can happen if we only have a PROGRAM node - if current_node.kind() == previous_node.kind() { - self.ts_node = Some(current_node); + if current_node.kind() == parent_node.kind() { + self.node_under_cursor = Some(current_node); return; } - match previous_node.kind() { + match parent_node.kind() { "statement" | "subquery" => { self.wrapping_clause_type = current_node.kind().try_into().ok(); - self.wrapping_statement_range = Some(previous_node.range()); + self.wrapping_statement_range = Some(parent_node.range()); } "invocation" => self.is_invocation = true, @@ -170,11 +181,16 @@ impl<'a> CompletionContext<'a> { match current_node.kind() { "object_reference" => { - let txt = self.get_ts_node_content(current_node); - if let Some(txt) = txt { - let parts: Vec<&str> = txt.split('.').collect(); - if parts.len() == 2 { - self.schema_name = Some(parts[0].to_string()); + let content = self.get_ts_node_content(current_node); + if let Some(node_txt) = content { + match node_txt { + NodeText::Original(txt) => { + let parts: Vec<&str> = txt.split('.').collect(); + if parts.len() == 2 { + self.schema_name = Some(parts[0].to_string()); + } + } + NodeText::Replaced => {} } } } @@ -193,7 +209,14 @@ impl<'a> CompletionContext<'a> { // We have arrived at the leaf node if current_node.child_count() == 0 { - self.ts_node = Some(current_node); + if matches!( + self.get_ts_node_content(current_node).unwrap(), + NodeText::Replaced + ) { + self.node_under_cursor = None; + } else { + self.node_under_cursor = Some(current_node); + } return; } @@ -205,7 +228,8 @@ impl<'a> CompletionContext<'a> { #[cfg(test)] mod tests { use crate::{ - context::{ClauseType, CompletionContext}, + context::{ClauseType, CompletionContext, NodeText}, + sanitization::SanitizedCompletionParams, test_helper::{CURSOR_POS, get_text_and_position}, }; @@ -252,10 +276,10 @@ mod tests { let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: &tree, + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -284,10 +308,10 @@ mod tests { let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: &tree, + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -318,10 +342,10 @@ mod tests { let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: &tree, + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; @@ -343,18 +367,21 @@ mod tests { let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: &tree, + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); - let node = ctx.ts_node.unwrap(); + let node = ctx.node_under_cursor.unwrap(); - assert_eq!(ctx.get_ts_node_content(node), Some("select")); + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("select")) + ); assert_eq!( ctx.wrapping_clause_type, @@ -371,18 +398,21 @@ mod tests { let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: &tree, + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); - let node = ctx.ts_node.unwrap(); + let node = ctx.node_under_cursor.unwrap(); - assert_eq!(ctx.get_ts_node_content(node), Some("from")); + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("from")) + ); assert_eq!( ctx.wrapping_clause_type, Some(crate::context::ClauseType::From) @@ -397,18 +427,18 @@ mod tests { let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: &tree, + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); - let node = ctx.ts_node.unwrap(); + let node = ctx.node_under_cursor.unwrap(); - assert_eq!(ctx.get_ts_node_content(node), Some("")); + assert_eq!(ctx.get_ts_node_content(node), Some(NodeText::Original(""))); assert_eq!(ctx.wrapping_clause_type, None); } @@ -422,18 +452,21 @@ mod tests { let tree = get_tree(text.as_str()); - let params = crate::CompletionParams { + let params = SanitizedCompletionParams { position: (position as u32).into(), text, - tree: &tree, + tree: std::borrow::Cow::Owned(tree), schema: &pgt_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); - let node = ctx.ts_node.unwrap(); + let node = ctx.node_under_cursor.unwrap(); - assert_eq!(ctx.get_ts_node_content(node), Some("fro")); + assert_eq!( + ctx.get_ts_node_content(node), + Some(NodeText::Original("fro")) + ); assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select)); } } diff --git a/crates/pgt_completions/src/lib.rs b/crates/pgt_completions/src/lib.rs index 62470ff4..f8ca1a55 100644 --- a/crates/pgt_completions/src/lib.rs +++ b/crates/pgt_completions/src/lib.rs @@ -4,9 +4,11 @@ mod context; mod item; mod providers; mod relevance; +mod sanitization; #[cfg(test)] mod test_helper; pub use complete::*; pub use item::*; +pub use sanitization::*; diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 3f1c5bb9..2898b63f 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -143,12 +143,86 @@ mod tests { let params = get_test_params(&tree, &cache, case.get_input_query()); let mut items = complete(params); - let _ = items.split_off(3); + let _ = items.split_off(6); - items.sort_by(|a, b| a.label.cmp(&b.label)); + #[derive(Eq, PartialEq, Debug)] + struct LabelAndDesc { + label: String, + desc: String, + } + + let labels: Vec<LabelAndDesc> = items + .into_iter() + .map(|c| LabelAndDesc { + label: c.label, + desc: c.description, + }) + .collect(); + + let expected = vec![ + ("name", "Table: public.users"), + ("narrator", "Table: public.audio_books"), + ("narrator_id", "Table: private.audio_books"), + ("name", "Schema: pg_catalog"), + ("nameconcatoid", "Schema: pg_catalog"), + ("nameeq", "Schema: pg_catalog"), + ] + .into_iter() + .map(|(label, schema)| LabelAndDesc { + label: label.into(), + desc: schema.into(), + }) + .collect::<Vec<LabelAndDesc>>(); + + assert_eq!(labels, expected); + } + + #[tokio::test] + async fn suggests_relevant_columns_without_letters() { + let setup = r#" + create table users ( + id serial primary key, + name text, + address text, + email text + ); + "#; + + let test_case = TestCase { + message: "suggests user created tables first", + query: format!(r#"select {} from users"#, CURSOR_POS), + label: "", + description: "", + }; + + let (tree, cache) = get_test_deps(setup, test_case.get_input_query()).await; + let params = get_test_params(&tree, &cache, test_case.get_input_query()); + let results = complete(params); - let labels: Vec<String> = items.into_iter().map(|c| c.label).collect(); + let (first_four, _rest) = results.split_at(4); + + let has_column_in_first_four = |col: &'static str| { + first_four + .iter() + .find(|compl_item| compl_item.label.as_str() == col) + .is_some() + }; - assert_eq!(labels, vec!["name", "narrator", "narrator_id"]); + assert!( + has_column_in_first_four("id"), + "`id` not present in first four completion items." + ); + assert!( + has_column_in_first_four("name"), + "`name` not present in first four completion items." + ); + assert!( + has_column_in_first_four("address"), + "`address` not present in first four completion items." + ); + assert!( + has_column_in_first_four("email"), + "`email` not present in first four completion items." + ); } } diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 6a1e00c9..2074a4f1 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -73,9 +73,9 @@ mod tests { "#; let test_cases = vec![ - (format!("select * from us{}", CURSOR_POS), "users"), - (format!("select * from em{}", CURSOR_POS), "emails"), - (format!("select * from {}", CURSOR_POS), "addresses"), + (format!("select * from u{}", CURSOR_POS), "users"), + (format!("select * from e{}", CURSOR_POS), "emails"), + (format!("select * from a{}", CURSOR_POS), "addresses"), ]; for (query, expected_label) in test_cases { diff --git a/crates/pgt_completions/src/relevance.rs b/crates/pgt_completions/src/relevance.rs index ffe6cb22..9650a94d 100644 --- a/crates/pgt_completions/src/relevance.rs +++ b/crates/pgt_completions/src/relevance.rs @@ -1,4 +1,4 @@ -use crate::context::{ClauseType, CompletionContext}; +use crate::context::{ClauseType, CompletionContext, NodeText}; #[derive(Debug)] pub(crate) enum CompletionRelevanceData<'a> { @@ -33,7 +33,6 @@ impl CompletionRelevance<'_> { self.check_is_user_defined(); self.check_matches_schema(ctx); self.check_matches_query_input(ctx); - self.check_if_catalog(ctx); self.check_is_invocation(ctx); self.check_matching_clause_type(ctx); self.check_relations_in_stmt(ctx); @@ -42,10 +41,16 @@ impl CompletionRelevance<'_> { } fn check_matches_query_input(&mut self, ctx: &CompletionContext) { - let node = ctx.ts_node.unwrap(); + let node = match ctx.node_under_cursor { + Some(node) => node, + None => return, + }; let content = match ctx.get_ts_node_content(node) { - Some(c) => c, + Some(c) => match c { + NodeText::Original(s) => s, + NodeText::Replaced => return, + }, None => return, }; @@ -61,7 +66,7 @@ impl CompletionRelevance<'_> { .try_into() .expect("The length of the input exceeds i32 capacity"); - self.score += len * 5; + self.score += len * 10; }; } @@ -135,14 +140,6 @@ impl CompletionRelevance<'_> { } } - fn check_if_catalog(&mut self, ctx: &CompletionContext) { - if ctx.schema_name.as_ref().is_some_and(|n| n == "pg_catalog") { - return; - } - - self.score -= 5; // unlikely that the user wants schema data - } - fn check_relations_in_stmt(&mut self, ctx: &CompletionContext) { match self.data { CompletionRelevanceData::Table(_) | CompletionRelevanceData::Function(_) => return, @@ -182,5 +179,11 @@ impl CompletionRelevance<'_> { if system_schemas.contains(&schema.as_str()) { self.score -= 10; } + + // "public" is the default postgres schema where users + // create objects. Prefer it by a slight bit. + if schema.as_str() == "public" { + self.score += 2; + } } } diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs new file mode 100644 index 00000000..5ad8ba0e --- /dev/null +++ b/crates/pgt_completions/src/sanitization.rs @@ -0,0 +1,283 @@ +use std::borrow::Cow; + +use pgt_text_size::TextSize; + +use crate::CompletionParams; + +pub(crate) struct SanitizedCompletionParams<'a> { + pub position: TextSize, + pub text: String, + pub schema: &'a pgt_schema_cache::SchemaCache, + pub tree: Cow<'a, tree_sitter::Tree>, +} + +pub fn benchmark_sanitization(params: CompletionParams) -> String { + let params: SanitizedCompletionParams = params.try_into().unwrap(); + params.text +} + +impl<'larger, 'smaller> From<CompletionParams<'larger>> for SanitizedCompletionParams<'smaller> +where + 'larger: 'smaller, +{ + fn from(params: CompletionParams<'larger>) -> Self { + if cursor_inbetween_nodes(params.tree, params.position) + || cursor_prepared_to_write_token_after_last_node(params.tree, params.position) + || cursor_before_semicolon(params.tree, params.position) + { + SanitizedCompletionParams::with_adjusted_sql(params) + } else { + SanitizedCompletionParams::unadjusted(params) + } + } +} + +static SANITIZED_TOKEN: &str = "REPLACED_TOKEN"; + +impl<'larger, 'smaller> SanitizedCompletionParams<'smaller> +where + 'larger: 'smaller, +{ + fn with_adjusted_sql(params: CompletionParams<'larger>) -> Self { + let cursor_pos: usize = params.position.into(); + let mut sql = String::new(); + + for (idx, c) in params.text.chars().enumerate() { + if idx == cursor_pos { + sql.push_str(SANITIZED_TOKEN); + sql.push(' '); + } + sql.push(c); + } + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + let tree = parser.parse(sql.clone(), None).unwrap(); + + Self { + position: params.position, + text: sql, + schema: params.schema, + tree: Cow::Owned(tree), + } + } + fn unadjusted(params: CompletionParams<'larger>) -> Self { + Self { + position: params.position, + text: params.text.clone(), + schema: params.schema, + tree: Cow::Borrowed(params.tree), + } + } + + pub fn is_sanitized_token(txt: &str) -> bool { + txt == SANITIZED_TOKEN + } +} + +/// Checks if the cursor is positioned inbetween two SQL nodes. +/// +/// ```sql +/// select| from users; -- cursor "touches" select node. returns false. +/// select |from users; -- cursor "touches" from node. returns false. +/// select | from users; -- cursor is between select and from nodes. returns true. +/// ``` +fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool { + let mut cursor = tree.walk(); + let mut leaf_node = tree.root_node(); + + let byte = position.into(); + + // if the cursor escapes the root node, it can't be between nodes. + if byte < leaf_node.start_byte() || byte >= leaf_node.end_byte() { + return false; + } + + /* + * Get closer and closer to the leaf node, until + * a) there is no more child *for the node* or + * b) there is no more child *under the cursor*. + */ + loop { + let child_idx = cursor.goto_first_child_for_byte(position.into()); + if child_idx.is_none() { + break; + } + leaf_node = cursor.node(); + } + + let cursor_on_leafnode = byte >= leaf_node.start_byte() && leaf_node.end_byte() >= byte; + + /* + * The cursor is inbetween nodes if it is not within the range + * of a leaf node. + */ + !cursor_on_leafnode +} + +/// Checks if the cursor is positioned after the last node, +/// ready to write the next token: +/// +/// ```sql +/// select * from | -- ready to write! +/// select * from| -- user still needs to type a space +/// select * from | -- too far off. +/// ``` +fn cursor_prepared_to_write_token_after_last_node( + tree: &tree_sitter::Tree, + position: TextSize, +) -> bool { + let cursor_pos: usize = position.into(); + cursor_pos == tree.root_node().end_byte() + 1 +} + +fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool { + let mut cursor = tree.walk(); + let mut leaf_node = tree.root_node(); + + let byte: usize = position.into(); + + // if the cursor escapes the root node, it can't be between nodes. + if byte < leaf_node.start_byte() || byte >= leaf_node.end_byte() { + return false; + } + + loop { + let child_idx = cursor.goto_first_child_for_byte(position.into()); + if child_idx.is_none() { + break; + } + leaf_node = cursor.node(); + } + + // The semicolon node is on the same level as the statement: + // + // program [0..26] + // statement [0..19] + // ; [25..26] + // + // However, if we search for position 21, we'll still land on the semi node. + // We must manually verify that the cursor is between the statement and the semi nodes. + + // if the last node is not a semi, the statement is not completed. + if leaf_node.kind() != ";" { + return false; + } + + // not okay to be on the semi. + if byte == leaf_node.start_byte() { + return false; + } + + leaf_node + .prev_named_sibling() + .map(|n| n.end_byte() < byte) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use pgt_text_size::TextSize; + + use crate::sanitization::{ + cursor_before_semicolon, cursor_inbetween_nodes, + cursor_prepared_to_write_token_after_last_node, + }; + + #[test] + fn test_cursor_inbetween_nodes() { + // note: two spaces between select and from. + let input = "select from users;"; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let mut tree = parser.parse(input.to_string(), None).unwrap(); + + // select | from users; <-- just right, one space after select token, one space before from + assert!(cursor_inbetween_nodes(&mut tree, TextSize::new(7))); + + // select| from users; <-- still on select token + assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(6))); + + // select |from users; <-- already on from token + assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(8))); + + // select from users;| + assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(19))); + } + + #[test] + fn test_cursor_after_nodes() { + let input = "select * from"; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let mut tree = parser.parse(input.to_string(), None).unwrap(); + + // select * from| <-- still on previous token + assert!(!cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(13) + )); + + // select * from | <-- too far off, two spaces afterward + assert!(!cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(15) + )); + + // select * |from <-- it's within + assert!(!cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(9) + )); + + // select * from | <-- just right + assert!(cursor_prepared_to_write_token_after_last_node( + &mut tree, + TextSize::new(14) + )); + } + + #[test] + fn test_cursor_before_semicolon() { + // Idx "13" is the exlusive end of `select * from` (first space after from) + // Idx "18" is right where the semi is + let input = "select * from ;"; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(tree_sitter_sql::language()) + .expect("Error loading sql language"); + + let mut tree = parser.parse(input.to_string(), None).unwrap(); + + // select * from ;| <-- it's after the statement + assert!(!cursor_before_semicolon(&mut tree, TextSize::new(19))); + + // select * from| ; <-- still touches the from + assert!(!cursor_before_semicolon(&mut tree, TextSize::new(13))); + + // not okay to be ON the semi. + // select * from |; + assert!(!cursor_before_semicolon(&mut tree, TextSize::new(18))); + + // anything is fine here + // select * from | ; + // select * from | ; + // select * from | ; + // select * from |; + assert!(cursor_before_semicolon(&mut tree, TextSize::new(14))); + assert!(cursor_before_semicolon(&mut tree, TextSize::new(15))); + assert!(cursor_before_semicolon(&mut tree, TextSize::new(16))); + assert!(cursor_before_semicolon(&mut tree, TextSize::new(17))); + } +} diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index 58e9baf7..4339688e 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -15,7 +15,6 @@ impl From<&str> for InputQuery { fn from(value: &str) -> Self { let position = value .find(CURSOR_POS) - .map(|p| p.saturating_sub(1)) .expect("Insert Cursor Position into your Query."); InputQuery { @@ -74,3 +73,43 @@ pub(crate) fn get_test_params<'a>( text, } } + +#[cfg(test)] +mod tests { + use crate::test_helper::CURSOR_POS; + + use super::InputQuery; + + #[test] + fn input_query_should_extract_correct_position() { + struct TestCase { + query: String, + expected_pos: usize, + expected_sql_len: usize, + } + + let cases = vec![ + TestCase { + query: format!("select * from{}", CURSOR_POS), + expected_pos: 13, + expected_sql_len: 13, + }, + TestCase { + query: format!("{}select * from", CURSOR_POS), + expected_pos: 0, + expected_sql_len: 13, + }, + TestCase { + query: format!("select {} from", CURSOR_POS), + expected_pos: 7, + expected_sql_len: 12, + }, + ]; + + for case in cases { + let query = InputQuery::from(case.query.as_str()); + assert_eq!(query.position, case.expected_pos); + assert_eq!(query.sql.len(), case.expected_sql_len); + } + } +} diff --git a/crates/pgt_text_size/src/range.rs b/crates/pgt_text_size/src/range.rs index 95b0db58..3cfc3c96 100644 --- a/crates/pgt_text_size/src/range.rs +++ b/crates/pgt_text_size/src/range.rs @@ -281,6 +281,24 @@ impl TextRange { }) } + /// Expand the range's end by the given offset. + /// + /// # Examples + /// + /// ```rust + /// # use pgt_text_size::*; + /// assert_eq!( + /// TextRange::new(2.into(), 4.into()).checked_expand_end(16.into()).unwrap(), + /// TextRange::new(2.into(), 20.into()), + /// ); + /// ``` + #[inline] + pub fn checked_expand_end(self, offset: TextSize) -> Option<TextRange> { + Some(TextRange { + start: self.start, + end: self.end.checked_add(offset)?, + }) + } /// Subtract an offset from this range. /// /// Note that this is not appropriate for changing where a `TextRange` is diff --git a/crates/pgt_workspace/src/features/completions.rs b/crates/pgt_workspace/src/features/completions.rs index 8fb13313..4a5c5e29 100644 --- a/crates/pgt_workspace/src/features/completions.rs +++ b/crates/pgt_workspace/src/features/completions.rs @@ -1,6 +1,10 @@ +use std::sync::Arc; + use pgt_completions::CompletionItem; use pgt_fs::PgTPath; -use pgt_text_size::TextSize; +use pgt_text_size::{TextRange, TextSize}; + +use crate::workspace::{GetCompletionsFilter, GetCompletionsMapper, ParsedDocument, StatementId}; #[derive(Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] @@ -24,3 +28,167 @@ impl IntoIterator for CompletionsResult { self.items.into_iter() } } + +pub(crate) fn get_statement_for_completions<'a>( + doc: &'a ParsedDocument, + position: TextSize, +) -> Option<(StatementId, TextRange, String, Arc<tree_sitter::Tree>)> { + let count = doc.count(); + // no arms no cookies + if count == 0 { + return None; + } + + let mut eligible_statements = doc.iter_with_filter( + GetCompletionsMapper, + GetCompletionsFilter { + cursor_position: position, + }, + ); + + if count == 1 { + eligible_statements.next() + } else { + let mut prev_stmt = None; + + for current_stmt in eligible_statements { + /* + * If we have multiple statements, we want to make sure that we do not overlap + * with the next one. + * + * select 1 |select 1; + */ + if prev_stmt.is_some_and(|_| current_stmt.1.contains(position)) { + return None; + } + prev_stmt = Some(current_stmt) + } + + prev_stmt + } +} + +#[cfg(test)] +mod tests { + use pgt_fs::PgTPath; + use pgt_text_size::TextSize; + + use crate::workspace::ParsedDocument; + + use super::get_statement_for_completions; + + static CURSOR_POSITION: &str = "€"; + + fn get_doc_and_pos(sql: &str) -> (ParsedDocument, TextSize) { + let pos = sql + .find(CURSOR_POSITION) + .expect("Please add cursor position to test sql"); + + let pos: u32 = pos.try_into().unwrap(); + + ( + ParsedDocument::new( + PgTPath::new("test.sql"), + sql.replace(CURSOR_POSITION, "").into(), + 5, + ), + TextSize::new(pos), + ) + } + + #[test] + fn finds_matching_statement() { + let sql = format!( + r#" + select * from users; + + update {}users set email = 'myemail@com'; + + select 1; + "#, + CURSOR_POSITION + ); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + let (_, _, text, _) = + get_statement_for_completions(&doc, position).expect("Expected Statement"); + + assert_eq!(text, "update users set email = 'myemail@com';") + } + + #[test] + fn does_not_break_when_no_statements_exist() { + let sql = format!("{}", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + assert!(matches!( + get_statement_for_completions(&doc, position), + None + )); + } + + #[test] + fn does_not_return_overlapping_statements_if_too_close() { + let sql = format!("select * from {}select 1;", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + // make sure these are parsed as two + assert_eq!(doc.count(), 2); + + assert!(matches!( + get_statement_for_completions(&doc, position), + None + )); + } + + #[test] + fn is_fine_with_spaces() { + let sql = format!("select * from {} ;", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + let (_, _, text, _) = + get_statement_for_completions(&doc, position).expect("Expected Statement"); + + assert_eq!(text, "select * from ;") + } + + #[test] + fn considers_offset() { + let sql = format!("select * from {}", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + let (_, _, text, _) = + get_statement_for_completions(&doc, position).expect("Expected Statement"); + + assert_eq!(text, "select * from") + } + + #[test] + fn does_not_consider_too_far_offset() { + let sql = format!("select * from {}", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + assert!(matches!( + get_statement_for_completions(&doc, position), + None + )); + } + + #[test] + fn does_not_consider_offset_if_statement_terminated_by_semi() { + let sql = format!("select * from users;{}", CURSOR_POSITION); + + let (doc, position) = get_doc_and_pos(sql.as_str()); + + assert!(matches!( + get_statement_for_completions(&doc, position), + None + )); + } +} diff --git a/crates/pgt_workspace/src/workspace.rs b/crates/pgt_workspace/src/workspace.rs index 681ab95f..54f7200b 100644 --- a/crates/pgt_workspace/src/workspace.rs +++ b/crates/pgt_workspace/src/workspace.rs @@ -22,6 +22,7 @@ mod client; mod server; pub use server::StatementId; +pub(crate) use server::parsed_document::*; #[derive(Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index 27f5e8be..aab2333f 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -8,7 +8,7 @@ use document::Document; use futures::{StreamExt, stream}; use parsed_document::{ AsyncDiagnosticsMapper, CursorPositionFilter, DefaultMapper, ExecuteStatementMapper, - GetCompletionsMapper, ParsedDocument, SyncDiagnosticsMapper, + ParsedDocument, SyncDiagnosticsMapper, }; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext}; @@ -29,7 +29,7 @@ use crate::{ self, CodeAction, CodeActionKind, CodeActionsResult, CommandAction, CommandActionCategory, ExecuteStatementParams, ExecuteStatementResult, }, - completions::{CompletionsResult, GetCompletionsParams}, + completions::{CompletionsResult, GetCompletionsParams, get_statement_for_completions}, diagnostics::{PullDiagnosticsParams, PullDiagnosticsResult}, }, settings::{Settings, SettingsHandle, SettingsHandleMut}, @@ -46,9 +46,9 @@ mod analyser; mod async_helper; mod change; mod db_connection; -mod document; +pub(crate) mod document; mod migration; -mod parsed_document; +pub(crate) mod parsed_document; mod pg_query; mod schema_cache_manager; mod sql_function; @@ -469,37 +469,36 @@ impl Workspace for WorkspaceServer { &self, params: GetCompletionsParams, ) -> Result<CompletionsResult, WorkspaceError> { - let parser = self + let parsed_doc = self .parsed_documents .get(¶ms.path) .ok_or(WorkspaceError::not_found())?; let pool = match self.connection.read().unwrap().get_pool() { Some(pool) => pool, - None => return Ok(CompletionsResult::default()), + None => { + tracing::debug!("No connection to database. Skipping completions."); + return Ok(CompletionsResult::default()); + } }; let schema_cache = self.schema_cache.load(pool)?; - let items = parser - .iter_with_filter( - GetCompletionsMapper, - CursorPositionFilter::new(params.position), - ) - .flat_map(|(_id, range, content, cst)| { - // `offset` is the position in the document, - // but we need the position within the *statement*. + match get_statement_for_completions(&parsed_doc, params.position) { + None => Ok(CompletionsResult::default()), + Some((_id, range, content, cst)) => { let position = params.position - range.start(); - pgt_completions::complete(pgt_completions::CompletionParams { + + let items = pgt_completions::complete(pgt_completions::CompletionParams { position, schema: schema_cache.as_ref(), tree: &cst, text: content, - }) - }) - .collect(); + }); - Ok(CompletionsResult { items }) + Ok(CompletionsResult { items }) + } + } } } diff --git a/crates/pgt_workspace/src/workspace/server/parsed_document.rs b/crates/pgt_workspace/src/workspace/server/parsed_document.rs index a110fb1f..f752c79c 100644 --- a/crates/pgt_workspace/src/workspace/server/parsed_document.rs +++ b/crates/pgt_workspace/src/workspace/server/parsed_document.rs @@ -130,7 +130,7 @@ pub trait StatementMapper<'a> { fn map( &self, - parser: &'a ParsedDocument, + parsed: &'a ParsedDocument, id: StatementId, range: TextRange, content: &str, @@ -138,7 +138,7 @@ pub trait StatementMapper<'a> { } pub trait StatementFilter<'a> { - fn predicate(&self, id: &StatementId, range: &TextRange) -> bool; + fn predicate(&self, id: &StatementId, range: &TextRange, content: &str) -> bool; } pub struct ParseIterator<'a, M, F> { @@ -171,7 +171,7 @@ where fn next(&mut self) -> Option<Self::Item> { // First check if we have any pending sub-statements to process if let Some((id, range, content)) = self.pending_sub_statements.pop() { - if self.filter.predicate(&id, &range) { + if self.filter.predicate(&id, &range, content.as_str()) { return Some(self.mapper.map(self.parser, id, range, &content)); } // If the sub-statement doesn't pass the filter, continue to the next item @@ -207,7 +207,7 @@ where } // Return the current statement if it passes the filter - if self.filter.predicate(&root_id, &range) { + if self.filter.predicate(&root_id, &range, content) { return Some(self.mapper.map(self.parser, root_id, range, content)); } @@ -329,14 +329,40 @@ impl<'a> StatementMapper<'a> for GetCompletionsMapper { range: TextRange, content: &str, ) -> Self::Output { - let cst_result = parser.cst_db.get_or_cache_tree(&id, content); - (id, range, content.to_string(), cst_result) + let tree = parser.cst_db.get_or_cache_tree(&id, content); + (id, range, content.into(), tree) + } +} + +/* + * We allow an offset of two for the statement: + * + * select * from | <-- we want to suggest items for the next token. + * + * However, if the current statement is terminated by a semicolon, we don't apply any + * offset. + * + * select * from users; | <-- no autocompletions here. + */ +pub struct GetCompletionsFilter { + pub cursor_position: TextSize, +} +impl<'a> StatementFilter<'a> for GetCompletionsFilter { + fn predicate(&self, _id: &StatementId, range: &TextRange, content: &str) -> bool { + let is_terminated_by_semi = content.chars().last().is_some_and(|c| c == ';'); + + let measuring_range = if is_terminated_by_semi { + *range + } else { + range.checked_expand_end(2.into()).unwrap_or(*range) + }; + measuring_range.contains(self.cursor_position) } } pub struct NoFilter; impl<'a> StatementFilter<'a> for NoFilter { - fn predicate(&self, _id: &StatementId, _range: &TextRange) -> bool { + fn predicate(&self, _id: &StatementId, _range: &TextRange, _content: &str) -> bool { true } } @@ -352,7 +378,7 @@ impl CursorPositionFilter { } impl<'a> StatementFilter<'a> for CursorPositionFilter { - fn predicate(&self, _id: &StatementId, range: &TextRange) -> bool { + fn predicate(&self, _id: &StatementId, range: &TextRange, _content: &str) -> bool { range.contains(self.pos) } } @@ -368,7 +394,7 @@ impl IdFilter { } impl<'a> StatementFilter<'a> for IdFilter { - fn predicate(&self, id: &StatementId, _range: &TextRange) -> bool { + fn predicate(&self, id: &StatementId, _range: &TextRange, _content: &str) -> bool { *id == self.id } }