Skip to content

Commit 4e59087

Browse files
committed
replace_all
1 parent 386503d commit 4e59087

File tree

8 files changed

+571
-33
lines changed

8 files changed

+571
-33
lines changed

gopls/internal/golang/codeaction.go

+10
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ var codeActionProducers = [...]codeActionProducer{
237237
{kind: settings.RefactorExtractMethod, fn: refactorExtractMethod},
238238
{kind: settings.RefactorExtractToNewFile, fn: refactorExtractToNewFile},
239239
{kind: settings.RefactorExtractVariable, fn: refactorExtractVariable},
240+
{kind: settings.RefactorReplaceAllOccursOfExpr, fn: refactorReplaceAllOccursOfExpr},
240241
{kind: settings.RefactorInlineCall, fn: refactorInlineCall, needPkg: true},
241242
{kind: settings.RefactorRewriteChangeQuote, fn: refactorRewriteChangeQuote},
242243
{kind: settings.RefactorRewriteFillStruct, fn: refactorRewriteFillStruct, needPkg: true},
@@ -458,6 +459,15 @@ func refactorExtractVariable(ctx context.Context, req *codeActionsRequest) error
458459
return nil
459460
}
460461

462+
// refactorReplaceAllOccursOfExpr produces "Replace all occcurrances of expr" code action.
463+
// See [replaceAllOccursOfExpr] for command implementation.
464+
func refactorReplaceAllOccursOfExpr(ctx context.Context, req *codeActionsRequest) error {
465+
if _, ok, _ := allOccurs(req.start, req.end, req.pgf.File); ok {
466+
req.addApplyFixAction(fmt.Sprintf("Replace all occcurrances of expression"), fixReplaceAllOccursOfExpr, req.loc)
467+
}
468+
return nil
469+
}
470+
461471
// refactorExtractToNewFile produces "Extract declarations to new file" code actions.
462472
// See [server.commandHandler.ExtractToNewFile] for command implementation.
463473
func refactorExtractToNewFile(ctx context.Context, req *codeActionsRequest) error {

gopls/internal/golang/extract.go

+340-3
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,22 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
3636
// TODO: stricter rules for selectorExpr.
3737
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
3838
*ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
39-
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0)
39+
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", 0)
4040
lhsNames = append(lhsNames, lhsName)
4141
case *ast.CallExpr:
4242
tup, ok := info.TypeOf(expr).(*types.Tuple)
4343
if !ok {
4444
// If the call expression only has one return value, we can treat it the
4545
// same as our standard extract variable case.
46-
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0)
46+
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", 0)
4747
lhsNames = append(lhsNames, lhsName)
4848
break
4949
}
5050
idx := 0
5151
for i := 0; i < tup.Len(); i++ {
5252
// Generate a unique variable for each return value.
5353
var lhsName string
54-
lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", idx)
54+
lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", idx)
5555
lhsNames = append(lhsNames, lhsName)
5656
}
5757
default:
@@ -105,6 +105,343 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
105105
}, nil
106106
}
107107

