@@ -36,22 +36,22 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
36
36
// TODO: stricter rules for selectorExpr.
37
37
case * ast.BasicLit , * ast.CompositeLit , * ast.IndexExpr , * ast.SliceExpr ,
38
38
* ast.UnaryExpr , * ast.BinaryExpr , * ast.SelectorExpr :
39
- lhsName , _ := generateAvailableIdentifier (expr .Pos (), path , pkg , info , "x " , 0 )
39
+ lhsName , _ := generateAvailableIdentifier (expr .Pos (), path , pkg , info , "newVar " , 0 )
40
40
lhsNames = append (lhsNames , lhsName )
41
41
case * ast.CallExpr :
42
42
tup , ok := info .TypeOf (expr ).(* types.Tuple )
43
43
if ! ok {
44
44
// If the call expression only has one return value, we can treat it the
45
45
// same as our standard extract variable case.
46
- lhsName , _ := generateAvailableIdentifier (expr .Pos (), path , pkg , info , "x " , 0 )
46
+ lhsName , _ := generateAvailableIdentifier (expr .Pos (), path , pkg , info , "newVar " , 0 )
47
47
lhsNames = append (lhsNames , lhsName )
48
48
break
49
49
}
50
50
idx := 0
51
51
for i := 0 ; i < tup .Len (); i ++ {
52
52
// Generate a unique variable for each return value.
53
53
var lhsName string
54
- lhsName , idx = generateAvailableIdentifier (expr .Pos (), path , pkg , info , "x " , idx )
54
+ lhsName , idx = generateAvailableIdentifier (expr .Pos (), path , pkg , info , "newVar " , idx )
55
55
lhsNames = append (lhsNames , lhsName )
56
56
}
57
57
default :
@@ -105,6 +105,343 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
105
105
}, nil
106
106
}
107
107
108
+ func replaceAllOccursOfExpr (fset * token.FileSet , start , end token.Pos , src []byte , file * ast.File , pkg * types.Package , info * types.Info ) (* token.FileSet , * analysis.SuggestedFix , error ) {
109
+ tokFile := fset .File (file .Pos ())
110
+ exprs , _ , err := allOccurs (start , end , file )
111
+ if err != nil {
112
+ return nil , nil , fmt .Errorf ("extractVariable: cannot extract %s: %v" , safetoken .StartPosition (fset , start ), err )
113
+ }
114
+
115
+ scopes := make ([][]* types.Scope , len (exprs ))
116
+ for i , e := range exprs {
117
+ path , _ := astutil .PathEnclosingInterval (file , e .Pos (), e .End ())
118
+ scopes [i ] = CollectScopes (info , path , e .Pos ())
119
+ }
120
+
121
+ // Find the deepest common scope among all expressions.
122
+ commonScope , err := findDeepestCommonScope (scopes )
123
+ if err != nil {
124
+ return nil , nil , fmt .Errorf ("extractVariable: %v" , err )
125
+ }
126
+
127
+ var innerScopes []* types.Scope
128
+ for _ , scope := range scopes {
129
+ for _ , s := range scope {
130
+ if s != nil {
131
+ innerScopes = append (innerScopes , s )
132
+ break
133
+ }
134
+ }
135
+ }
136
+ if len (innerScopes ) != len (exprs ) {
137
+ return nil , nil , fmt .Errorf ("extractVariable: nil scope" )
138
+ }
139
+ // So the largest scope's name won't conflict.
140
+ innerScopes = append (innerScopes , commonScope )
141
+
142
+ // Create new AST node for extracted code.
143
+ var lhsNames []string
144
+ switch expr := exprs [0 ].(type ) {
145
+ // TODO: stricter rules for selectorExpr.
146
+ case * ast.BasicLit , * ast.CompositeLit , * ast.IndexExpr , * ast.SliceExpr ,
147
+ * ast.UnaryExpr , * ast.BinaryExpr , * ast.SelectorExpr :
148
+ lhsName , _ := generateAvailableIdentifierForAllScopes (innerScopes , "newVar" , 0 )
149
+ lhsNames = append (lhsNames , lhsName )
150
+ case * ast.CallExpr :
151
+ tup , ok := info .TypeOf (expr ).(* types.Tuple )
152
+ if ! ok {
153
+ // If the call expression only has one return value, we can treat it the
154
+ // same as our standard extract variable case.
155
+ lhsName , _ := generateAvailableIdentifierForAllScopes (innerScopes , "newVar" , 0 )
156
+ lhsNames = append (lhsNames , lhsName )
157
+ break
158
+ }
159
+ idx := 0
160
+ for i := 0 ; i < tup .Len (); i ++ {
161
+ // Generate a unique variable for each return value.
162
+ var lhsName string
163
+ lhsName , idx = generateAvailableIdentifierForAllScopes (innerScopes , "newVar" , idx )
164
+ lhsNames = append (lhsNames , lhsName )
165
+ }
166
+ default :
167
+ return nil , nil , fmt .Errorf ("cannot extract %T" , expr )
168
+ }
169
+
170
+ var validPath []ast.Node
171
+ if commonScope != innerScopes [0 ] {
172
+ // This means the first expr within function body is not the largest scope,
173
+ // we need to find the scope immediately follow the common
174
+ // scope where we will insert the statement before.
175
+ child := innerScopes [0 ]
176
+ for p := child ; p != nil ; p = p .Parent () {
177
+ if p == commonScope {
178
+ break
179
+ }
180
+ child = p
181
+ }
182
+ validPath , _ = astutil .PathEnclosingInterval (file , child .Pos (), child .End ())
183
+ } else {
184
+ // Just insert before the first expr.
185
+ validPath , _ = astutil .PathEnclosingInterval (file , exprs [0 ].Pos (), exprs [0 ].End ())
186
+ }
187
+ //
188
+ // TODO: There is a bug here: for a variable declared in a labeled
189
+ // switch/for statement it returns the for/switch statement itself
190
+ // which produces the below code which is a compiler error e.g.
191
+ // label:
192
+ // switch r1 := r() { ... break label ... }
193
+ // On extracting "r()" to a variable
194
+ // label:
195
+ // x := r()
196
+ // switch r1 := x { ... break label ... } // compiler error
197
+ //
198
+ insertBeforeStmt := analysisinternal .StmtToInsertVarBefore (validPath )
199
+ if insertBeforeStmt == nil {
200
+ return nil , nil , fmt .Errorf ("cannot find location to insert extraction" )
201
+ }
202
+ indent , err := calculateIndentation (src , tokFile , insertBeforeStmt )
203
+ if err != nil {
204
+ return nil , nil , err
205
+ }
206
+ newLineIndent := "\n " + indent
207
+
208
+ lhs := strings .Join (lhsNames , ", " )
209
+ assignStmt := & ast.AssignStmt {
210
+ Lhs : []ast.Expr {ast .NewIdent (lhs )},
211
+ Tok : token .DEFINE ,
212
+ Rhs : []ast.Expr {exprs [0 ]},
213
+ }
214
+ var buf bytes.Buffer
215
+ if err := format .Node (& buf , fset , assignStmt ); err != nil {
216
+ return nil , nil , err
217
+ }
218
+ assignment := strings .ReplaceAll (buf .String (), "\n " , newLineIndent ) + newLineIndent
219
+ var textEdits []analysis.TextEdit
220
+ textEdits = append (textEdits , analysis.TextEdit {
221
+ Pos : insertBeforeStmt .Pos (),
222
+ End : insertBeforeStmt .Pos (),
223
+ NewText : []byte (assignment ),
224
+ })
225
+ for _ , e := range exprs {
226
+ textEdits = append (textEdits , analysis.TextEdit {
227
+ Pos : e .Pos (),
228
+ End : e .End (),
229
+ NewText : []byte (lhs ),
230
+ })
231
+ }
232
+ return fset , & analysis.SuggestedFix {
233
+ TextEdits : textEdits ,
234
+ }, nil
235
+ }
236
+
237
+ // findDeepestCommonScope finds the deepest (innermost) scope that is common to all provided scope chains.
238
+ // Each scope chain represents the scopes of an expression from innermost to outermost.
239
+ // If no common scope is found, it returns an error.
240
+ func findDeepestCommonScope (scopeChains [][]* types.Scope ) (* types.Scope , error ) {
241
+ if len (scopeChains ) == 0 {
242
+ return nil , fmt .Errorf ("no scopes provided" )
243
+ }
244
+ // Get the first scope chain as the reference.
245
+ referenceChain := scopeChains [0 ]
246
+
247
+ // Iterate from innermost to outermost scope.
248
+ for i := 0 ; i < len (referenceChain ); i ++ {
249
+ candidateScope := referenceChain [i ]
250
+ if candidateScope == nil {
251
+ continue
252
+ }
253
+ isCommon := true
254
+ // See if other exprs' chains all have candidateScope as a common ancestor.
255
+ for _ , chain := range scopeChains [1 :] {
256
+ found := false
257
+ for j := 0 ; j < len (chain ); j ++ {
258
+ if chain [j ] == candidateScope {
259
+ found = true
260
+ break
261
+ }
262
+ }
263
+ if ! found {
264
+ isCommon = false
265
+ break
266
+ }
267
+ }
268
+ if isCommon {
269
+ return candidateScope , nil
270
+ }
271
+ }
272
+ return nil , fmt .Errorf ("no common scope found" )
273
+ }
274
+
275
+ // allOccurs finds all occurrences of an expression identical to the one
276
+ // specified by the start and end positions within the same function.
277
+ // It returns at least one ast.Expr.
278
+ func allOccurs (start , end token.Pos , file * ast.File ) ([]ast.Expr , bool , error ) {
279
+ if start == end {
280
+ return nil , false , fmt .Errorf ("start and end are equal" )
281
+ }
282
+ path , _ := astutil .PathEnclosingInterval (file , start , end )
283
+ if len (path ) == 0 {
284
+ return nil , false , fmt .Errorf ("no path enclosing interval" )
285
+ }
286
+ for _ , n := range path {
287
+ if _ , ok := n .(* ast.ImportSpec ); ok {
288
+ return nil , false , fmt .Errorf ("cannot extract variable in an import block" )
289
+ }
290
+ }
291
+ node := path [0 ]
292
+ if start != node .Pos () || end != node .End () {
293
+ return nil , false , fmt .Errorf ("range does not map to an AST node" )
294
+ }
295
+ expr , ok := node .(ast.Expr )
296
+ if ! ok {
297
+ return nil , false , fmt .Errorf ("node is not an expression" )
298
+ }
299
+
300
+ var exprs []ast.Expr
301
+ exprs = append (exprs , expr )
302
+ if funcDecl , ok := path [len (path )- 2 ].(* ast.FuncDecl ); ok {
303
+ ast .Inspect (funcDecl , func (n ast.Node ) bool {
304
+ if e , ok := n .(ast.Expr ); ok && e != expr {
305
+ if exprIdentical (e , expr ) {
306
+ exprs = append (exprs , e )
307
+ }
308
+ }
309
+ return true
310
+ })
311
+ }
312
+ sort .Slice (exprs , func (i , j int ) bool {
313
+ return exprs [i ].Pos () < exprs [j ].Pos ()
314
+ })
315
+
316
+ switch expr .(type ) {
317
+ case * ast.BasicLit , * ast.CompositeLit , * ast.IndexExpr , * ast.CallExpr ,
318
+ * ast.SliceExpr , * ast.UnaryExpr , * ast.BinaryExpr , * ast.SelectorExpr :
319
+ return exprs , true , nil
320
+ }
321
+ return nil , false , fmt .Errorf ("cannot extract an %T to a variable" , expr )
322
+ }
323
+
324
+ // generateAvailableIdentifierForAllScopes adjusts the new identifier name
325
+ // until there are no collisions in any of the provided scopes.
326
+ func generateAvailableIdentifierForAllScopes (scopes []* types.Scope , prefix string , idx int ) (string , int ) {
327
+ name := prefix
328
+ for {
329
+ collision := false
330
+ for _ , scope := range scopes {
331
+ if scope .Lookup (name ) != nil {
332
+ collision = true
333
+ break
334
+ }
335
+ }
336
+ if ! collision {
337
+ return name , idx
338
+ }
339
+ idx ++
340
+ name = fmt .Sprintf ("%s%d" , prefix , idx )
341
+ }
342
+ }
343
+
344
+ // exprIdentical recursively compares two ast.Expr nodes for structural equality,
345
+ // ignoring position fields.
346
+ func exprIdentical (x , y ast.Expr ) bool {
347
+ if x == nil || y == nil {
348
+ return x == y
349
+ }
350
+ switch x := x .(type ) {
351
+ case * ast.BasicLit :
352
+ y , ok := y .(* ast.BasicLit )
353
+ return ok && x .Kind == y .Kind && x .Value == y .Value
354
+ case * ast.CompositeLit :
355
+ y , ok := y .(* ast.CompositeLit )
356
+ if ! ok || len (x .Elts ) != len (y .Elts ) || ! exprIdentical (x .Type , y .Type ) {
357
+ return false
358
+ }
359
+ for i := range x .Elts {
360
+ if ! exprIdentical (x .Elts [i ], y .Elts [i ]) {
361
+ return false
362
+ }
363
+ }
364
+ return true
365
+ case * ast.ArrayType :
366
+ y , ok := y .(* ast.ArrayType )
367
+ return ok && exprIdentical (x .Len , y .Len ) && exprIdentical (x .Elt , y .Elt )
368
+ case * ast.Ellipsis :
369
+ y , ok := y .(* ast.Ellipsis )
370
+ return ok && exprIdentical (x .Elt , y .Elt )
371
+ case * ast.FuncLit :
372
+ y , ok := y .(* ast.FuncLit )
373
+ return ok && exprIdentical (x .Type , y .Type )
374
+ case * ast.IndexExpr :
375
+ y , ok := y .(* ast.IndexExpr )
376
+ return ok && exprIdentical (x .X , y .X ) && exprIdentical (x .Index , y .Index )
377
+ case * ast.IndexListExpr :
378
+ y , ok := y .(* ast.IndexListExpr )
379
+ if ! ok || len (x .Indices ) != len (y .Indices ) || ! exprIdentical (x .X , y .X ) {
380
+ return false
381
+ }
382
+ for i := range x .Indices {
383
+ if ! exprIdentical (x .Indices [i ], y .Indices [i ]) {
384
+ return false
385
+ }
386
+ }
387
+ return true
388
+ case * ast.SliceExpr :
389
+ y , ok := y .(* ast.SliceExpr )
390
+ return ok && exprIdentical (x .X , y .X ) && exprIdentical (x .Low , y .Low ) && exprIdentical (x .High , y .High ) && exprIdentical (x .Max , y .Max ) && x .Slice3 == y .Slice3
391
+ case * ast.TypeAssertExpr :
392
+ y , ok := y .(* ast.TypeAssertExpr )
393
+ return ok && exprIdentical (x .X , y .X ) && exprIdentical (x .Type , y .Type )
394
+ case * ast.StarExpr :
395
+ y , ok := y .(* ast.StarExpr )
396
+ return ok && exprIdentical (x .X , y .X )
397
+ case * ast.KeyValueExpr :
398
+ y , ok := y .(* ast.KeyValueExpr )
399
+ return ok && exprIdentical (x .Key , y .Key ) && exprIdentical (x .Value , y .Value )
400
+ case * ast.UnaryExpr :
401
+ y , ok := y .(* ast.UnaryExpr )
402
+ return ok && x .Op == y .Op && exprIdentical (x .X , y .X )
403
+ case * ast.MapType :
404
+ y , ok := y .(* ast.MapType )
405
+ return ok && exprIdentical (x .Value , y .Value ) && exprIdentical (x .Key , y .Key )
406
+ case * ast.ChanType :
407
+ y , ok := y .(* ast.ChanType )
408
+ return ok && exprIdentical (x .Value , y .Value ) && x .Dir == y .Dir
409
+ case * ast.BinaryExpr :
410
+ y , ok := y .(* ast.BinaryExpr )
411
+ return ok && x .Op == y .Op &&
412
+ exprIdentical (x .X , y .X ) &&
413
+ exprIdentical (x .Y , y .Y )
414
+ case * ast.Ident :
415
+ y , ok := y .(* ast.Ident )
416
+ return ok && x .Name == y .Name
417
+ case * ast.ParenExpr :
418
+ y , ok := y .(* ast.ParenExpr )
419
+ return ok && exprIdentical (x .X , y .X )
420
+ case * ast.SelectorExpr :
421
+ y , ok := y .(* ast.SelectorExpr )
422
+ return ok &&
423
+ exprIdentical (x .X , y .X ) &&
424
+ exprIdentical (x .Sel , y .Sel )
425
+ case * ast.CallExpr :
426
+ y , ok := y .(* ast.CallExpr )
427
+ if ! ok || len (x .Args ) != len (y .Args ) {
428
+ return false
429
+ }
430
+ if ! exprIdentical (x .Fun , y .Fun ) {
431
+ return false
432
+ }
433
+ for i := range x .Args {
434
+ if ! exprIdentical (x .Args [i ], y .Args [i ]) {
435
+ return false
436
+ }
437
+ }
438
+ return true
439
+ default :
440
+ // For unhandled expression types, consider them unequal.
441
+ return false
442
+ }
443
+ }
444
+
108
445
// canExtractVariable reports whether the code in the given range can be
109
446
// extracted to a variable.
110
447
func canExtractVariable (start , end token.Pos , file * ast.File ) (ast.Expr , []ast.Node , bool , error ) {
0 commit comments