Skip to content

Commit 61f62e0

Browse files
committed
Merge branch 'master' of https://github.com/grpc/grpc-go
2 parents 1e47e17 + d736c11 commit 61f62e0

File tree

4 files changed

+67
-9
lines changed

4 files changed

+67
-9
lines changed

clientconn.go

+5-7
Original file line numberDiff line numberDiff line change
@@ -250,13 +250,13 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
250250
defer func() {
251251
select {
252252
case <-ctx.Done():
253-
if conn != nil {
254-
conn.Close()
255-
}
256-
conn = nil
257-
err = ctx.Err()
253+
conn, err = nil, ctx.Err()
258254
default:
259255
}
256+
257+
if err != nil {
258+
cc.Close()
259+
}
260260
}()
261261

262262
for _, opt := range opts {
@@ -312,11 +312,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
312312
return nil, ctx.Err()
313313
case err := <-waitC:
314314
if err != nil {
315-
cc.Close()
316315
return nil, err
317316
}
318317
case <-timeoutCh:
319-
cc.Close()
320318
return nil, ErrClientConnTimeout
321319
}
322320
// If balancer is nil or balancer.Notify() is nil, ok will be false here.

credentials/credentials.go

+7
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ package credentials // import "google.golang.org/grpc/credentials"
4040
import (
4141
"crypto/tls"
4242
"crypto/x509"
43+
"errors"
4344
"fmt"
4445
"io/ioutil"
4546
"net"
@@ -86,6 +87,12 @@ type AuthInfo interface {
8687
AuthType() string
8788
}
8889

90+
var (
91+
// ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
92+
// and the caller should not close rawConn.
93+
ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
94+
)
95+
8996
// TransportCredentials defines the common interface for all the live gRPC wire
9097
// protocols and supported transport security protocols (e.g., TLS, SSL).
9198
type TransportCredentials interface {

server.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
367367
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
368368
s.mu.Unlock()
369369
grpclog.Printf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
370-
rawConn.Close()
370+
// If serverHandShake returns ErrConnDispatched, keep rawConn open.
371+
if err != credentials.ErrConnDispatched {
372+
rawConn.Close()
373+
}
371374
return
372375
}
373376

test/end2end_test.go

+51-1
Original file line numberDiff line numberDiff line change
@@ -848,9 +848,11 @@ func testFailFast(t *testing.T, e env) {
848848
te.srv.Stop()
849849
// Loop until the server teardown is propagated to the client.
850850
for {
851-
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) == codes.Unavailable {
851+
_, err := tc.EmptyCall(context.Background(), &testpb.Empty{})
852+
if grpc.Code(err) == codes.Unavailable {
852853
break
853854
}
855+
fmt.Printf("%v.EmptyCall(_, _) = _, %v", tc, err)
854856
time.Sleep(10 * time.Millisecond)
855857
}
856858
// The client keeps reconnecting and ongoing fail-fast RPCs should fail with code.Unavailable.
@@ -2462,6 +2464,54 @@ func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) {
24622464
}
24632465
}
24642466

2467+
type serverDispatchCred struct {
2468+
ready chan struct{}
2469+
rawConn net.Conn
2470+
}
2471+
2472+
func newServerDispatchCred() *serverDispatchCred {
2473+
return &serverDispatchCred{
2474+
ready: make(chan struct{}),
2475+
}
2476+
}
2477+
func (c *serverDispatchCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
2478+
return rawConn, nil, nil
2479+
}
2480+
func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
2481+
c.rawConn = rawConn
2482+
close(c.ready)
2483+
return nil, nil, credentials.ErrConnDispatched
2484+
}
2485+
func (c *serverDispatchCred) Info() credentials.ProtocolInfo {
2486+
return credentials.ProtocolInfo{}
2487+
}
2488+
func (c *serverDispatchCred) getRawConn() net.Conn {
2489+
<-c.ready
2490+
return c.rawConn
2491+
}
2492+
2493+
func TestServerCredsDispatch(t *testing.T) {
2494+
lis, err := net.Listen("tcp", ":0")
2495+
if err != nil {
2496+
t.Fatalf("Failed to listen: %v", err)
2497+
}
2498+
cred := newServerDispatchCred()
2499+
s := grpc.NewServer(grpc.Creds(cred))
2500+
go s.Serve(lis)
2501+
defer s.Stop()
2502+
2503+
cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(cred))
2504+
if err != nil {
2505+
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
2506+
}
2507+
defer cc.Close()
2508+
2509+
// Check rawConn is not closed.
2510+
if n, err := cred.getRawConn().Write([]byte{0}); n <= 0 || err != nil {
2511+
t.Errorf("Read() = %v, %v; want n>0, <nil>", n, err)
2512+
}
2513+
}
2514+
24652515
// interestingGoroutines returns all goroutines we care about for the purpose
24662516
// of leak checking. It excludes testing or runtime ones.
24672517
func interestingGoroutines() (gs []string) {

0 commit comments

Comments
 (0)