diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index cf398693113..9f371571acd 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -134,6 +134,30 @@ func (i *CompletionItem) Snippet() string { return i.InsertText } +func (i *CompletionItem) addPrefixSuffix(c *completer, prefix string, suffix string) error { + if prefix != "" { + // If we are in a selector, add an edit to place prefix before selector. + if sel := enclosingSelector(c.path, c.pos); sel != nil { + edits, err := c.editText(sel.Pos(), sel.Pos(), prefix) + if err != nil { + return err + } + i.AdditionalTextEdits = append(i.AdditionalTextEdits, edits...) + } else { + // If there is no selector, just stick the prefix at the start. + i.InsertText = prefix + i.InsertText + i.snippet.PrependText(prefix) + } + } + + if suffix != "" { + i.InsertText += suffix + i.snippet.WriteText(suffix) + } + + return nil +} + // Scoring constants are used for weighting the relevance of different candidates. const ( // stdScore is the base score for all completion items. @@ -659,6 +683,7 @@ func Completion(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, p c.addStatementCandidates() c.sortItems() + return c.items, c.getSurrounding(), nil } @@ -2301,17 +2326,21 @@ Nodes: sig, _ := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature) - if sig != nil && sig.TypeParams().Len() > 0 { - // If we are completing a generic func call, re-check the call expression. - // This allows type param inference to work in cases like: - // - // func foo[T any](T) {} - // foo[int](<>) // <- get "int" completions instead of "T" - // - // TODO: remove this after https://go.dev/issue/52503 - info := &types.Info{Types: make(map[ast.Expr]types.TypeAndValue)} - types.CheckExpr(c.pkg.FileSet(), c.pkg.Types(), node.Fun.Pos(), node.Fun, info) - sig, _ = info.Types[node.Fun].Type.(*types.Signature) + if sig != nil && sig.TypeParams().Len() > 0 && len(c.path) > i+1 { + // infer an instantiation for the CallExpr from it's context + switch c.path[i+1].(type) { + case *ast.AssignStmt, *ast.ReturnStmt, *ast.SendStmt, *ast.ValueSpec: + // Defer call to reverseInferExpectedCallParamType so we can provide it the + // inferences about its parent node. + defer func(sig *types.Signature) { + inf = c.reverseInferExpectedCallParam(inf, node, sig) + }(sig) + continue Nodes + case *ast.KeyValueExpr: + c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i:], node.Pos(), c.pkg.TypesInfo()) + inf = c.reverseInferExpectedCallParam(inf, node, sig) + break Nodes + } } if sig != nil { @@ -2395,6 +2424,28 @@ Nodes: } if ct := expectedConstraint(tv.Type, 0); ct != nil { + // Infer the type parameters in a function call based on it's context + if len(c.path) > i+2 { + if node, ok := c.path[i+1].(*ast.CallExpr); ok { + if sig, ok := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature); ok && sig.TypeParams().Len() != 0 { + // skip again to get the parent of the call expression + i++ + switch c.path[i+1].(type) { + case *ast.AssignStmt, *ast.ValueSpec, *ast.ReturnStmt, *ast.SendStmt: + // Defer call to reverseInferExpectedCallParamType so we can provide it the + // inferences about its parent node. + defer func() { + inf = c.reverseInferExpectedTypeParam(inf, ct, 0, sig) + }() + continue Nodes + case *ast.KeyValueExpr: + c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i+2:], node.Pos(), c.pkg.TypesInfo()) + inf = c.reverseInferExpectedTypeParam(inf, ct, 0, sig) + break Nodes + } + } + } + } inf.objType = ct inf.typeName.wantTypeName = true inf.typeName.isTypeParam = true @@ -2405,7 +2456,30 @@ Nodes: case *ast.IndexListExpr: if node.Lbrack < c.pos && c.pos <= node.Rbrack { if tv, ok := c.pkg.TypesInfo().Types[node.X]; ok { - if ct := expectedConstraint(tv.Type, exprAtPos(c.pos, node.Indices)); ct != nil { + typeParamIdx := exprAtPos(c.pos, node.Indices) + if ct := expectedConstraint(tv.Type, typeParamIdx); ct != nil { + // Infer the type parameters in a function call based on it's context + if len(c.path) > i+2 { + if callnode, ok := c.path[i+1].(*ast.CallExpr); ok { + if sig, ok := c.pkg.TypesInfo().Types[callnode.Fun].Type.(*types.Signature); ok && sig.TypeParams().Len() != 0 { + // skip again to get the parent of the call expression + i++ + switch c.path[i+1].(type) { + case *ast.AssignStmt, *ast.ValueSpec, *ast.ReturnStmt, *ast.SendStmt: + // Defer call to reverseInferExpectedCallParamType so we can provide it the + // inferences about its parent node. + defer func() { + inf = c.reverseInferExpectedTypeParam(inf, ct, typeParamIdx, sig) + }() + continue Nodes + case *ast.KeyValueExpr: + c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[i+2:], callnode.Pos(), c.pkg.TypesInfo()) + inf = c.reverseInferExpectedTypeParam(inf, ct, typeParamIdx, sig) + break Nodes + } + } + } + } inf.objType = ct inf.typeName.wantTypeName = true inf.typeName.isTypeParam = true @@ -2457,6 +2531,118 @@ Nodes: return inf } +func reverseInferSignature(sig *types.Signature, targetType []types.Type) []types.Type { + if sig.Results().Len() != len(targetType) { + return nil + } + + tparams := []*types.TypeParam{} + targs := []types.Type{} + for i := range sig.TypeParams().Len() { + tparams = append(tparams, sig.TypeParams().At(i)) + targs = append(targs, nil) + } + + u := newUnifier(tparams, targs) + for i, assignee := range targetType { + // reverseInferSignature instantiates the call site of a generic function + // based on the expected return types. Returns nil if inference fails or is invalid. + // + // targetType is the expected return types of the function after instantiation. + if !u.unify(sig.Results().At(i).Type(), assignee, unifyMode(unifyModeExact)) { + return nil + } + } + + substs := []types.Type{} + for i := 0; i < sig.TypeParams().Len(); i++ { + if v := u.handles[sig.TypeParams().At(i)]; v != nil && *v != nil { + substs = append(substs, *v) + } else { + substs = append(substs, nil) + } + } + + return substs +} + +func (c *completer) reverseInferredSubstitions(inf candidateInference, sig *types.Signature) []types.Type { + targetType := []types.Type{} + if inf.assignees != nil { + targetType = inf.assignees + inf.assignees = nil + } else if c.enclosingCompositeLiteral != nil && !c.wantStructFieldCompletions() { + targetType = append(targetType, c.expectedCompositeLiteralType()) + } else if t := inf.objType; t != nil { + inf.objType = nil + targetType = append(targetType, t) + } else { + return nil + } + return reverseInferSignature(sig, targetType) +} + +// reverseInferExpectedTypeParam uses inferences and completion parameters from the parent scope +// to instantiate the generalized signature of the call node. +// +// inf is expected to contain inferences based on the parent of the CallExpr node. +func (c *completer) reverseInferExpectedTypeParam(inf candidateInference, expectedConstraint types.Type, typeParamIdx int, sig *types.Signature) candidateInference { + if typeParamIdx >= sig.TypeParams().Len() { + inf.objType = nil + inf.assignees = nil + return inf + } + + substs := c.reverseInferredSubstitions(inf, sig) + if substs != nil && len(substs) > 0 { + if substs[typeParamIdx] != nil { + inf.objType = substs[typeParamIdx] + } else { + // default to the constraint if no viable substition + inf.objType = expectedConstraint + } + inf.typeName.wantTypeName = true + inf.typeName.isTypeParam = true + } + return inf +} + +// reverseInferExpectedCallParam uses inferences and completion parameters from the parent scope +// to instantiate the generalized signature of the call node. +// +// inf is expected to contain inferences based on the parent of the CallExpr node. +func (c *completer) reverseInferExpectedCallParam(inf candidateInference, node *ast.CallExpr, sig *types.Signature) candidateInference { + substs := c.reverseInferredSubstitions(inf, sig) + if substs == nil { + return inf + } + + for i := range substs { + if substs[i] == nil { + substs[i] = sig.TypeParams().At(i) + } + } + + if inst, err := types.Instantiate(nil, sig, substs, true); err == nil { + if inst, ok := inst.(*types.Signature); ok { + inf = c.expectedCallParamType(inf, node, inst) + + // Interface type variants shouldn't be candidates as arguments if the caller isn't + // explicitly instantiated + // + // func generic[T any](x T) T { return x } + // var x someInterface = generic(someImplementor{}) + // ^^ wanted generic[someInterface] but got generic[someImplementor] + // When offering completions, add a conversion if necessary. + // generic(someInterface(someImplementor{})) + if types.IsInterface(inf.objType) { + inf.convertibleTo = inf.objType + } + } + } + return inf +} + func (c *completer) expectedCallParamType(inf candidateInference, node *ast.CallExpr, sig *types.Signature) candidateInference { numParams := sig.Params().Len() if numParams == 0 { @@ -2938,6 +3124,12 @@ func (ci *candidateInference) candTypeMatches(cand *candidate) bool { } if ci.convertibleTo != nil && convertibleTo(candType, ci.convertibleTo) { + // Candidate implements an interface, but needs explicit conversion to the interface + // type. This happens when passing arguments to a generic function. + if ci.objType != nil && types.IsInterface(ci.objType) && !types.Identical(candType, ci.convertibleTo) { + cand.score *= 0.95 // should rank barely lower if it needs a conversion, even though it's perfectly valid + cand.convertTo = ci.objType + } return true } @@ -3161,6 +3353,10 @@ func (c *completer) matchingTypeName(cand *candidate) bool { return false } + wantInterfaceTypeParam := c.inference.typeName.isTypeParam && + c.inference.typeName.wantTypeName && c.inference.objType != nil && + types.IsInterface(c.inference.objType) + typeMatches := func(candType types.Type) bool { // Take into account any type name modifier prefixes. candType = c.inference.applyTypeNameModifiers(candType) @@ -3179,6 +3375,13 @@ func (c *completer) matchingTypeName(cand *candidate) bool { } } + // When performing reverse type inference + // x = Foo[<>]() + // Where x is an interface, only suggest the interface rather than its implementors + if wantInterfaceTypeParam && types.Identical(candType, c.inference.objType) { + return true + } + if c.inference.typeName.wantComparable && !types.Comparable(candType) { return false } diff --git a/gopls/internal/golang/completion/format.go b/gopls/internal/golang/completion/format.go index c2b955ca7e9..388f3f2c9b1 100644 --- a/gopls/internal/golang/completion/format.go +++ b/gopls/internal/golang/completion/format.go @@ -196,24 +196,9 @@ Suffixes: } if cand.convertTo != nil { - typeName := types.TypeString(cand.convertTo, c.qf) - - switch t := cand.convertTo.(type) { - // We need extra parens when casting to these types. For example, - // we need "(*int)(foo)", not "*int(foo)". - case *types.Pointer, *types.Signature: - typeName = "(" + typeName + ")" - case *types.Basic: - // If the types are incompatible (as determined by typeMatches), then we - // must need a conversion here. However, if the target type is untyped, - // don't suggest converting to e.g. "untyped float" (golang/go#62141). - if t.Info()&types.IsUntyped != 0 { - typeName = types.TypeString(types.Default(cand.convertTo), c.qf) - } - } - - prefix = typeName + "(" + prefix - suffix = ")" + p, s := c.formatConvertTo(cand.convertTo) + prefix = p + prefix + suffix = s } if prefix != "" { @@ -288,6 +273,24 @@ Suffixes: return item, nil } +func (c *completer) formatConvertTo(convertTo types.Type) (prefix string, suffix string) { + typeName := types.TypeString(convertTo, c.qf) + switch t := convertTo.(type) { + // We need extra parens when casting to these types. For example, + // we need "(*int)(foo)", not "*int(foo)". + case *types.Pointer, *types.Signature: + typeName = "(" + typeName + ")" + case *types.Basic: + // If the types are incompatible (as determined by typeMatches), then we + // must need a conversion here. However, if the target type is untyped, + // don't suggest converting to e.g. "untyped float" (golang/go#62141). + if t.Info()&types.IsUntyped != 0 { + typeName = types.TypeString(types.Default(convertTo), c.qf) + } + } + return typeName + "(", ")" +} + // importEdits produces the text edits necessary to add the given import to the current file. func (c *completer) importEdits(imp *importInfo) ([]protocol.TextEdit, error) { if imp == nil { diff --git a/gopls/internal/golang/completion/literal.go b/gopls/internal/golang/completion/literal.go index 7427d559e94..47021142f10 100644 --- a/gopls/internal/golang/completion/literal.go +++ b/gopls/internal/golang/completion/literal.go @@ -73,15 +73,26 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im cand.addressable = true } - if !c.matchingCandidate(&cand) || cand.convertTo != nil { + if !c.matchingCandidate(&cand) { return } var ( - qf = c.qf - sel = enclosingSelector(c.path, c.pos) + qf = c.qf + sel = enclosingSelector(c.path, c.pos) + prefix = "" + suffix = "" ) + // Only convert literals which are an interface type value to a + // parameterized argument in a function call. + if cand.convertTo != nil { + if !types.IsInterface(cand.convertTo) { + return + } + prefix, suffix = c.formatConvertTo(cand.convertTo) + } + // Don't qualify the type name if we are in a selector expression // since the package name is already present. if sel != nil { @@ -129,13 +140,18 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im switch t := literalType.Underlying().(type) { case *types.Struct, *types.Array, *types.Slice, *types.Map: - c.compositeLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + item := c.compositeLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + item.addPrefixSuffix(c, prefix, suffix) + c.items = append(c.items, item) case *types.Signature: // Add a literal completion for a signature type that implements // an interface. For example, offer "http.HandlerFunc()" when // expected type is "http.Handler". if expType != nil && types.IsInterface(expType) { - c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + if item, ok := c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits); ok { + item.addPrefixSuffix(c, prefix, suffix) + c.items = append(c.items, item) + } } case *types.Basic: // Add a literal completion for basic types that implement our @@ -143,7 +159,10 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im // implements http.FileSystem), or are identical to our expected // type (i.e. yielding a type conversion such as "float64()"). if expType != nil && (types.IsInterface(expType) || types.Identical(expType, literalType)) { - c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + if item, ok := c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits); ok { + item.addPrefixSuffix(c, prefix, suffix) + c.items = append(c.items, item) + } } } } @@ -155,11 +174,15 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im switch literalType.Underlying().(type) { case *types.Slice: // The second argument to "make()" for slices is required, so default to "0". - c.makeCall(snip.Clone(), typeName, "0", float64(score), addlEdits) + item := c.makeCall(snip.Clone(), typeName, "0", float64(score), addlEdits) + item.addPrefixSuffix(c, prefix, suffix) + c.items = append(c.items, item) case *types.Map, *types.Chan: // Maps and channels don't require the second argument, so omit // to keep things simple for now. - c.makeCall(snip.Clone(), typeName, "", float64(score), addlEdits) + item := c.makeCall(snip.Clone(), typeName, "", float64(score), addlEdits) + item.addPrefixSuffix(c, prefix, suffix) + c.items = append(c.items, item) } } @@ -167,7 +190,10 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im if score := c.matcher.Score("func"); !cand.hasMod(reference) && score > 0 && (expType == nil || !types.IsInterface(expType)) { switch t := literalType.Underlying().(type) { case *types.Signature: - c.functionLiteral(ctx, t, float64(score)) + if item, ok := c.functionLiteral(ctx, t, float64(score)); ok { + item.addPrefixSuffix(c, prefix, suffix) + c.items = append(c.items, item) + } } } } @@ -180,7 +206,7 @@ const literalCandidateScore = highScore / 2 // functionLiteral adds a function literal completion item for the // given signature. -func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, matchScore float64) { +func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, matchScore float64) (CompletionItem, bool) { snip := &snippet.Builder{} snip.WriteText("func(") @@ -216,7 +242,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m if ctx.Err() == nil { event.Error(ctx, "formatting var type", err) } - return + return CompletionItem{}, false } name = abbreviateTypeName(typeName) } @@ -284,7 +310,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m if ctx.Err() == nil { event.Error(ctx, "formatting var type", err) } - return + return CompletionItem{}, false } if sig.Variadic() && i == sig.Params().Len()-1 { typeStr = strings.Replace(typeStr, "[]", "...", 1) @@ -342,7 +368,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m if ctx.Err() == nil { event.Error(ctx, "formatting var type", err) } - return + return CompletionItem{}, false } if tp, ok := types.Unalias(r.Type()).(*types.TypeParam); ok && !c.typeParamInScope(tp) { snip.WritePlaceholder(func(snip *snippet.Builder) { @@ -360,12 +386,12 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m snip.WriteFinalTabstop() snip.WriteText("}") - c.items = append(c.items, CompletionItem{ + return CompletionItem{ Label: "func(...) {}", Score: matchScore * literalCandidateScore, Kind: protocol.VariableCompletion, snippet: snip, - }) + }, true } // conventionalAcronyms contains conventional acronyms for type names @@ -432,7 +458,7 @@ func abbreviateTypeName(s string) string { // compositeLiteral adds a composite literal completion item for the given typeName. // T is an (unnamed, unaliased) struct, array, slice, or map type. -func (c *completer) compositeLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) { +func (c *completer) compositeLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) CompletionItem { snip.WriteText("{") // Don't put the tab stop inside the composite literal curlies "{}" // for structs that have no accessible fields. @@ -443,22 +469,22 @@ func (c *completer) compositeLiteral(T types.Type, snip *snippet.Builder, typeNa nonSnippet := typeName + "{}" - c.items = append(c.items, CompletionItem{ + return CompletionItem{ Label: nonSnippet, InsertText: nonSnippet, Score: matchScore * literalCandidateScore, Kind: protocol.VariableCompletion, AdditionalTextEdits: edits, snippet: snip, - }) + } } // basicLiteral adds a literal completion item for the given basic // type name typeName. -func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) { +func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) (CompletionItem, bool) { // Never give type conversions like "untyped int()". if isUntyped(T) { - return + return CompletionItem{}, false } snip.WriteText("(") @@ -467,7 +493,7 @@ func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName s nonSnippet := typeName + "()" - c.items = append(c.items, CompletionItem{ + return CompletionItem{ Label: nonSnippet, InsertText: nonSnippet, Detail: T.String(), @@ -475,11 +501,11 @@ func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName s Kind: protocol.VariableCompletion, AdditionalTextEdits: edits, snippet: snip, - }) + }, true } -// makeCall adds a completion item for a "make()" call given a specific type. -func (c *completer) makeCall(snip *snippet.Builder, typeName string, secondArg string, matchScore float64, edits []protocol.TextEdit) { +// makeCall returns a completion item for a "make()" call given a specific type. +func (c *completer) makeCall(snip *snippet.Builder, typeName string, secondArg string, matchScore float64, edits []protocol.TextEdit) CompletionItem { // Keep it simple and don't add any placeholders for optional "make()" arguments. snip.PrependText("make(") @@ -501,14 +527,15 @@ func (c *completer) makeCall(snip *snippet.Builder, typeName string, secondArg s } nonSnippet.WriteByte(')') - c.items = append(c.items, CompletionItem{ - Label: nonSnippet.String(), - InsertText: nonSnippet.String(), - Score: matchScore * literalCandidateScore, + return CompletionItem{ + Label: nonSnippet.String(), + InsertText: nonSnippet.String(), + // make() should be just below other literal completions + Score: matchScore * literalCandidateScore * 0.99, Kind: protocol.FunctionCompletion, AdditionalTextEdits: edits, snippet: snip, - }) + } } // Create a snippet for a type name where type params become placeholders. diff --git a/gopls/internal/golang/completion/unify.go b/gopls/internal/golang/completion/unify.go index 1c611a3e2a4..7eadf05e420 100644 --- a/gopls/internal/golang/completion/unify.go +++ b/gopls/internal/golang/completion/unify.go @@ -1,3 +1,16 @@ +// Below is copied from go/types/unify.go on September 21, 2021, with +// snippets from other files as well. It is copied to implement +// unification for autocompletion inferences, in lieu of an official +// type unification API. +// +// When such an API is available, the code below should deleted. +// Due to complexity of extracting private types from the go/types package, +// +// The unifier does not fully interface unification. +// +// The code has been modified to compile without introducing key any functionality changes. +// + // Copyright 2020 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -30,12 +43,11 @@ // Whether they indeed are identical or assignable is determined // upon instantiation and function argument passing. -package types2 +package completion import ( - "bytes" "fmt" - "sort" + "go/types" "strings" ) @@ -52,15 +64,6 @@ const ( // If enableCoreTypeUnification is set, unification will consider // the core types, if any, of non-local (unbound) type parameters. enableCoreTypeUnification = true - - // If traceInference is set, unification will print a trace of its operation. - // Interpretation of trace: - // x ≡ y attempt to unify types x and y - // p ➞ y type parameter p is set to type y (p is inferred to be y) - // p ⇄ q type parameters p and q match (p is inferred to be q and vice versa) - // x ≢ y types x and y cannot be unified - // [p, q, ...] ➞ [x, y, ...] mapping from type parameters to types - traceInference = false ) // A unifier maintains a list of type parameters and @@ -76,115 +79,61 @@ type unifier struct { // that inferring the type for a given type parameter P will // automatically infer the same type for all other parameters // unified (joined) with P. - handles map[*TypeParam]*Type - depth int // recursion depth during unification - enableInterfaceInference bool // use shared methods for better inference + handles map[*types.TypeParam]*types.Type + depth int // recursion depth during unification } // newUnifier returns a new unifier initialized with the given type parameter // and corresponding type argument lists. The type argument list may be shorter // than the type parameter list, and it may contain nil types. Matching type // parameters and arguments must have the same index. -func newUnifier(tparams []*TypeParam, targs []Type, enableInterfaceInference bool) *unifier { - assert(len(tparams) >= len(targs)) - handles := make(map[*TypeParam]*Type, len(tparams)) +func newUnifier(tparams []*types.TypeParam, targs []types.Type) *unifier { + handles := make(map[*types.TypeParam]*types.Type, len(tparams)) // Allocate all handles up-front: in a correct program, all type parameters // must be resolved and thus eventually will get a handle. // Also, sharing of handles caused by unified type parameters is rare and // so it's ok to not optimize for that case (and delay handle allocation). for i, x := range tparams { - var t Type + var t types.Type if i < len(targs) { t = targs[i] } handles[x] = &t } - return &unifier{handles, 0, enableInterfaceInference} + return &unifier{handles, 0} } // unifyMode controls the behavior of the unifier. type unifyMode uint const ( - // If assign is set, we are unifying types involved in an assignment: + // If unifyModeAssign is set, we are unifying types involved in an assignment: // they may match inexactly at the top, but element types must match // exactly. - assign unifyMode = 1 << iota + unifyModeAssign unifyMode = 1 << iota - // If exact is set, types unify if they are identical (or can be + // If unifyModeExact is set, types unify if they are identical (or can be // made identical with suitable arguments for type parameters). // Otherwise, a named type and a type literal unify if their // underlying types unify, channel directions are ignored, and // if there is an interface, the other type must implement the // interface. - exact + unifyModeExact ) -func (m unifyMode) String() string { - switch m { - case 0: - return "inexact" - case assign: - return "assign" - case exact: - return "exact" - case assign | exact: - return "assign, exact" - } - return fmt.Sprintf("mode %d", m) -} - // unify attempts to unify x and y and reports whether it succeeded. // As a side-effect, types may be inferred for type parameters. // The mode parameter controls how types are compared. -func (u *unifier) unify(x, y Type, mode unifyMode) bool { - return u.nify(x, y, mode, nil) -} - -func (u *unifier) tracef(format string, args ...interface{}) { - fmt.Println(strings.Repeat(". ", u.depth) + sprintf(nil, true, format, args...)) -} - -// String returns a string representation of the current mapping -// from type parameters to types. -func (u *unifier) String() string { - // sort type parameters for reproducible strings - tparams := make(typeParamsById, len(u.handles)) - i := 0 - for tpar := range u.handles { - tparams[i] = tpar - i++ - } - sort.Sort(tparams) - - var buf bytes.Buffer - w := newTypeWriter(&buf, nil) - w.byte('[') - for i, x := range tparams { - if i > 0 { - w.string(", ") - } - w.typ(x) - w.string(": ") - w.typ(u.at(x)) - } - w.byte(']') - return buf.String() +func (u *unifier) unify(x, y types.Type, mode unifyMode) bool { + return u.nify(x, y, mode) } -type typeParamsById []*TypeParam - -func (s typeParamsById) Len() int { return len(s) } -func (s typeParamsById) Less(i, j int) bool { return s[i].id < s[j].id } -func (s typeParamsById) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +type typeParamsById []*types.TypeParam // join unifies the given type parameters x and y. // If both type parameters already have a type associated with them // and they are not joined, join fails and returns false. -func (u *unifier) join(x, y *TypeParam) bool { - if traceInference { - u.tracef("%s ⇄ %s", x, y) - } +func (u *unifier) join(x, y *types.TypeParam) bool { switch hx, hy := u.handles[x], u.handles[y]; { case hx == hy: // Both type parameters already share the same handle. Nothing to do. @@ -205,10 +154,10 @@ func (u *unifier) join(x, y *TypeParam) bool { return true } -// asBoundTypeParam returns x.(*TypeParam) if x is a type parameter recorded with u. +// asBoundTypeParam returns x.(*types.TypeParam) if x is a type parameter recorded with u. // Otherwise, the result is nil. -func (u *unifier) asBoundTypeParam(x Type) *TypeParam { - if x, _ := Unalias(x).(*TypeParam); x != nil { +func (u *unifier) asBoundTypeParam(x types.Type) *types.TypeParam { + if x, _ := types.Unalias(x).(*types.TypeParam); x != nil { if _, found := u.handles[x]; found { return x } @@ -218,9 +167,8 @@ func (u *unifier) asBoundTypeParam(x Type) *TypeParam { // setHandle sets the handle for type parameter x // (and all its joined type parameters) to h. -func (u *unifier) setHandle(x *TypeParam, h *Type) { +func (u *unifier) setHandle(x *types.TypeParam, h *types.Type) { hx := u.handles[x] - assert(hx != nil) for y, hy := range u.handles { if hy == hx { u.handles[y] = h @@ -229,17 +177,13 @@ func (u *unifier) setHandle(x *TypeParam, h *Type) { } // at returns the (possibly nil) type for type parameter x. -func (u *unifier) at(x *TypeParam) Type { +func (u *unifier) at(x *types.TypeParam) types.Type { return *u.handles[x] } // set sets the type t for type parameter x; // t must not be nil. -func (u *unifier) set(x *TypeParam, t Type) { - assert(t != nil) - if traceInference { - u.tracef("%s ➞ %s", x, t) - } +func (u *unifier) set(x *types.TypeParam, t types.Type) { *u.handles[x] = t } @@ -258,8 +202,8 @@ func (u *unifier) unknowns() int { // The result is never nil and has the same length as tparams; result types that // could not be inferred are nil. Corresponding type parameters and result types // have identical indices. -func (u *unifier) inferred(tparams []*TypeParam) []Type { - list := make([]Type, len(tparams)) +func (u *unifier) inferred(tparams []*types.TypeParam) []types.Type { + list := make([]types.Type, len(tparams)) for i, x := range tparams { list[i] = u.at(x) } @@ -268,39 +212,111 @@ func (u *unifier) inferred(tparams []*TypeParam) []Type { // asInterface returns the underlying type of x as an interface if // it is a non-type parameter interface. Otherwise it returns nil. -func asInterface(x Type) (i *Interface) { - if _, ok := Unalias(x).(*TypeParam); !ok { - i, _ = under(x).(*Interface) +func asInterface(x types.Type) (i *types.Interface) { + if _, ok := types.Unalias(x).(*types.TypeParam); !ok { + i, _ = x.Underlying().(*types.Interface) } return i } +func isTypeParam(t types.Type) bool { + _, ok := types.Unalias(t).(*types.TypeParam) + return ok +} + +func asNamed(t types.Type) *types.Named { + n, _ := types.Unalias(t).(*types.Named) + return n +} + +func isTypeLit(t types.Type) bool { + switch types.Unalias(t).(type) { + case *types.Named, *types.TypeParam: + return false + } + return true +} + +// identicalOrigin reports whether x and y originated in the same declaration. +func identicalOrigin(x, y *types.Named) bool { + // TODO(gri) is this correct? + return x.Origin().Obj() == y.Origin().Obj() +} + +func match(x, y types.Type) types.Type { + // Common case: we don't have channels. + if types.Identical(x, y) { + return x + } + + // We may have channels that differ in direction only. + if x, _ := x.(*types.Chan); x != nil { + if y, _ := y.(*types.Chan); y != nil && types.Identical(x.Elem(), y.Elem()) { + // We have channels that differ in direction only. + // If there's an unrestricted channel, select the restricted one. + switch { + case x.Dir() == types.SendRecv: + return y + case y.Dir() == types.SendRecv: + return x + } + } + } + + // types are different + return nil +} + +func coreType(t types.Type) types.Type { + t = types.Unalias(t) + tpar, _ := t.(*types.TypeParam) + if tpar == nil { + return t.Underlying() + } + + return nil +} + +func sameId(obj *types.Var, pkg *types.Package, name string, foldCase bool) bool { + // If we don't care about capitalization, we also ignore packages. + if foldCase && strings.EqualFold(obj.Name(), name) { + return true + } + // spec: + // "Two identifiers are different if they are spelled differently, + // or if they appear in different packages and are not exported. + // Otherwise, they are the same." + if obj.Name() != name { + return false + } + // obj.Name == name + if obj.Exported() { + return true + } + // not exported, so packages must be the same + if obj.Pkg() != nil && pkg != nil { + return obj.Pkg() == pkg + } + return obj.Pkg().Path() == pkg.Path() +} + // nify implements the core unification algorithm which is an // adapted version of Checker.identical. For changes to that // code the corresponding changes should be made here. // Must not be called directly from outside the unifier. -func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { +func (u *unifier) nify(x, y types.Type, mode unifyMode) (result bool) { u.depth++ - if traceInference { - u.tracef("%s ≡ %s\t// %s", x, y, mode) - } defer func() { - if traceInference && !result { - u.tracef("%s ≢ %s", x, y) - } u.depth-- }() // nothing to do if x == y - if x == y || Unalias(x) == Unalias(y) { + if x == y || types.Unalias(x) == types.Unalias(y) { return true } // Stop gap for cases where unification fails. if u.depth > unificationDepthLimit { - if traceInference { - u.tracef("depth %d >= %d", u.depth, unificationDepthLimit) - } if panicAtUnificationDepthLimit { panic("unification reached recursion depth limit") } @@ -312,9 +328,6 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { // - defined type, make sure one is in y // - type parameter recorded with u, make sure one is in x if asNamed(x) != nil || u.asBoundTypeParam(y) != nil { - if traceInference { - u.tracef("%s ≡ %s\t// swap", y, x) - } x, y = y, x } @@ -335,16 +348,12 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { // we will fail at function instantiation or argument assignment time. // // If we have at least one defined type, there is one in y. - if ny := asNamed(y); mode&exact == 0 && ny != nil && isTypeLit(x) && !(u.enableInterfaceInference && IsInterface(x)) { - if traceInference { - u.tracef("%s ≡ under %s", x, ny) - } - y = ny.under() + if ny := asNamed(y); mode&unifyModeExact == 0 && ny != nil && isTypeLit(x) { + y = ny.Underlying() // Per the spec, a defined type cannot have an underlying type // that is a type parameter. - assert(!isTypeParam(y)) // x and y may be identical now - if x == y || Unalias(x) == Unalias(y) { + if x == y || types.Unalias(x) == types.Unalias(y) { return true } } @@ -362,13 +371,13 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { return true } // both x and y have an inferred type - they must match - return u.nify(u.at(px), u.at(py), mode, p) + return u.nify(u.at(px), u.at(py), mode) case px != nil: // x is a type parameter, y is not if x := u.at(px); x != nil { // x has an inferred type which must match y - if u.nify(x, y, mode, p) { + if u.nify(x, y, mode) { // We have a match, possibly through underlying types. xi := asInterface(x) yi := asInterface(y) @@ -381,8 +390,11 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { // If both types are defined types, they must be identical // because unification doesn't know which type has the "right" name. if xn && yn { - return Identical(x, y) + return types.Identical(x, y) } + return false + // Below is the original code for reference + // In all other cases, the method sets must match. // The types unified so we know that corresponding methods // match and we can simply compare the number of methods. @@ -391,9 +403,9 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { // type, it's not clear how to choose and whether we introduce // an order dependency or not. Requiring the same method set // is conservative. - if len(xi.typeSet().methods) != len(yi.typeSet().methods) { - return false - } + // if len(xi.typeSet().methods) != len(yi.typeSet().methods) { + // return false + // } } else if xi != nil || yi != nil { // One but not both of them are interfaces. // In this case, either x or y could be viable matches for the corresponding @@ -421,7 +433,7 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { // // If we have defined and literal channel types, a defined type wins to avoid // order dependencies. - if mode&exact == 0 { + if mode&unifyModeExact == 0 { switch { case xn: // x is a defined type: nothing to do. @@ -430,7 +442,7 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { u.set(px, y) default: // Neither x nor y are defined types. - if yc, _ := under(y).(*Chan); yc != nil && yc.dir != SendRecv { + if yc, _ := y.Underlying().(*types.Chan); yc != nil && yc.Dir() != types.SendRecv { // y is a directed channel type: select y. u.set(px, y) } @@ -445,105 +457,11 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { return true } - // x != y if we get here - assert(x != y && Unalias(x) != Unalias(y)) - // If u.EnableInterfaceInference is set and we don't require exact unification, // if both types are interfaces, one interface must have a subset of the // methods of the other and corresponding method signatures must unify. // If only one type is an interface, all its methods must be present in the // other type and corresponding method signatures must unify. - if u.enableInterfaceInference && mode&exact == 0 { - // One or both interfaces may be defined types. - // Look under the name, but not under type parameters (go.dev/issue/60564). - xi := asInterface(x) - yi := asInterface(y) - // If we have two interfaces, check the type terms for equivalence, - // and unify common methods if possible. - if xi != nil && yi != nil { - xset := xi.typeSet() - yset := yi.typeSet() - if xset.comparable != yset.comparable { - return false - } - // For now we require terms to be equal. - // We should be able to relax this as well, eventually. - if !xset.terms.equal(yset.terms) { - return false - } - // Interface types are the only types where cycles can occur - // that are not "terminated" via named types; and such cycles - // can only be created via method parameter types that are - // anonymous interfaces (directly or indirectly) embedding - // the current interface. Example: - // - // type T interface { - // m() interface{T} - // } - // - // If two such (differently named) interfaces are compared, - // endless recursion occurs if the cycle is not detected. - // - // If x and y were compared before, they must be equal - // (if they were not, the recursion would have stopped); - // search the ifacePair stack for the same pair. - // - // This is a quadratic algorithm, but in practice these stacks - // are extremely short (bounded by the nesting depth of interface - // type declarations that recur via parameter types, an extremely - // rare occurrence). An alternative implementation might use a - // "visited" map, but that is probably less efficient overall. - q := &ifacePair{xi, yi, p} - for p != nil { - if p.identical(q) { - return true // same pair was compared before - } - p = p.prev - } - // The method set of x must be a subset of the method set - // of y or vice versa, and the common methods must unify. - xmethods := xset.methods - ymethods := yset.methods - // The smaller method set must be the subset, if it exists. - if len(xmethods) > len(ymethods) { - xmethods, ymethods = ymethods, xmethods - } - // len(xmethods) <= len(ymethods) - // Collect the ymethods in a map for quick lookup. - ymap := make(map[string]*Func, len(ymethods)) - for _, ym := range ymethods { - ymap[ym.Id()] = ym - } - // All xmethods must exist in ymethods and corresponding signatures must unify. - for _, xm := range xmethods { - if ym := ymap[xm.Id()]; ym == nil || !u.nify(xm.typ, ym.typ, exact, p) { - return false - } - } - return true - } - - // We don't have two interfaces. If we have one, make sure it's in xi. - if yi != nil { - xi = yi - y = x - } - - // If we have one interface, at a minimum each of the interface methods - // must be implemented and thus unify with a corresponding method from - // the non-interface type, otherwise unification fails. - if xi != nil { - // All xi methods must exist in y and corresponding signatures must unify. - xmethods := xi.typeSet().methods - for _, xm := range xmethods { - obj, _, _ := LookupFieldOrMethod(y, false, xm.pkg, xm.name) - if ym, _ := obj.(*Func); ym == nil || !u.nify(xm.typ, ym.typ, exact, p) { - return false - } - } - return true - } - } // Unless we have exact unification, neither x nor y are interfaces now. // Except for unbound type parameters (see below), x and y must be structurally @@ -557,60 +475,58 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { // // TODO(gri) Factor out type parameter handling from the switch. if isTypeParam(y) { - if traceInference { - u.tracef("%s ≡ %s\t// swap", y, x) - } x, y = y, x } // Type elements (array, slice, etc. elements) use emode for unification. // Element types must match exactly if the types are used in an assignment. emode := mode - if mode&assign != 0 { - emode |= exact + if mode&unifyModeAssign != 0 { + emode |= unifyModeExact } // Continue with unaliased types but don't lose original alias names, if any (go.dev/issue/67628). - xorig, x := x, Unalias(x) - yorig, y := y, Unalias(y) + xorig, x := x, types.Unalias(x) + yorig, y := y, types.Unalias(y) switch x := x.(type) { - case *Basic: + case *types.Basic: // Basic types are singletons except for the rune and byte // aliases, thus we cannot solely rely on the x == y check // above. See also comment in TypeName.IsAlias. - if y, ok := y.(*Basic); ok { - return x.kind == y.kind + if y, ok := y.(*types.Basic); ok { + return x.Kind() == y.Kind() } - case *Array: + case *types.Array: // Two array types unify if they have the same array length // and their element types unify. - if y, ok := y.(*Array); ok { + if y, ok := y.(*types.Array); ok { // If one or both array lengths are unknown (< 0) due to some error, // assume they are the same to avoid spurious follow-on errors. - return (x.len < 0 || y.len < 0 || x.len == y.len) && u.nify(x.elem, y.elem, emode, p) + return (x.Len() < 0 || y.Len() < 0 || x.Len() == y.Len()) && u.nify(x.Elem(), y.Elem(), emode) } - case *Slice: + case *types.Slice: // Two slice types unify if their element types unify. - if y, ok := y.(*Slice); ok { - return u.nify(x.elem, y.elem, emode, p) + if y, ok := y.(*types.Slice); ok { + return u.nify(x.Elem(), y.Elem(), emode) } - case *Struct: + case *types.Struct: // Two struct types unify if they have the same sequence of fields, // and if corresponding fields have the same names, their (field) types unify, // and they have identical tags. Two embedded fields are considered to have the same // name. Lower-case field names from different packages are always different. - if y, ok := y.(*Struct); ok { + if y, ok := y.(*types.Struct); ok { if x.NumFields() == y.NumFields() { - for i, f := range x.fields { - g := y.fields[i] - if f.embedded != g.embedded || + for i := range x.NumFields() { + f := x.Field(i) + g := y.Field(i) + if f.Embedded() != g.Embedded() || x.Tag(i) != y.Tag(i) || - !f.sameId(g.pkg, g.name, false) || - !u.nify(f.typ, g.typ, emode, p) { + !sameId(f, g.Pkg(), g.Name(), false) || + !u.nify(f.Type(), g.Type(), emode) { return false } } @@ -618,21 +534,22 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { } } - case *Pointer: + case *types.Pointer: // Two pointer types unify if their base types unify. - if y, ok := y.(*Pointer); ok { - return u.nify(x.base, y.base, emode, p) + if y, ok := y.(*types.Pointer); ok { + return u.nify(x.Elem(), y.Elem(), emode) } - case *Tuple: + case *types.Tuple: // Two tuples types unify if they have the same number of elements // and the types of corresponding elements unify. - if y, ok := y.(*Tuple); ok { + if y, ok := y.(*types.Tuple); ok { if x.Len() == y.Len() { if x != nil { - for i, v := range x.vars { - w := y.vars[i] - if !u.nify(v.typ, w.typ, mode, p) { + for i := range x.Len() { + v := x.At(i) + w := y.At(i) + if !u.nify(v.Type(), w.Type(), mode) { return false } } @@ -641,119 +558,116 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { } } - case *Signature: + case *types.Signature: // Two function types unify if they have the same number of parameters // and result values, corresponding parameter and result types unify, // and either both functions are variadic or neither is. // Parameter and result names are not required to match. // TODO(gri) handle type parameters or document why we can ignore them. - if y, ok := y.(*Signature); ok { - return x.variadic == y.variadic && - u.nify(x.params, y.params, emode, p) && - u.nify(x.results, y.results, emode, p) + if y, ok := y.(*types.Signature); ok { + return x.Variadic() == y.Variadic() && + u.nify(x.Params(), y.Params(), emode) && + u.nify(x.Results(), y.Results(), emode) } - case *Interface: - assert(!u.enableInterfaceInference || mode&exact != 0) // handled before this switch + case *types.Interface: + return false + // Below is the original code // Two interface types unify if they have the same set of methods with // the same names, and corresponding function types unify. // Lower-case method names from different packages are always different. // The order of the methods is irrelevant. - if y, ok := y.(*Interface); ok { - xset := x.typeSet() - yset := y.typeSet() - if xset.comparable != yset.comparable { - return false - } - if !xset.terms.equal(yset.terms) { - return false - } - a := xset.methods - b := yset.methods - if len(a) == len(b) { - // Interface types are the only types where cycles can occur - // that are not "terminated" via named types; and such cycles - // can only be created via method parameter types that are - // anonymous interfaces (directly or indirectly) embedding - // the current interface. Example: - // - // type T interface { - // m() interface{T} - // } - // - // If two such (differently named) interfaces are compared, - // endless recursion occurs if the cycle is not detected. - // - // If x and y were compared before, they must be equal - // (if they were not, the recursion would have stopped); - // search the ifacePair stack for the same pair. - // - // This is a quadratic algorithm, but in practice these stacks - // are extremely short (bounded by the nesting depth of interface - // type declarations that recur via parameter types, an extremely - // rare occurrence). An alternative implementation might use a - // "visited" map, but that is probably less efficient overall. - q := &ifacePair{x, y, p} - for p != nil { - if p.identical(q) { - return true // same pair was compared before - } - p = p.prev - } - if debug { - assertSortedMethods(a) - assertSortedMethods(b) - } - for i, f := range a { - g := b[i] - if f.Id() != g.Id() || !u.nify(f.typ, g.typ, exact, q) { - return false - } - } - return true - } - } - - case *Map: + // xset := x.typeSet() + // yset := y.typeSet() + // if xset.comparable != yset.comparable { + // return false + // } + // if !xset.terms.equal(yset.terms) { + // return false + // } + // a := xset.methods + // b := yset.methods + // if len(a) == len(b) { + // // Interface types are the only types where cycles can occur + // // that are not "terminated" via named types; and such cycles + // // can only be created via method parameter types that are + // // anonymous interfaces (directly or indirectly) embedding + // // the current interface. Example: + // // + // // type T interface { + // // m() interface{T} + // // } + // // + // // If two such (differently named) interfaces are compared, + // // endless recursion occurs if the cycle is not detected. + // // + // // If x and y were compared before, they must be equal + // // (if they were not, the recursion would have stopped); + // // search the ifacePair stack for the same pair. + // // + // // This is a quadratic algorithm, but in practice these stacks + // // are extremely short (bounded by the nesting depth of interface + // // type declarations that recur via parameter types, an extremely + // // rare occurrence). An alternative implementation might use a + // // "visited" map, but that is probably less efficient overall. + // q := &ifacePair{x, y, p} + // for p != nil { + // if p.identical(q) { + // return true // same pair was compared before + // } + // p = p.prev + // } + // if debug { + // assertSortedMethods(a) + // assertSortedMethods(b) + // } + // for i, f := range a { + // g := b[i] + // if f.Id() != g.Id() || !u.nify(f.typ, g.typ, exact, q) { + // return false + // } + // } + // return true + // } + + case *types.Map: // Two map types unify if their key and value types unify. - if y, ok := y.(*Map); ok { - return u.nify(x.key, y.key, emode, p) && u.nify(x.elem, y.elem, emode, p) + if y, ok := y.(*types.Map); ok { + return u.nify(x.Key(), y.Key(), emode) && u.nify(x.Elem(), y.Elem(), emode) } - case *Chan: + case *types.Chan: // Two channel types unify if their value types unify // and if they have the same direction. // The channel direction is ignored for inexact unification. - if y, ok := y.(*Chan); ok { - return (mode&exact == 0 || x.dir == y.dir) && u.nify(x.elem, y.elem, emode, p) + if y, ok := y.(*types.Chan); ok { + return (mode&unifyModeExact == 0 || x.Dir() == y.Dir()) && u.nify(x.Elem(), y.Elem(), emode) } - case *Named: + case *types.Named: // Two named types unify if their type names originate in the same type declaration. // If they are instantiated, their type argument lists must unify. if y := asNamed(y); y != nil { // Check type arguments before origins so they unify // even if the origins don't match; for better error // messages (see go.dev/issue/53692). - xargs := x.TypeArgs().list() - yargs := y.TypeArgs().list() - if len(xargs) != len(yargs) { + xargs := x.TypeArgs() + yargs := y.TypeArgs() + if xargs.Len() != yargs.Len() { return false } - for i, xarg := range xargs { - if !u.nify(xarg, yargs[i], mode, p) { + for i := range xargs.Len() { + xarg := xargs.At(i) + yarg := yargs.At(i) + if !u.nify(xarg, yarg, mode) { return false } } return identicalOrigin(x, y) } - case *TypeParam: - // x must be an unbound type parameter (see comment above). - if debug { - assert(u.asBoundTypeParam(x) == nil) - } + case *types.TypeParam: // By definition, a valid type argument must be in the type set of // the respective type constraint. Therefore, the type argument's // underlying type must be in the set of underlying types of that @@ -774,14 +688,11 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { // up here again with x and y swapped, so we don't // need to take care of that case separately. if cx := coreType(x); cx != nil { - if traceInference { - u.tracef("core %s ≡ %s", xorig, yorig) - } // If y is a defined type, it may not match against cx which // is an underlying type (incl. int, string, etc.). Use assign // mode here so that the unifier automatically takes under(y) // if necessary. - return u.nify(cx, yorig, assign, p) + return u.nify(cx, yorig, unifyModeAssign) } } // x != y and there's nothing to do @@ -790,7 +701,7 @@ func (u *unifier) nify(x, y Type, mode unifyMode, p *ifacePair) (result bool) { // avoid a crash in case of nil type default: - panic(sprintf(nil, true, "u.nify(%s, %s, %d)", xorig, yorig, mode)) + panic(fmt.Sprintf("u.nify(%s, %s, %d)", xorig, yorig, mode)) } return false diff --git a/gopls/internal/test/integration/completion/completion_test.go b/gopls/internal/test/integration/completion/completion_test.go index c96e569f1ad..0e3dbe72c8b 100644 --- a/gopls/internal/test/integration/completion/completion_test.go +++ b/gopls/internal/test/integration/completion/completion_test.go @@ -970,6 +970,96 @@ use ./missing/ }) } +const reverseInferenceSrc = ` +-- go.mod -- +module mod.com + +go 1.18 +-- a.go -- +package a + +type Wrap[T any] struct { + inner *T +} + +func NewWrap[T any](x T) Wrap[T] { + return Wrap[T]{inner: &x} +} + +func DoubleWrap[T any, U any](t T, u U) (Wrap[T], Wrap[U]) { + return Wrap[T]{inner: &t}, Wrap[U]{inner: &u} +} + +type InterfaceA interface { + implA() +} + +type InterfaceB interface { + implB() +} + +type TypeA struct{} + +func (TypeA) implA() {} + +type TypeX string + +func (TypeX) implB() {} + +type TypeB struct{} + +func (TypeB) implB() {} + +func one[a int | string]() {} + +func main() { + var y Wrap[InterfaceA] + var x Wrap[InterfaceB] + avar := TypeA{} + bvar := TypeB{} + x = NewWrap[]() + x, y = DoubleWrap[,]() +} +` + +func TestReverseInferInterfaceCompletion(t *testing.T) { + Run(t, reverseInferenceSrc, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `NewWrap\[\]\(()\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + + want := []string{"bvar", "x.inner", "TypeB{}", "TypeX()", "nil"} + for i, item := range result.Items[:len(want)] { + if diff := cmp.Diff(want[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected mismatch (-want +got):\n%s", diff) + } + } + }) +} + +// TODO: implement after fixed index list expressions inference with multiple type parameters +func TestReverseInferDoubleTypeParamCompletion(t *testing.T) { + Run(t, reverseInferenceSrc, func(t *testing.T, env *Env) { + }) +} + +func TestReverseInferInterfaceTypeParamCompletion(t *testing.T) { + Run(t, reverseInferenceSrc, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `NewWrap\[\]\(()\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + + want := []string{"bvar", "x.inner", "TypeB{}", "TypeX()", "nil"} + for i, item := range result.Items[:len(want)] { + if diff := cmp.Diff(want[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected mismatch (-want +got):\n%s", diff) + } + } + }) +} + func TestBuiltinCompletion(t *testing.T) { const files = ` -- go.mod --