Skip to content

Commit

Permalink
completions/inference: perform type inference in completions with gen…
Browse files Browse the repository at this point in the history
…eric funcs

When using generic functions, the LSP now tries to infer an instantiation
based on the surroundings of the call expression. If successful, it improves
completions for parameters of generic functions. It shouldn't collide
with any pre-existing code paths.

Fixes #69754
  • Loading branch information
jacobzim-stl committed Oct 8, 2024
1 parent 813e3c7 commit 152e295
Show file tree
Hide file tree
Showing 5 changed files with 627 additions and 393 deletions.
227 changes: 215 additions & 12 deletions gopls/internal/golang/completion/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,30 @@ func (i *CompletionItem) Snippet() string {
return i.InsertText
}

func (i *CompletionItem) addPrefixSuffix(c *completer, prefix string, suffix string) error {
if prefix != "" {
// If we are in a selector, add an edit to place prefix before selector.
if sel := enclosingSelector(c.path, c.pos); sel != nil {
edits, err := c.editText(sel.Pos(), sel.Pos(), prefix)
if err != nil {
return err
}
i.AdditionalTextEdits = append(i.AdditionalTextEdits, edits...)
} else {
// If there is no selector, just stick the prefix at the start.
i.InsertText = prefix + i.InsertText
i.snippet.PrependText(prefix)
}
}

if suffix != "" {
i.InsertText += suffix
i.snippet.WriteText(suffix)
}

return nil
}

// Scoring constants are used for weighting the relevance of different candidates.
const (
// stdScore is the base score for all completion items.
Expand Down Expand Up @@ -659,6 +683,7 @@ func Completion(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p
c.addStatementCandidates()

c.sortItems()

return c.items, c.getSurrounding(), nil
}

Expand Down Expand Up @@ -2301,17 +2326,21 @@ Nodes:

sig, _ := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature)

if sig != nil && sig.TypeParams().Len() > 0 {
// If we are completing a generic func call, re-check the call expression.
// This allows type param inference to work in cases like:
//
// func foo[T any](T) {}
// foo[int](<>) // <- get "int" completions instead of "T"
//
// TODO: remove this after https://go.dev/issue/52503
info := &types.Info{Types: make(map[ast.Expr]types.TypeAndValue)}
types.CheckExpr(c.pkg.FileSet(), c.pkg.Types(), node.Fun.Pos(), node.Fun, info)
sig, _ = info.Types[node.Fun].Type.(*types.Signature)
if sig != nil && sig.TypeParams().Len() > 0 && len(c.path) > i+1 {
// infer an instantiation for the CallExpr from it's context
switch c.path[i+1].(type) {
case *ast.AssignStmt, *ast.ReturnStmt, *ast.SendStmt, *ast.ValueSpec:
// Defer call to reverseInferExpectedCallParamType so we can provide it the
// inferences about its parent node.
defer func(sig *types.Signature) {
inf = c.reverseInferExpectedCallParam(inf, node, sig)
}(sig)
continue Nodes
case *ast.KeyValueExpr:
c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i:], node.Pos(), c.pkg.TypesInfo())
inf = c.reverseInferExpectedCallParam(inf, node, sig)
break Nodes
}
}

if sig != nil {
Expand Down Expand Up @@ -2395,6 +2424,28 @@ Nodes:
}

