Skip to content

Commit 5252b54

Browse files
committed
Same wildcard should match the same subtree
fixes #15
1 parent 4dbd5d7 commit 5252b54

File tree

3 files changed

+115
-14
lines changed

3 files changed

+115
-14
lines changed

lua/ssr/parse.lua

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
local ts = vim.treesitter
22
local parsers = require "nvim-treesitter.parsers"
3-
local u = require "ssr.utils"
3+
local wildcard_prefix = require("ssr.search").wildcard_prefix
44

55
local M = {}
66

7-
M.wildcard_prefix = "__ssr_var_"
8-
97
---@class ParseContext
108
---@field lang string
119
---@field before string
@@ -66,7 +64,7 @@ end
6664
---@return TSNode, string
6765
function ParseContext:parse(pattern)
6866
-- Replace named wildcard $name to identifier __ssr_var_name to avoid syntax error.
69-
pattern = pattern:gsub("%$([_%a%d]+)", M.wildcard_prefix .. "%1")
67+
pattern = pattern:gsub("%$([_%a%d]+)", wildcard_prefix .. "%1")
7068
local context_text = self.before .. pattern .. self.after
7169
local root = ts.get_string_parser(context_text, self.lang):parse()[1]:root()
7270
local lines = vim.split(pattern, "\n")

lua/ssr/search.lua

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ local api = vim.api
22
local ts = vim.treesitter
33
local parsers = require "nvim-treesitter.parsers"
44
local u = require "ssr.utils"
5-
local wildcard_prefix = require("ssr.parse").wildcard_prefix
65

76
local M = {}
87

8+
M.wildcard_prefix = "__ssr_var_"
9+
910
---@class Match
1011
---@field range ExtmarkRange
1112
---@field captures ExtmarkRange[]
@@ -41,6 +42,33 @@ function ExtmarkRange:get()
4142
return extmark[1], extmark[2], extmark[3].end_row, extmark[3].end_col
4243
end
4344

45+
-- Compare if two captured trees can match.
46+
-- The check is loose because users want to match different types of node.
47+
-- e.g. converting `{ foo: foo }` to shorthand `{ foo }`.
48+
ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred)
49+
---@param node1 TSNode
50+
---@param node2 TSNode
51+
---@return boolean
52+
local function tree_match(node1, node2)
53+
if node1:named() ~= node2:named() then
54+
return false
55+
end
56+
if node1:child_count() == 0 or node2:child_count() == 0 then
57+
return ts.get_node_text(node1, buf) == ts.get_node_text(node2, buf)
58+
end
59+
if node1:child_count() ~= node2:child_count() then
60+
return false
61+
end
62+
for i = 0, node1:child_count() - 1 do
63+
if not tree_match(node1:child(i), node2:child(i)) then
64+
return false
65+
end
66+
end
67+
return true
68+
end
69+
return tree_match(match[pred[2]], match[pred[3]])
70+
end, true)
71+
4472
-- Build a TS sexpr represting the node.
4573
---@param node TSNode
4674
---@param source string
@@ -54,11 +82,19 @@ local function build_sexpr(node, source)
5482
local text = ts.get_node_text(node, source)
5583

5684
-- Special identifier __ssr_var_name is a named wildcard.
57-
local var = text:match("^" .. wildcard_prefix .. "([_%a%d]+)$")
85+
-- Handle this early to make sure wildcard captures largest node.
86+
local var = text:match("^" .. M.wildcard_prefix .. "([_%a%d]+)$")
5887
if var then
59-
wildcards[var] = next_idx
60-
next_idx = next_idx + 1
61-
return "(_) @" .. var
88+
if not wildcards[var] then
89+
wildcards[var] = next_idx
90+
next_idx = next_idx + 1
91+
return "(_) @" .. var
92+
else
93+
-- Same wildcard should match the same subtree.
94+
local sexpr = string.format("(_) @_%d (#ssr-tree-match? @_%d @%s)", next_idx, next_idx, var)
95+
next_idx = next_idx + 1
96+
return sexpr
97+
end
6298
end
6399

64100
-- Leaf nodes (keyword, identifier, literal and symbol) should match text.

tests/ssr_spec.lua

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,22 +138,89 @@ foo($a, $b) ==>> ($a).foo($b)
138138
String::from((y + 5).foo(z))
139139
]]
140140

141-
t [[ go parsed correctly
141+
t [[ go parse Go := in function
142142
func main() {
143-
<commit, _ := os.LookupEnv("GITHUB_SHA")>
144-
print(commit)
143+
<commit, _ := os.LookupEnv("GITHUB_SHA")>
145144
}
146145
====
147146
$a, _ := os.LookupEnv($b)
148147
==>>
149148
$a := os.Getenv($b)
150149
====
151150
func main() {
152-
commit := os.Getenv("GITHUB_SHA")
153-
print(commit)
151+
commit := os.Getenv("GITHUB_SHA")
154152
}
155153
]]
156154

155+
t [[ go match Go if err
156+
fn main() {
157+
<if err != nil {
158+
panic(err)
159+
}>
160+
}
161+
====
162+
if err != nil { panic(err) } ==>> x
163+
====
164+
fn main() {
165+
x
166+
}
167+
]]
168+
169+
t [[ rust reused wildcard: compound assignments
170+
<idx = idx + 1>;
171+
bar = foo + idx;
172+
*foo.bar() = * foo . bar () + 1;
173+
(foo + bar) = (foo + bar) + 1;
174+
(foo + bar) = (foo - bar) + 1;
175+
====
176+
$a = $a + $b ==>> $a += $b
177+
====
178+
idx += 1;
179+
bar = foo + idx;
180+
*foo.bar() += 1;
181+
(foo + bar) += 1;
182+
(foo + bar) = (foo - bar) + 1;
183+
]]
184+
185+
t [[ python reused wildcard: indent
186+
def f():
187+
<if await foo.bar(baz):
188+
if await foo.bar(baz):
189+
pass>
190+
====
191+
if $foo:
192+
if $foo:
193+
$body
194+
==>>
195+
if $foo:
196+
$body
197+
====
198+
def f():
199+
if await foo.bar(baz):
200+
pass
201+
]]
202+
203+
-- two `foo`s have different type: `property_identifier` and `identifier`
204+
t [[ javascript reused wildcard: match different node types 1
205+
<{ foo: foo }>
206+
{ foo: bar }
207+
====
208+
{ $a: $a } ==>> { $a }
209+
====
210+
{ foo }
211+
{ foo: bar }
212+
]]
213+
214+
t [[ lua reused wildcard: match different node types 2
215+
<local api = vim.api>
216+
local a = vim.api
217+
====
218+
local $a = vim.$a ==>> x
219+
====
220+
x
221+
local a = vim.api
222+
]]
223+
157224
describe("", function()
158225
for _, s in ipairs(tests) do
159226
local ft, desc, content, pattern, template, expected =

0 commit comments

Comments
 (0)