diff --git a/src/cmd/compile/internal/devirtualize/devirtualize.go b/src/cmd/compile/internal/devirtualize/devirtualize.go index 372d05809401ff..5dd4d65b07be56 100644 --- a/src/cmd/compile/internal/devirtualize/devirtualize.go +++ b/src/cmd/compile/internal/devirtualize/devirtualize.go @@ -40,14 +40,17 @@ func StaticCall(call *ir.CallExpr) { } sel := call.Fun.(*ir.SelectorExpr) - r := ir.StaticValue(sel.X) - if r.Op() != ir.OCONVIFACE { + typ := staticType(sel.X) + if typ == nil { return } - recv := r.(*ir.ConvExpr) - typ := recv.X.Type() - if typ.IsInterface() { + // Don't try to devirtualize calls that we statically know that would have failed at runtime. + // This can happen in such case: any(0).(interface {A()}).A(), this typechecks without + // any errors, but will cause a runtime panic. We statically know that int(0) does not + // implement that interface, thus we skip the devirtualization, as it is not possible + // to make a type assertion from interface{A()} to int (int does not implement interface{A()}). + if !typecheck.Implements(typ, sel.X.Type()) { return } @@ -138,3 +141,85 @@ func StaticCall(call *ir.CallExpr) { // Desugar OCALLMETH, if we created one (#57309). typecheck.FixMethodCall(call) } + +func staticType(n ir.Node) *types.Type { + for { + switch n1 := n.(type) { + case *ir.ConvExpr: + if n1.Op() == ir.OCONVNOP || n1.Op() == ir.OCONVIFACE { + n = n1.X + continue + } + case *ir.InlinedCallExpr: + if n1.Op() == ir.OINLCALL { + n = n1.SingleResult() + continue + } + case *ir.ParenExpr: + n = n1.X + continue + case *ir.TypeAssertExpr: + n = n1.X + continue + } + + n1 := staticValue(n) + if n1 == nil { + if n.Type().IsInterface() { + return nil + } + return n.Type() + } + n = n1 + } +} + +func staticValue(nn ir.Node) ir.Node { + if nn.Op() != ir.ONAME { + return nil + } + + n := nn.(*ir.Name).Canonical() + if n.Class != ir.PAUTO { + return nil + } + + defn := n.Defn + if defn == nil { + return nil + } + + var rhs ir.Node +FindRHS: + switch defn.Op() { + case ir.OAS: + defn := defn.(*ir.AssignStmt) + rhs = defn.Y + case ir.OAS2: + defn := defn.(*ir.AssignListStmt) + for i, lhs := range defn.Lhs { + if lhs == n { + rhs = defn.Rhs[i] + break FindRHS + } + } + base.Fatalf("%v missing from LHS of %v", n, defn) + case ir.OAS2DOTTYPE: + defn := defn.(*ir.AssignListStmt) + if defn.Lhs[0] == n { + rhs = defn.Rhs[0] + } + default: + return nil + } + + if rhs == nil { + base.Fatalf("RHS is nil: %v", defn) + } + + if ir.Reassigned(n) { + return nil + } + + return rhs +} diff --git a/src/cmd/compile/internal/noder/reader.go b/src/cmd/compile/internal/noder/reader.go index eca66487fa26da..bdfef70f216527 100644 --- a/src/cmd/compile/internal/noder/reader.go +++ b/src/cmd/compile/internal/noder/reader.go @@ -2941,6 +2941,7 @@ func (r *reader) multiExpr() []ir.Node { as.Def = true for i := range results { tmp := r.temp(pos, r.typ()) + tmp.Defn = as as.PtrInit().Append(ir.NewDecl(pos, ir.ODCL, tmp)) as.Lhs.Append(tmp) diff --git a/src/crypto/sha256/sha256_test.go b/src/crypto/sha256/sha256_test.go index e1af9640e25547..04be3461075017 100644 --- a/src/crypto/sha256/sha256_test.go +++ b/src/crypto/sha256/sha256_test.go @@ -391,3 +391,17 @@ func BenchmarkHash1K(b *testing.B) { func BenchmarkHash8K(b *testing.B) { benchmarkSize(b, 8192) } + +func TestAllocatonsWithTypeAsserts(t *testing.T) { + cryptotest.SkipTestAllocations(t) + allocs := testing.AllocsPerRun(100, func() { + h := New() + h.Write([]byte{1, 2, 3}) + marshaled, _ := h.(encoding.BinaryMarshaler).MarshalBinary() + marshaled, _ = h.(encoding.BinaryAppender).AppendBinary(marshaled[:0]) + h.(encoding.BinaryUnmarshaler).UnmarshalBinary(marshaled) + }) + if allocs != 0 { + t.Fatalf("allocs = %v; want = 0", allocs) + } +} diff --git a/test/escape_iface_with_devirt_type_assertions.go b/test/escape_iface_with_devirt_type_assertions.go new file mode 100644 index 00000000000000..5a3d762aa16205 --- /dev/null +++ b/test/escape_iface_with_devirt_type_assertions.go @@ -0,0 +1,214 @@ +// errorcheck -0 -m + +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package escape + +type M interface{ M() } + +type A interface{ A() } + +type C interface{ C() } + +type Impl struct{} + +func (*Impl) M() {} // ERROR "can inline" + +func (*Impl) A() {} // ERROR "can inline" + +type CImpl struct{} + +func (CImpl) C() {} // ERROR "can inline" + +func t() { + var a M = &Impl{} // ERROR "&Impl{} does not escape" + + a.(M).M() // ERROR "devirtualizing a.\(M\).M" "inlining call" + a.(A).A() // ERROR "devirtualizing a.\(A\).A" "inlining call" + a.(*Impl).M() // ERROR "inlining call" + a.(*Impl).A() // ERROR "inlining call" + + v := a.(M) + v.M() // ERROR "devirtualizing v.M" "inlining call" + v.(A).A() // ERROR "devirtualizing v.\(A\).A" "inlining call" + v.(*Impl).A() // ERROR "inlining call" + v.(*Impl).M() // ERROR "inlining call" + + v2 := a.(A) + v2.A() // ERROR "devirtualizing v2.A" "inlining call" + v2.(M).M() // ERROR "devirtualizing v2.\(M\).M" "inlining call" + v2.(*Impl).A() // ERROR "inlining call" + v2.(*Impl).M() // ERROR "inlining call" + + a.(M).(A).A() // ERROR "devirtualizing a.\(M\).\(A\).A" "inlining call" + a.(A).(M).M() // ERROR "devirtualizing a.\(A\).\(M\).M" "inlining call" + + a.(M).(A).(*Impl).A() // ERROR "inlining call" + a.(A).(M).(*Impl).M() // ERROR "inlining call" + + any(a).(M).M() // ERROR "devirtualizing" "inlining call" + any(a).(A).A() // ERROR "devirtualizing" "inlining call" + any(a).(M).(any).(A).A() // ERROR "devirtualizing" "inlining call" + + c := any(a) + c.(A).A() // ERROR "devirtualizing" "inlining call" + c.(M).M() // ERROR "devirtualizing" "inlining call" + + { + var a C = &CImpl{} // ERROR "does not escape" + a.(any).(C).C() // ERROR "devirtualizing" "inlining" + a.(any).(*CImpl).C() // ERROR "inlining" + } +} + +func t2() { + { + var a M = &Impl{} // ERROR "does not escape" + if v, ok := a.(M); ok { + v.M() // ERROR "devirtualizing" "inlining call" + } + } + { + var a M = &Impl{} // ERROR "does not escape" + if v, ok := a.(A); ok { + v.A() // ERROR "devirtualizing" "inlining call" + } + } + { + var a M = &Impl{} // ERROR "does not escape" + v, ok := a.(M) + if ok { + v.M() // ERROR "devirtualizing" "inlining call" + } + } + { + var a M = &Impl{} // ERROR "does not escape" + v, ok := a.(A) + if ok { + v.A() // ERROR "devirtualizing" "inlining call" + } + } + { + var a M = &Impl{} // ERROR "does not escape" + v, ok := a.(*Impl) + if ok { + v.A() // ERROR "inlining" + v.M() // ERROR "inlining" + } + } + { + var a M = &Impl{} // ERROR "does not escape" + v, _ := a.(M) + v.M() // ERROR "devirtualizing" "inlining call" + } + { + var a M = &Impl{} // ERROR "does not escape" + v, _ := a.(A) + v.A() // ERROR "devirtualizing" "inlining call" + } + { + var a M = &Impl{} // ERROR "does not escape" + v, _ := a.(*Impl) + v.A() // ERROR "inlining" + v.M() // ERROR "inlining" + } + { + a := newM() // ERROR "does not escape" "inlining call" + callA(a) // ERROR "devirtualizing" "inlining call" + callIfA(a) // ERROR "devirtualizing" "inlining call" + } + + { + var a M = &Impl{} // ERROR "does not escape" + // Note the !ok condition, devirtualizing here is fine. + if v, ok := a.(M); !ok { + v.M() // ERROR "devirtualizing" "inlining call" + } + } +} + +func newM() M { // ERROR "can inline" + return &Impl{} // ERROR "escapes" +} + +func callA(m M) { // ERROR "can inline" "leaking param" + m.(A).A() +} + +func callIfA(m M) { // ERROR "can inline" "leaking param" + if v, ok := m.(A); ok { + v.A() + } +} + +//go:noinline +func newImplNoInline() *Impl { + return &Impl{} // ERROR "escapes" +} + +func t3() { + { + var a A = newImplNoInline() + if v, ok := a.(M); ok { + v.M() // ERROR "devirtualizing" "inlining call" + } + } + { + m := make(map[*Impl]struct{}) // ERROR "does not escape" + for v := range m { + var v A = v + if v, ok := v.(M); ok { + v.M() // ERROR "devirtualizing" "inlining call" + } + } + } + { + m := make(map[int]*Impl) // ERROR "does not escape" + for _, v := range m { + var v A = v + if v, ok := v.(M); ok { + v.M() // ERROR "devirtualizing" "inlining call" + } + } + } + { + m := make(map[int]*Impl) // ERROR "does not escape" + var v A = m[0] + if v, ok := v.(M); ok { + v.M() // ERROR "devirtualizing" "inlining call" + } + } + { + m := make(chan *Impl) + var v A = <-m + if v, ok := v.(M); ok { + v.M() // ERROR "devirtualizing" "inlining call" + } + } +} + +//go:noinline +func testInvalidAsserts() { + any(0).(interface{ A() }).A() // ERROR "escapes" + { + var a M = &Impl{} // ERROR "escapes" + a.(C).C() // this will panic + a.(any).(C).C() // this will panic + } + { + var a C = &CImpl{} // ERROR "escapes" + a.(M).M() // this will panic + a.(any).(M).M() // this will panic + } + { + var a C = &CImpl{} // ERROR "does not escape" + + // this will panic + a.(M).(*Impl).M() // ERROR "inlining" + + // this will panic + a.(any).(M).(*Impl).M() // ERROR "inlining" + } +}