Skip to content

Commit 1ac9f6c

Browse files
committed
Same wildcard should match the same subtree
fixes #15
1 parent dd7dc92 commit 1ac9f6c

File tree

2 files changed

+71
-5
lines changed

2 files changed

+71
-5
lines changed

lua/ssr/search.lua

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,33 @@ function ExtmarkRange:get()
4141
return extmark[1], extmark[2], extmark[3].end_row, extmark[3].end_col
4242
end
4343

44+
ts.query.add_predicate("is-same-tree?", function(match, _pattern, buf, pred)
45+
---@param node1 TSNode
46+
---@param node2 TSNode
47+
---@return boolean
48+
local function is_same_tree(node1, node2)
49+
if node1:type() ~= node2:type() then
50+
return false
51+
end
52+
if node1:child_count() ~= node2:child_count() then
53+
return false
54+
end
55+
if node1:child_count() == 0 then
56+
if ts.get_node_text(node1, buf) ~= ts.get_node_text(node2, buf) then
57+
return false
58+
end
59+
else
60+
for i = 0, node1:child_count() - 1 do
61+
if not is_same_tree(node1:child(i), node2:child(i)) then
62+
return false
63+
end
64+
end
65+
end
66+
return true
67+
end
68+
return is_same_tree(match[pred[2]], match[pred[3]])
69+
end, true)
70+
4471
-- Build a TS sexpr represting the node.
4572
---@param node TSNode
4673
---@param source string
@@ -56,9 +83,16 @@ local function build_sexpr(node, source)
5683
-- Special identifier __ssr_var_name is a named wildcard.
5784
local var = text:match("^" .. wildcard_prefix .. "([_%a%d]+)$")
5885
if var then
59-
wildcards[var] = next_idx
60-
next_idx = next_idx + 1
61-
return "(_) @" .. var
86+
if not wildcards[var] then
87+
wildcards[var] = next_idx
88+
next_idx = next_idx + 1
89+
return "(_) @" .. var
90+
else
91+
-- Same wildcard should match the same subtree.
92+
local sexpr = string.format("(_) @_%d (#is-same-tree? @_%d @%s)", next_idx, next_idx, var)
93+
next_idx = next_idx + 1
94+
return sexpr
95+
end
6296
end
6397

6498
-- Leaf nodes (keyword, identifier, literal and symbol) should match text.

tests/ssr_spec.lua

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ String::from((y + 5).foo(z))
141141
t [[ go parsed correctly
142142
func main() {
143143
<commit, _ := os.LookupEnv("GITHUB_SHA")>
144-
print(commit)
145144
}
146145
====
147146
$a, _ := os.LookupEnv($b)
@@ -150,10 +149,43 @@ $a := os.Getenv($b)
150149
====
151150
func main() {
152151
commit := os.Getenv("GITHUB_SHA")
153-
print(commit)
154152
}
155153
]]
156154

155+
t [[ rust match same tree
156+
<idx = idx + 1>;
157+
bar = foo + idx;
158+
*foo.bar() = * foo . bar () + 1;
159+
(foo + bar) = (foo + bar) + 1;
160+
(foo + bar) = (foo - bar) + 1;
161+
====
162+
$a = $a + $b ==>> $a += $b
163+
====
164+
idx += 1;
165+
bar = foo + idx;
166+
*foo.bar() += 1;
167+
(foo + bar) += 1;
168+
(foo + bar) = (foo - bar) + 1;
169+
]]
170+
171+
t [[ python correct indent for same tree
172+
def f():
173+
<if await foo.bar(baz):
174+
if await foo.bar(baz):
175+
pass>
176+
====
177+
if $foo:
178+
if $foo:
179+
$body
180+
==>>
181+
if $foo:
182+
$body
183+
====
184+
def f():
185+
if await foo.bar(baz):
186+
pass
187+
]]
188+
157189
describe("", function()
158190
for _, s in ipairs(tests) do
159191
local ft, desc, content, pattern, template, expected =

0 commit comments

Comments
 (0)