if ct := expectedConstraint(tv.Type, 0); ct != nil {
// Infer the type parameters in a function call based on it's context
if len(c.path) > i+2 {
if node, ok := c.path[i+1].(*ast.CallExpr); ok {
if sig, ok := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature); ok && sig.TypeParams().Len() != 0 {
// skip again to get the parent of the call expression
i++
switch c.path[i+1].(type) {
case *ast.AssignStmt, *ast.ValueSpec, *ast.ReturnStmt, *ast.SendStmt:
// Defer call to reverseInferExpectedCallParamType so we can provide it the
// inferences about its parent node.
defer func() {
inf = c.reverseInferExpectedTypeParam(inf, ct, 0, sig)
}()
continue Nodes
case *ast.KeyValueExpr:
c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i+2:], node.Pos(), c.pkg.TypesInfo())
inf = c.reverseInferExpectedTypeParam(inf, ct, 0, sig)
break Nodes
}
}
}
}
inf.objType = ct
inf.typeName.wantTypeName = true
inf.typeName.isTypeParam = true
Expand All @@ -2405,7 +2456,30 @@ Nodes:
case *ast.IndexListExpr:
if node.Lbrack < c.pos && c.pos <= node.Rbrack {
if tv, ok := c.pkg.TypesInfo().Types[node.X]; ok {
if ct := expectedConstraint(tv.Type, exprAtPos(c.pos, node.Indices)); ct != nil {
typeParamIdx := exprAtPos(c.pos, node.Indices)
if ct := expectedConstraint(tv.Type, typeParamIdx); ct != nil {
// Infer the type parameters in a function call based on it's context
if len(c.path) > i+2 {
if callnode, ok := c.path[i+1].(*ast.CallExpr); ok {
if sig, ok := c.pkg.TypesInfo().Types[callnode.Fun].Type.(*types.Signature); ok && sig.TypeParams().Len() != 0 {
// skip again to get the parent of the call expression
i++
switch c.path[i+1].(type) {
case *ast.AssignStmt, *ast.ValueSpec, *ast.ReturnStmt, *ast.SendStmt:
// Defer call to reverseInferExpectedCallParamType so we can provide it the
// inferences about its parent node.
defer func() {
inf = c.reverseInferExpectedTypeParam(inf, ct, typeParamIdx, sig)
}()
continue Nodes
case *ast.KeyValueExpr:
c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i+2:], callnode.Pos(), c.pkg.TypesInfo())
inf = c.reverseInferExpectedTypeParam(inf, ct, typeParamIdx, sig)
break Nodes
}
}
}
}
inf.objType = ct
inf.typeName.wantTypeName = true
inf.typeName.isTypeParam = true
Expand Down Expand Up @@ -2457,6 +2531,118 @@ Nodes:
return inf
}

func reverseInferSignature(sig *types.Signature, targetType []types.Type) []types.Type {
if sig.Results().Len() != len(targetType) {
return nil
}

tparams := []*types.TypeParam{}
targs := []types.Type{}
for i := range sig.TypeParams().Len() {
tparams = append(tparams, sig.TypeParams().At(i))
targs = append(targs, nil)
}

u := newUnifier(tparams, targs)
for i, assignee := range targetType {
// reverseInferSignature instantiates the call site of a generic function
// based on the expected return types. Returns nil if inference fails or is invalid.
//
// targetType is the expected return types of the function after instantiation.
if !u.unify(sig.Results().At(i).Type(), assignee, unifyMode(unifyModeExact)) {
return nil
}
}

substs := []types.Type{}
for i := 0; i < sig.TypeParams().Len(); i++ {
if v := u.handles[sig.TypeParams().At(i)]; v != nil && *v != nil {
substs = append(substs, *v)
} else {
substs = append(substs, nil)
}
}

return substs
}

func (c *completer) reverseInferredSubstitions(inf candidateInference, sig *types.Signature) []types.Type {
targetType := []types.Type{}
if inf.assignees != nil {
targetType = inf.assignees
inf.assignees = nil
} else if c.enclosingCompositeLiteral != nil && !c.wantStructFieldCompletions() {
targetType = append(targetType, c.expectedCompositeLiteralType())
} else if t := inf.objType; t != nil {
inf.objType = nil
targetType = append(targetType, t)
} else {
return nil
}
return reverseInferSignature(sig, targetType)
}

// reverseInferExpectedTypeParam uses inferences and completion parameters from the parent scope
// to instantiate the generalized signature of the call node.
//
// inf is expected to contain inferences based on the parent of the CallExpr node.
func (c *completer) reverseInferExpectedTypeParam(inf candidateInference, expectedConstraint types.Type, typeParamIdx int, sig *types.Signature) candidateInference {
if typeParamIdx >= sig.TypeParams().Len() {
inf.objType = nil
inf.assignees = nil
return inf
}

substs := c.reverseInferredSubstitions(inf, sig)
if substs != nil && len(substs) > 0 {
if substs[typeParamIdx] != nil {
inf.objType = substs[typeParamIdx]
} else {
// default to the constraint if no viable substition
inf.objType = expectedConstraint
}
inf.typeName.wantTypeName = true
inf.typeName.isTypeParam = true
}
return inf
}

// reverseInferExpectedCallParam uses inferences and completion parameters from the parent scope
// to instantiate the generalized signature of the call node.
//
// inf is expected to contain inferences based on the parent of the CallExpr node.
func (c *completer) reverseInferExpectedCallParam(inf candidateInference, node *ast.CallExpr, sig *types.Signature) candidateInference {
substs := c.reverseInferredSubstitions(inf, sig)
if substs == nil {
return inf
}

for i := range substs {
if substs[i] == nil {
substs[i] = sig.TypeParams().At(i)
}
}

if inst, err := types.Instantiate(nil, sig, substs, true); err == nil {
if inst, ok := inst.(*types.Signature); ok {
inf = c.expectedCallParamType(inf, node, inst)

// Interface type variants shouldn't be candidates as arguments if the caller isn't
// explicitly instantiated
//
// func generic[T any](x T) T { return x }
// var x someInterface = generic(someImplementor{})
// ^^ wanted generic[someInterface] but got generic[someImplementor]
// When offering completions, add a conversion if necessary.
// generic(someInterface(someImplementor{}))
if types.IsInterface(inf.objType) {
inf.convertibleTo = inf.objType
}
}
}
return inf
}

func (c *completer) expectedCallParamType(inf candidateInference, node *ast.CallExpr, sig *types.Signature) candidateInference {
numParams := sig.Params().Len()
if numParams == 0 {
Expand Down Expand Up @@ -2938,6 +3124,12 @@ func (ci *candidateInference) candTypeMatches(cand *candidate) bool {
}

if ci.convertibleTo != nil && convertibleTo(candType, ci.convertibleTo) {
// Candidate implements an interface, but needs explicit conversion to the interface
// type. This happens when passing arguments to a generic function.
if ci.objType != nil && types.IsInterface(ci.objType) && !types.Identical(candType, ci.convertibleTo) {
cand.score *= 0.95 // should rank barely lower if it needs a conversion, even though it's perfectly valid
cand.convertTo = ci.objType
}
return true
}

Expand Down Expand Up @@ -3161,6 +3353,10 @@ func (c *completer) matchingTypeName(cand *candidate) bool {
return false
}

wantInterfaceTypeParam := c.inference.typeName.isTypeParam &&
c.inference.typeName.wantTypeName && c.inference.objType != nil &&
types.IsInterface(c.inference.objType)

typeMatches := func(candType types.Type) bool {
// Take into account any type name modifier prefixes.
candType = c.inference.applyTypeNameModifiers(candType)
Expand All @@ -3179,6 +3375,13 @@ func (c *completer) matchingTypeName(cand *candidate) bool {
}
}

// When performing reverse type inference
// x = Foo[<>]()
// Where x is an interface, only suggest the interface rather than its implementors
if wantInterfaceTypeParam && types.Identical(candType, c.inference.objType) {
return true
}

if c.inference.typeName.wantComparable && !types.Comparable(candType) {
return false
}
Expand Down
39 changes: 21 additions & 18 deletions gopls/internal/golang/completion/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,24 +196,9 @@ Suffixes:
}

if cand.convertTo != nil {
typeName := types.TypeString(cand.convertTo, c.qf)

switch t := cand.convertTo.(type) {
// We need extra parens when casting to these types. For example,
// we need "(*int)(foo)", not "*int(foo)".
case *types.Pointer, *types.Signature:
typeName = "(" + typeName + ")"
case *types.Basic:
// If the types are incompatible (as determined by typeMatches), then we
// must need a conversion here. However, if the target type is untyped,
// don't suggest converting to e.g. "untyped float" (golang/go#62141).
if t.Info()&types.IsUntyped != 0 {
typeName = types.TypeString(types.Default(cand.convertTo), c.qf)
}
}

prefix = typeName + "(" + prefix
suffix = ")"
p, s := c.formatConvertTo(cand.convertTo)
prefix = p + prefix
suffix = s
}

if prefix != "" {
Expand Down Expand Up @@ -288,6 +273,24 @@ Suffixes:
return item, nil
}

func (c *completer) formatConvertTo(convertTo types.Type) (prefix string, suffix string) {
typeName := types.TypeString(convertTo, c.qf)
switch t := convertTo.(type) {
// We need extra parens when casting to these types. For example,
// we need "(*int)(foo)", not "*int(foo)".
case *types.Pointer, *types.Signature:
typeName = "(" + typeName + ")"
case *types.Basic:
// If the types are incompatible (as determined by typeMatches), then we
// must need a conversion here. However, if the target type is untyped,
// don't suggest converting to e.g. "untyped float" (golang/go#62141).
if t.Info()&types.IsUntyped != 0 {
typeName = types.TypeString(types.Default(convertTo), c.qf)
}
}
return typeName + "(", ")"
}

// importEdits produces the text edits necessary to add the given import to the current file.
func (c *completer) importEdits(imp *importInfo) ([]protocol.TextEdit, error) {
if imp == nil {
Expand Down
Loading

0 comments on commit 152e295

Please sign in to comment.