108+
func replaceAllOccursOfExpr(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) {
109+
tokFile := fset.File(file.Pos())
110+
exprs, _, err := allOccurs(start, end, file)
111+
if err != nil {
112+
return nil, nil, fmt.Errorf("extractVariable: cannot extract %s: %v", safetoken.StartPosition(fset, start), err)
113+
}
114+
115+
scopes := make([][]*types.Scope, len(exprs))
116+
for i, e := range exprs {
117+
path, _ := astutil.PathEnclosingInterval(file, e.Pos(), e.End())
118+
scopes[i] = CollectScopes(info, path, e.Pos())
119+
}
120+
121+
// Find the deepest common scope among all expressions.
122+
commonScope, err := findDeepestCommonScope(scopes)
123+
if err != nil {
124+
return nil, nil, fmt.Errorf("extractVariable: %v", err)
125+
}
126+
127+
var innerScopes []*types.Scope
128+
for _, scope := range scopes {
129+
for _, s := range scope {
130+
if s != nil {
131+
innerScopes = append(innerScopes, s)
132+
break
133+
}
134+
}
135+
}
136+
if len(innerScopes) != len(exprs) {
137+
return nil, nil, fmt.Errorf("extractVariable: nil scope")
138+
}
139+
// So the largest scope's name won't conflict.
140+
innerScopes = append(innerScopes, commonScope)
141+
142+
// Create new AST node for extracted code.
143+
var lhsNames []string
144+
switch expr := exprs[0].(type) {
145+
// TODO: stricter rules for selectorExpr.
146+
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
147+
*ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
148+
lhsName, _ := generateAvailableIdentifierForAllScopes(innerScopes, "newVar", 0)
149+
lhsNames = append(lhsNames, lhsName)
150+
case *ast.CallExpr:
151+
tup, ok := info.TypeOf(expr).(*types.Tuple)
152+
if !ok {
153+
// If the call expression only has one return value, we can treat it the
154+
// same as our standard extract variable case.
155+
lhsName, _ := generateAvailableIdentifierForAllScopes(innerScopes, "newVar", 0)
156+
lhsNames = append(lhsNames, lhsName)
157+
break
158+
}
159+
idx := 0
160+
for i := 0; i < tup.Len(); i++ {
161+
// Generate a unique variable for each return value.
162+
var lhsName string
163+
lhsName, idx = generateAvailableIdentifierForAllScopes(innerScopes, "newVar", idx)
164+
lhsNames = append(lhsNames, lhsName)
165+
}
166+
default:
167+
return nil, nil, fmt.Errorf("cannot extract %T", expr)
168+
}
169+
170+
var validPath []ast.Node
171+
if commonScope != innerScopes[0] {
172+
// This means the first expr within function body is not the largest scope,
173+
// we need to find the scope immediately follow the common
174+
// scope where we will insert the statement before.
175+
child := innerScopes[0]
176+
for p := child; p != nil; p = p.Parent() {
177+
if p == commonScope {
178+
break
179+
}
180+
child = p
181+
}
182+
validPath, _ = astutil.PathEnclosingInterval(file, child.Pos(), child.End())
183+
} else {
184+
// Just insert before the first expr.
185+
validPath, _ = astutil.PathEnclosingInterval(file, exprs[0].Pos(), exprs[0].End())
186+
}
187+
//
188+
// TODO: There is a bug here: for a variable declared in a labeled
189+
// switch/for statement it returns the for/switch statement itself
190+
// which produces the below code which is a compiler error e.g.
191+
// label:
192+
// switch r1 := r() { ... break label ... }
193+
// On extracting "r()" to a variable
194+
// label:
195+
// x := r()
196+
// switch r1 := x { ... break label ... } // compiler error
197+
//
198+
insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(validPath)
199+
if insertBeforeStmt == nil {
200+
return nil, nil, fmt.Errorf("cannot find location to insert extraction")
201+
}
202+
indent, err := calculateIndentation(src, tokFile, insertBeforeStmt)
203+
if err != nil {
204+
return nil, nil, err
205+
}
206+
newLineIndent := "\n" + indent
207+
208+
lhs := strings.Join(lhsNames, ", ")
209+
assignStmt := &ast.AssignStmt{
210+
Lhs: []ast.Expr{ast.NewIdent(lhs)},
211+
Tok: token.DEFINE,
212+
Rhs: []ast.Expr{exprs[0]},
213+
}
214+
var buf bytes.Buffer
215+
if err := format.Node(&buf, fset, assignStmt); err != nil {
216+
return nil, nil, err
217+
}
218+
assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent
219+
var textEdits []analysis.TextEdit
220+
textEdits = append(textEdits, analysis.TextEdit{
221+
Pos: insertBeforeStmt.Pos(),
222+
End: insertBeforeStmt.Pos(),
223+
NewText: []byte(assignment),
224+
})
225+
for _, e := range exprs {
226+
textEdits = append(textEdits, analysis.TextEdit{
227+
Pos: e.Pos(),
228+
End: e.End(),
229+
NewText: []byte(lhs),
230+
})
231+
}
232+
return fset, &analysis.SuggestedFix{
233+
TextEdits: textEdits,
234+
}, nil
235+
}
236+
237+
// findDeepestCommonScope finds the deepest (innermost) scope that is common to all provided scope chains.
238+
// Each scope chain represents the scopes of an expression from innermost to outermost.
239+
// If no common scope is found, it returns an error.
240+
func findDeepestCommonScope(scopeChains [][]*types.Scope) (*types.Scope, error) {
241+
if len(scopeChains) == 0 {
242+
return nil, fmt.Errorf("no scopes provided")
243+
}
244+
// Get the first scope chain as the reference.
245+
referenceChain := scopeChains[0]
246+
247+
// Iterate from innermost to outermost scope.
248+
for i := 0; i < len(referenceChain); i++ {
249+
candidateScope := referenceChain[i]
250+
if candidateScope == nil {
251+
continue
252+
}
253+
isCommon := true
254+
// See if other exprs' chains all have candidateScope as a common ancestor.
255+
for _, chain := range scopeChains[1:] {
256+
found := false
257+
for j := 0; j < len(chain); j++ {
258+
if chain[j] == candidateScope {
259+
found = true
260+
break
261+
}
262+
}
263+
if !found {
264+
isCommon = false
265+
break
266+
}
267+
}
268+
if isCommon {
269+
return candidateScope, nil
270+
}
271+
}
272+
return nil, fmt.Errorf("no common scope found")
273+
}
274+
275+
// allOccurs finds all occurrences of an expression identical to the one
276+
// specified by the start and end positions within the same function.
277+
// It returns at least one ast.Expr.
278+
func allOccurs(start, end token.Pos, file *ast.File) ([]ast.Expr, bool, error) {
279+
if start == end {
280+
return nil, false, fmt.Errorf("start and end are equal")
281+
}
282+
path, _ := astutil.PathEnclosingInterval(file, start, end)
283+
if len(path) == 0 {
284+
return nil, false, fmt.Errorf("no path enclosing interval")
285+
}
286+
for _, n := range path {
287+
if _, ok := n.(*ast.ImportSpec); ok {
288+
return nil, false, fmt.Errorf("cannot extract variable in an import block")
289+
}
290+
}
291+
node := path[0]
292+
if start != node.Pos() || end != node.End() {
293+
return nil, false, fmt.Errorf("range does not map to an AST node")
294+
}
295+
expr, ok := node.(ast.Expr)
296+
if !ok {
297+
return nil, false, fmt.Errorf("node is not an expression")
298+
}
299+
300+
var exprs []ast.Expr
301+
exprs = append(exprs, expr)
302+
if funcDecl, ok := path[len(path)-2].(*ast.FuncDecl); ok {
303+
ast.Inspect(funcDecl, func(n ast.Node) bool {
304+
if e, ok := n.(ast.Expr); ok && e != expr {
305+
if exprIdentical(e, expr) {
306+
exprs = append(exprs, e)
307+
}
308+
}
309+
return true
310+
})
311+
}
312+
sort.Slice(exprs, func(i, j int) bool {
313+
return exprs[i].Pos() < exprs[j].Pos()
314+
})
315+
316+
switch expr.(type) {
317+
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr,
318+
*ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
319+
return exprs, true, nil
320+
}
321+
return nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
322+
}
323+
324+
// generateAvailableIdentifierForAllScopes adjusts the new identifier name
325+
// until there are no collisions in any of the provided scopes.
326+
func generateAvailableIdentifierForAllScopes(scopes []*types.Scope, prefix string, idx int) (string, int) {
327+
name := prefix
328+
for {
329+
collision := false
330+
for _, scope := range scopes {
331+
if scope.Lookup(name) != nil {
332+
collision = true
333+
break
334+
}
335+
}
336+
if !collision {
337+
return name, idx
338+
}
339+
idx++
340+
name = fmt.Sprintf("%s%d", prefix, idx)
341+
}
342+
}
343+
344+
// exprIdentical recursively compares two ast.Expr nodes for structural equality,
345+
// ignoring position fields.
346+
func exprIdentical(x, y ast.Expr) bool {
347+
if x == nil || y == nil {
348+
return x == y
349+
}
350+
switch x := x.(type) {
351+
case *ast.BasicLit:
352+
y, ok := y.(*ast.BasicLit)
353+
return ok && x.Kind == y.Kind && x.Value == y.Value
354+
case *ast.CompositeLit:
355+
y, ok := y.(*ast.CompositeLit)
356+
if !ok || len(x.Elts) != len(y.Elts) || !exprIdentical(x.Type, y.Type) {
357+
return false
358+
}
359+
for i := range x.Elts {
360+
if !exprIdentical(x.Elts[i], y.Elts[i]) {
361+
return false
362+
}
363+
}
364+
return true
365+
case *ast.ArrayType:
366+
y, ok := y.(*ast.ArrayType)
367+
return ok && exprIdentical(x.Len, y.Len) && exprIdentical(x.Elt, y.Elt)
368+
case *ast.Ellipsis:
369+
y, ok := y.(*ast.Ellipsis)
370+
return ok && exprIdentical(x.Elt, y.Elt)
371+
case *ast.FuncLit:
372+
y, ok := y.(*ast.FuncLit)
373+
return ok && exprIdentical(x.Type, y.Type)
374+
case *ast.IndexExpr:
375+
y, ok := y.(*ast.IndexExpr)
376+
return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Index, y.Index)
377+
case *ast.IndexListExpr:
378+
y, ok := y.(*ast.IndexListExpr)
379+
if !ok || len(x.Indices) != len(y.Indices) || !exprIdentical(x.X, y.X) {
380+
return false
381+
}
382+
for i := range x.Indices {
383+
if !exprIdentical(x.Indices[i], y.Indices[i]) {
384+
return false
385+
}
386+
}
387+
return true
388+
case *ast.SliceExpr:
389+
y, ok := y.(*ast.SliceExpr)
390+
return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Low, y.Low) && exprIdentical(x.High, y.High) && exprIdentical(x.Max, y.Max) && x.Slice3 == y.Slice3
391+
case *ast.TypeAssertExpr:
392+
y, ok := y.(*ast.TypeAssertExpr)
393+
return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Type, y.Type)
394+
case *ast.StarExpr:
395+
y, ok := y.(*ast.StarExpr)
396+
return ok && exprIdentical(x.X, y.X)
397+
case *ast.KeyValueExpr:
398+
y, ok := y.(*ast.KeyValueExpr)
399+
return ok && exprIdentical(x.Key, y.Key) && exprIdentical(x.Value, y.Value)
400+
case *ast.UnaryExpr:
401+
y, ok := y.(*ast.UnaryExpr)
402+
return ok && x.Op == y.Op && exprIdentical(x.X, y.X)
403+
case *ast.MapType:
404+
y, ok := y.(*ast.MapType)
405+
return ok && exprIdentical(x.Value, y.Value) && exprIdentical(x.Key, y.Key)
406+
case *ast.ChanType:
407+
y, ok := y.(*ast.ChanType)
408+
return ok && exprIdentical(x.Value, y.Value) && x.Dir == y.Dir
409+
case *ast.BinaryExpr:
410+
y, ok := y.(*ast.BinaryExpr)
411+
return ok && x.Op == y.Op &&
412+
exprIdentical(x.X, y.X) &&
413+
exprIdentical(x.Y, y.Y)
414+
case *ast.Ident:
415+
y, ok := y.(*ast.Ident)
416+
return ok && x.Name == y.Name
417+
case *ast.ParenExpr:
418+
y, ok := y.(*ast.ParenExpr)
419+
return ok && exprIdentical(x.X, y.X)
420+
case *ast.SelectorExpr:
421+
y, ok := y.(*ast.SelectorExpr)
422+
return ok &&
423+
exprIdentical(x.X, y.X) &&
424+
exprIdentical(x.Sel, y.Sel)
425+
case *ast.CallExpr:
426+
y, ok := y.(*ast.CallExpr)
427+
if !ok || len(x.Args) != len(y.Args) {
428+
return false
429+
}
430+
if !exprIdentical(x.Fun, y.Fun) {
431+
return false
432+
}
433+
for i := range x.Args {
434+
if !exprIdentical(x.Args[i], y.Args[i]) {
435+
return false
436+
}
437+
}
438+
return true
439+
default:
440+
// For unhandled expression types, consider them unequal.
441+
return false
442+
}
443+
}
444+
108445
// canExtractVariable reports whether the code in the given range can be
109446
// extracted to a variable.
110447
func canExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) {

0 commit comments

Comments
 (0)