Skip to content

Commit

Permalink
completions/inference: resolve CL feedback and add tests
Browse files Browse the repository at this point in the history
Implements cleanups from CL feedback.

Fixes unintuitive behavior with partially instantiated calls.
Fixes reverse inference to take lower precedence than generic type
constraints.
Fixes no inference completions for out of bounds parameters.

chore: better comment

small format change
  • Loading branch information
jacobzim-stl committed Oct 21, 2024
1 parent 152e295 commit 05a7636
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 147 deletions.
174 changes: 95 additions & 79 deletions gopls/internal/golang/completion/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,25 +134,28 @@ func (i *CompletionItem) Snippet() string {
return i.InsertText
}

func (i *CompletionItem) addPrefixSuffix(c *completer, prefix string, suffix string) error {
if prefix != "" {
// addConversion wraps the existing completionItem in a conversion expression.
// Only affects the receiver's InsertText and snippet fields, not the Label.
// An empty conv argument has no effect.
func (i *CompletionItem) addConversion(c *completer, conv conversionEdits) error {
if conv.prefix != "" {
// If we are in a selector, add an edit to place prefix before selector.
if sel := enclosingSelector(c.path, c.pos); sel != nil {
edits, err := c.editText(sel.Pos(), sel.Pos(), prefix)
edits, err := c.editText(sel.Pos(), sel.Pos(), conv.prefix)
if err != nil {
return err
}
i.AdditionalTextEdits = append(i.AdditionalTextEdits, edits...)
} else {
// If there is no selector, just stick the prefix at the start.
i.InsertText = prefix + i.InsertText
i.snippet.PrependText(prefix)
i.InsertText = conv.prefix + i.InsertText
i.snippet.PrependText(conv.prefix)
}
}

if suffix != "" {
i.InsertText += suffix
i.snippet.WriteText(suffix)
if conv.suffix != "" {
i.InsertText += conv.suffix
i.snippet.WriteText(conv.suffix)
}

return nil
Expand Down Expand Up @@ -2189,6 +2192,10 @@ type candidateInference struct {
// convertibleTo is a type our candidate type must be convertible to.
convertibleTo types.Type

// needsExactType is true if the candidate type must be exactly the type of
// the objType, e.g. an interface rather than it's implementors.
needsExactType bool

// typeName holds information about the expected type name at
// position, if any.
typeName typeNameInference
Expand Down Expand Up @@ -2324,26 +2331,32 @@ Nodes:
break Nodes
}

sig, _ := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature)
if sig, ok := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature); ok {
// Out of bounds arguments get no inference completion.
if !sig.Variadic() && exprAtPos(c.pos, node.Args) >= sig.Params().Len() {
return inf
}

if sig != nil && sig.TypeParams().Len() > 0 && len(c.path) > i+1 {
// infer an instantiation for the CallExpr from it's context
switch c.path[i+1].(type) {
case *ast.AssignStmt, *ast.ReturnStmt, *ast.SendStmt, *ast.ValueSpec:
// Defer call to reverseInferExpectedCallParamType so we can provide it the
// inferences about its parent node.
defer func(sig *types.Signature) {
// Call has uninstantiated type parameters.
// Skips inference for partially instantiated calls until partially
// instantiating signatures has stronger support.
if sig.TypeParams().Len() > 0 && c.numTypeArgs(node) == 0 && len(c.path) > i+1 {
// Use the CallExpr's surroundings to infer an instantiation from its return types.
switch c.path[i+1].(type) {
case *ast.AssignStmt, *ast.ReturnStmt, *ast.SendStmt, *ast.ValueSpec:
// Defer call to reverseInferExpectedCallParamType so we can provide it the
// inferences about its parent node.
defer func(sig *types.Signature) {
inf = c.reverseInferExpectedCallParam(inf, node, sig)
}(sig)
continue Nodes
case *ast.KeyValueExpr:
c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i:], node.Pos(), c.pkg.TypesInfo())
inf = c.reverseInferExpectedCallParam(inf, node, sig)
}(sig)
continue Nodes
case *ast.KeyValueExpr:
c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i:], node.Pos(), c.pkg.TypesInfo())
inf = c.reverseInferExpectedCallParam(inf, node, sig)
break Nodes
break Nodes
}
}
}

if sig != nil {
inf = c.expectedCallParamType(inf, node, sig)
}

