Skip to content

Commit 52f6504

Browse files
authored
Merge pull request grpc#867 from iamqizhao/master
Support client side interceptor
2 parents 8d57dd3 + 61f62e0 commit 52f6504

File tree

5 files changed

+142
-19
lines changed

5 files changed

+142
-19
lines changed

call.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,14 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
112112
// Invoke sends the RPC request on the wire and returns after response is received.
113113
// Invoke is called by generated code. Also users can call Invoke directly when it
114114
// is really needed in their use cases.
115-
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
115+
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) error {
116+
if cc.dopts.unaryInt != nil {
117+
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
118+
}
119+
return invoke(ctx, method, args, reply, cc, opts...)
120+
}
121+
122+
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
116123
c := defaultCallInfo
117124
for _, o := range opts {
118125
if err := o.before(&c); err != nil {

clientconn.go

+25-9
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,17 @@ var (
8383
// dialOptions configure a Dial call. dialOptions are set by the DialOption
8484
// values passed to Dial.
8585
type dialOptions struct {
86-
codec Codec
87-
cp Compressor
88-
dc Decompressor
89-
bs backoffStrategy
90-
balancer Balancer
91-
block bool
92-
insecure bool
93-
timeout time.Duration
94-
copts transport.ConnectOptions
86+
unaryInt UnaryClientInterceptor
87+
streamInt StreamClientInterceptor
88+
codec Codec
89+
cp Compressor
90+
dc Decompressor
91+
bs backoffStrategy
92+
balancer Balancer
93+
block bool
94+
insecure bool
95+
timeout time.Duration
96+
copts transport.ConnectOptions
9597
}
9698

9799
// DialOption configures how we set up the connection.
@@ -215,6 +217,20 @@ func WithUserAgent(s string) DialOption {
215217
}
216218
}
217219

220+
// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs.
221+
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
222+
return func(o *dialOptions) {
223+
o.unaryInt = f
224+
}
225+
}
226+
227+
// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs.
228+
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
229+
return func(o *dialOptions) {
230+
o.streamInt = f
231+
}
232+
}
233+
218234
// Dial creates a client connection to the given target.
219235
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
220236
return DialContext(context.Background(), target, opts...)

interceptor.go

+16
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,22 @@ import (
3737
"golang.org/x/net/context"
3838
)
3939

40+
// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs.
41+
type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error
42+
43+
// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC
44+
// and it is the responsibility of the interceptor to call it.
45+
// This is the EXPERIMENTAL API.
46+
type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
47+
48+
// Streamer is called by StreamClientInterceptor to create a ClientStream.
49+
type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error)
50+
51+
// StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O
52+
// operations. streamer is the handlder to create a ClientStream and it is the responsibility of the interceptor to call it.
53+
// This is the EXPERIMENTAL API.
54+
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)
55+
4056
// UnaryServerInfo consists of various information about a unary RPC on
4157
// server side. All per-rpc information may be mutated by the interceptor.
4258
type UnaryServerInfo struct {

stream.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,14 @@ type ClientStream interface {
9797

9898
// NewClientStream creates a new Stream for the client side. This is called
9999
// by generated code.
100-
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
100+
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
101+
if cc.dopts.streamInt != nil {
102+
return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
103+
}
104+
return newClientStream(ctx, desc, cc, method, opts...)
105+
}
106+
107+
func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
101108
var (
102109
t transport.ClientTransport
103110
s *transport.Stream

test/end2end_test.go

+85-8
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,10 @@ type test struct {
368368
userAgent string
369369
clientCompression bool
370370
serverCompression bool
371-
unaryInt grpc.UnaryServerInterceptor
372-
streamInt grpc.StreamServerInterceptor
371+
unaryClientInt grpc.UnaryClientInterceptor
372+
streamClientInt grpc.StreamClientInterceptor
373+
unaryServerInt grpc.UnaryServerInterceptor
374+
streamServerInt grpc.StreamServerInterceptor
373375

374376
// srv and srvAddr are set once startServer is called.
375377
srv *grpc.Server
@@ -423,11 +425,11 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
423425
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
424426
)
425427
}
426-
if te.unaryInt != nil {
427-
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt))
428+
if te.unaryServerInt != nil {
429+
sopts = append(sopts, grpc.UnaryInterceptor(te.unaryServerInt))
428430
}
429-
if te.streamInt != nil {
430-
sopts = append(sopts, grpc.StreamInterceptor(te.streamInt))
431+
if te.streamServerInt != nil {
432+
sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt))
431433
}
432434
la := "localhost:0"
433435
switch te.e.network {
@@ -492,6 +494,12 @@ func (te *test) clientConn() *grpc.ClientConn {
492494
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
493495
)
494496
}
497+
if te.unaryClientInt != nil {
498+
opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt))
499+
}
500+
if te.streamClientInt != nil {
501+
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
502+
}
495503
switch te.e.security {
496504
case "tls":
497505
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
@@ -2137,6 +2145,75 @@ func testCompressOK(t *testing.T, e env) {
21372145
}
21382146
}
21392147

2148+
func TestUnaryClientInterceptor(t *testing.T) {
2149+
defer leakCheck(t)()
2150+
for _, e := range listTestEnv() {
2151+
testUnaryClientInterceptor(t, e)
2152+
}
2153+
}
2154+
2155+
func failOkayRPC(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
2156+
err := invoker(ctx, method, req, reply, cc, opts...)
2157+
if err == nil {
2158+
return grpc.Errorf(codes.NotFound, "")
2159+
}
2160+
return err
2161+
}
2162+
2163+
func testUnaryClientInterceptor(t *testing.T, e env) {
2164+
te := newTest(t, e)
2165+
te.userAgent = testAppUA
2166+
te.unaryClientInt = failOkayRPC
2167+
te.startServer(&testServer{security: e.security})
2168+
defer te.tearDown()
2169+
2170+
tc := testpb.NewTestServiceClient(te.clientConn())
2171+
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.NotFound {
2172+
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %s", tc, err, codes.NotFound)
2173+
}
2174+
}
2175+
2176+
func TestStreamClientInterceptor(t *testing.T) {
2177+
defer leakCheck(t)()
2178+
for _, e := range listTestEnv() {
2179+
testStreamClientInterceptor(t, e)
2180+
}
2181+
}
2182+
2183+
func failOkayStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
2184+
s, err := streamer(ctx, desc, cc, method, opts...)
2185+
if err == nil {
2186+
return nil, grpc.Errorf(codes.NotFound, "")
2187+
}
2188+
return s, nil
2189+
}
2190+
2191+
func testStreamClientInterceptor(t *testing.T, e env) {
2192+
te := newTest(t, e)
2193+
te.streamClientInt = failOkayStream
2194+
te.startServer(&testServer{security: e.security})
2195+
defer te.tearDown()
2196+
2197+
tc := testpb.NewTestServiceClient(te.clientConn())
2198+
respParam := []*testpb.ResponseParameters{
2199+
{
2200+
Size: proto.Int32(int32(1)),
2201+
},
2202+
}
2203+
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
2204+
if err != nil {
2205+
t.Fatal(err)
2206+
}
2207+
req := &testpb.StreamingOutputCallRequest{
2208+
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
2209+
ResponseParameters: respParam,
2210+
Payload: payload,
2211+
}
2212+
if _, err := tc.StreamingOutputCall(context.Background(), req); grpc.Code(err) != codes.NotFound {
2213+
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, error code %s", tc, err, codes.NotFound)
2214+
}
2215+
}
2216+
21402217
func TestUnaryServerInterceptor(t *testing.T) {
21412218
defer leakCheck(t)()
21422219
for _, e := range listTestEnv() {
@@ -2150,7 +2227,7 @@ func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInf
21502227

21512228
func testUnaryServerInterceptor(t *testing.T, e env) {
21522229
te := newTest(t, e)
2153-
te.unaryInt = errInjector
2230+
te.unaryServerInt = errInjector
21542231
te.startServer(&testServer{security: e.security})
21552232
defer te.tearDown()
21562233

@@ -2181,7 +2258,7 @@ func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServ
21812258

21822259
func testStreamServerInterceptor(t *testing.T, e env) {
21832260
te := newTest(t, e)
2184-
te.streamInt = fullDuplexOnly
2261+
te.streamServerInt = fullDuplexOnly
21852262
te.startServer(&testServer{security: e.security})
21862263
defer te.tearDown()
21872264

0 commit comments

Comments
 (0)