Expand Down Expand Up @@ -2531,79 +2544,82 @@ Nodes:
return inf
}

func reverseInferSignature(sig *types.Signature, targetType []types.Type) []types.Type {
if sig.Results().Len() != len(targetType) {
return nil
func (c *completer) numTypeArgs(callExpr *ast.CallExpr) int {
switch fun := callExpr.Fun.(type) {
case *ast.IndexListExpr:
return len(fun.Indices)
case *ast.IndexExpr:
if typ, ok := c.pkg.TypesInfo().Types[fun.Index]; ok && typeIsValid(typ.Type) {
return 1
}
}
return 0
}

// reverseInferSignature instantiates the call site of a generic function
// based on the expected return types. Returns false if inference fails or is invalid.
func reverseInferSignature(sig *types.Signature, targetType []types.Type) ([]types.Type, bool) {
if sig.Results().Len() != len(targetType) || len(targetType) == 0 {
return nil, false
}

tparams := []*types.TypeParam{}
targs := []types.Type{}
tparams := make([]*types.TypeParam, sig.TypeParams().Len())
for i := range sig.TypeParams().Len() {
tparams = append(tparams, sig.TypeParams().At(i))
targs = append(targs, nil)
tparams[i] = sig.TypeParams().At(i)
}

u := newUnifier(tparams, targs)
u := newUnifier(tparams, make([]types.Type, sig.TypeParams().Len()))
for i, assignee := range targetType {
// reverseInferSignature instantiates the call site of a generic function
// based on the expected return types. Returns nil if inference fails or is invalid.
//
// targetType is the expected return types of the function after instantiation.
// This unification does not check the constraints of the type parameters.
if !u.unify(sig.Results().At(i).Type(), assignee, unifyMode(unifyModeExact)) {
return nil
return nil, false
}
}

substs := []types.Type{}
substs := make([]types.Type, sig.TypeParams().Len())
for i := 0; i < sig.TypeParams().Len(); i++ {
if v := u.handles[sig.TypeParams().At(i)]; v != nil && *v != nil {
substs = append(substs, *v)
} else {
substs = append(substs, nil)
if sub := u.handles[sig.TypeParams().At(i)]; sub != nil && *sub != nil {
// Ensure the inferred subst is assignable to the type parameter's constraint.
if !assignableTo(*sub, sig.TypeParams().At(i).Constraint()) {
return nil, false
}
substs[i] = *sub
}
}

return substs
return substs, true
}

func (c *completer) reverseInferredSubstitions(inf candidateInference, sig *types.Signature) []types.Type {
targetType := []types.Type{}
// resetExpectedType resets the inference and returns the previous
// expected inference type.
func (c *completer) resetExpectedType(inf *candidateInference) []types.Type {
var targetType []types.Type
if inf.assignees != nil {
targetType = inf.assignees
inf.assignees = nil
} else if c.enclosingCompositeLiteral != nil && !c.wantStructFieldCompletions() {
targetType = append(targetType, c.expectedCompositeLiteralType())
} else if t := inf.objType; t != nil {
inf.objType = nil
targetType = append(targetType, t)
} else {
return nil
inf.objType = nil
}
return reverseInferSignature(sig, targetType)
return targetType
}

// reverseInferExpectedTypeParam uses inferences and completion parameters from the parent scope
// to instantiate the generalized signature of the call node.
//
// inf is expected to contain inferences based on the parent of the CallExpr node.
func (c *completer) reverseInferExpectedTypeParam(inf candidateInference, expectedConstraint types.Type, typeParamIdx int, sig *types.Signature) candidateInference {
if typeParamIdx >= sig.TypeParams().Len() {
inf.objType = nil
inf.assignees = nil
return inf
}

substs := c.reverseInferredSubstitions(inf, sig)
if substs != nil && len(substs) > 0 {
if substs[typeParamIdx] != nil {
inf.objType = substs[typeParamIdx]
} else {
// default to the constraint if no viable substition
inf.objType = expectedConstraint
}
inf.typeName.wantTypeName = true
inf.typeName.isTypeParam = true
substs, ok := reverseInferSignature(sig, c.resetExpectedType(&inf))
if ok && len(substs) > 0 && substs[typeParamIdx] != nil {
inf.objType = substs[typeParamIdx]
} else {
// Default to the constraint if no viable substition.
inf.objType = expectedConstraint
}
inf.typeName.wantTypeName = true
inf.typeName.isTypeParam = true
return inf
}

Expand All @@ -2612,9 +2628,9 @@ func (c *completer) reverseInferExpectedTypeParam(inf candidateInference, expect
//
// inf is expected to contain inferences based on the parent of the CallExpr node.
func (c *completer) reverseInferExpectedCallParam(inf candidateInference, node *ast.CallExpr, sig *types.Signature) candidateInference {
substs := c.reverseInferredSubstitions(inf, sig)
if substs == nil {
return inf
substs, ok := reverseInferSignature(sig, c.resetExpectedType(&inf))
if !ok {
return c.expectedCallParamType(inf, node, sig)
}

for i := range substs {
Expand All @@ -2627,17 +2643,15 @@ func (c *completer) reverseInferExpectedCallParam(inf candidateInference, node *
if inst, ok := inst.(*types.Signature); ok {
inf = c.expectedCallParamType(inf, node, inst)

// Interface type variants shouldn't be candidates as arguments if the caller isn't
// explicitly instantiated
// Interface variants can't be passed if the caller isn't
// explicitly instantiated. Completions must be exactly the interface type.
//
// func generic[T any](x T) T { return x }
// var x someInterface = generic(someImplementor{})
// ^^ wanted generic[someInterface] but got generic[someImplementor]
// When offering completions, add a conversion if necessary.
// generic(someInterface(someImplementor{}))
if types.IsInterface(inf.objType) {
inf.convertibleTo = inf.objType
}
inf.needsExactType = true
}
}
return inf
Expand Down Expand Up @@ -3124,12 +3138,6 @@ func (ci *candidateInference) candTypeMatches(cand *candidate) bool {
}

if ci.convertibleTo != nil && convertibleTo(candType, ci.convertibleTo) {
// Candidate implements an interface, but needs explicit conversion to the interface
// type. This happens when passing arguments to a generic function.
if ci.objType != nil && types.IsInterface(ci.objType) && !types.Identical(candType, ci.convertibleTo) {
cand.score *= 0.95 // should rank barely lower if it needs a conversion, even though it's perfectly valid
cand.convertTo = ci.objType
}
return true
}

Expand Down Expand Up @@ -3164,6 +3172,14 @@ func (ci *candidateInference) candTypeMatches(cand *candidate) bool {
cand.mods = append(cand.mods, takeDotDotDot)
}

// Candidate matches, but isn't exactly identical to the expected type.
// Apply a conversion to allow it to match.
if ci.needsExactType && !types.Identical(candType, expType) {
cand.convertTo = expType
// Ranks barely lower if it needs a conversion, even though it's perfectly valid.
cand.score *= 0.95
}

// Lower candidate score for untyped conversions. This avoids
// ranking untyped constants above candidates with an exact type
// match. Don't lower score of builtin constants, e.g. "true".
Expand Down
24 changes: 19 additions & 5 deletions gopls/internal/golang/completion/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ Suffixes:
}

if cand.convertTo != nil {
p, s := c.formatConvertTo(cand.convertTo)
prefix = p + prefix
suffix = s
conv := c.formatConversion(cand.convertTo)
prefix = conv.prefix + prefix
suffix = conv.suffix
}

if prefix != "" {
Expand Down Expand Up @@ -273,7 +273,21 @@ Suffixes:
return item, nil
}

func (c *completer) formatConvertTo(convertTo types.Type) (prefix string, suffix string) {
// conversionEdits represents the string edits needed to make a type conversion
// of an expression.
type conversionEdits struct {
prefix, suffix string
}

// formatConversion returns the edits needed to make a type conversion
// expression, including parentheses if necessary.
//
// Returns empty conversionEdits if convertTo is nil.
func (c *completer) formatConversion(convertTo types.Type) conversionEdits {
if convertTo == nil {
return conversionEdits{}
}

typeName := types.TypeString(convertTo, c.qf)
switch t := convertTo.(type) {
// We need extra parens when casting to these types. For example,
Expand All @@ -288,7 +302,7 @@ func (c *completer) formatConvertTo(convertTo types.Type) (prefix string, suffix
typeName = types.TypeString(types.Default(convertTo), c.qf)
}
}
return typeName + "(", ")"
return conversionEdits{prefix: typeName + "(", suffix: ")"}
}

// importEdits produces the text edits necessary to add the given import to the current file.
Expand Down
Loading

0 comments on commit 05a7636

Please sign in to comment.