@@ -59,10 +59,82 @@ var (
59
59
logger * log.Logger
60
60
)
61
61
62
+ func loggerPrintf (format string , v ... any ) {
63
+ if DebugLogging {
64
+ logger .Output (2 , fmt .Sprintf (format , v ... ))
65
+ }
66
+ }
67
+
68
+ // NewClient creates a new agent communication grpc client.
69
+ // Caller must close the returned client when it is done being used to clean up its underlying
70
+ // connections.
71
+ func NewClient (ctx context.Context , regional bool , opts ... option.ClientOption ) (* agentcommunication.Client , error ) {
72
+ zone , err := cm .Zone ()
73
+ if err != nil {
74
+ return nil , err
75
+ }
76
+
77
+ location := zone
78
+ if regional {
79
+ location = location [:len (location )- 2 ]
80
+ }
81
+
82
+ defaultOpts := []option.ClientOption {
83
+ option .WithoutAuthentication (), // Do not use oauth.
84
+ option .WithGRPCDialOption (grpc .WithTransportCredentials (credentials .NewTLS (nil ))), // Because we disabled Auth we need to specifically enable TLS.
85
+ option .WithGRPCDialOption (grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : 60 * time .Second , Timeout : 10 * time .Second })),
86
+ // Use the FQDN to avoid unnecessary DNS lookups.
87
+ option .WithEndpoint (fmt .Sprintf ("%s-agentcommunication.googleapis.com.:443" , location )),
88
+ }
89
+
90
+ opts = append (defaultOpts , opts ... )
91
+ return agentcommunication .NewClient (ctx , opts ... )
92
+ }
93
+
94
+ // SendAgentMessage sends a message to the client. This is equivalent to sending a message via
95
+ // StreamAgentMessages with a single message and waiting for the response.
96
+ func SendAgentMessage (ctx context.Context , channelID string , client * agentcommunication.Client , msg * acpb.MessageBody ) (* acpb.SendAgentMessageResponse , error ) {
97
+ loggerPrintf ("SendAgentMessage" )
98
+ zone , err := cm .Zone ()
99
+ if err != nil {
100
+ return nil , err
101
+ }
102
+ projectNum , err := cm .NumericProjectID ()
103
+ if err != nil {
104
+ return nil , err
105
+ }
106
+ instanceID , err := cm .InstanceID ()
107
+ if err != nil {
108
+ return nil , err
109
+ }
110
+ resourceID := fmt .Sprintf ("projects/%s/zones/%s/instances/%s" , projectNum , zone , instanceID )
111
+
112
+ token , err := cm .Get ("instance/service-accounts/default/identity?audience=agentcommunication.googleapis.com&format=full" )
113
+ if err != nil {
114
+ return nil , fmt .Errorf ("%w: %v" , ErrGettingInstanceToken , err )
115
+ }
116
+
117
+ ctx = metadata .NewOutgoingContext (ctx , metadata .New (map [string ]string {
118
+ "authentication" : "Bearer " + token ,
119
+ "agent-communication-resource-id" : resourceID ,
120
+ "agent-communication-channel-id" : channelID ,
121
+ }))
122
+
123
+ loggerPrintf ("Using ResourceID %q" , resourceID )
124
+ loggerPrintf ("Using ChannelID %q" , channelID )
125
+
126
+ return client .SendAgentMessage (ctx , & acpb.SendAgentMessageRequest {
127
+ ChannelId : channelID ,
128
+ ResourceId : resourceID ,
129
+ MessageBody : msg ,
130
+ })
131
+ }
132
+
62
133
// Connection is an AgentCommunication connection.
63
134
type Connection struct {
64
135
client * agentcommunication.Client
65
- stream acpb.AgentCommunication_StreamAgentMessagesClient
136
+ // Indicates that the client is caller managed and should not be closed.
137
+ callerManagedClient bool
66
138
// Indicates that the entire connection is closed and will not reopen.
67
139
closed chan struct {}
68
140
closeErr error
@@ -81,12 +153,6 @@ type Connection struct {
81
153
timeToWaitForResp time.Duration
82
154
}
83
155
84
- func loggerPrintf (format string , v ... any ) {
85
- if DebugLogging {
86
- logger .Output (2 , fmt .Sprintf (format , v ... ))
87
- }
88
- }
89
-
90
156
// Close the connection.
91
157
func (c * Connection ) Close () {
92
158
c .close (ErrConnectionClosed )
@@ -114,7 +180,9 @@ func (c *Connection) close(err error) {
114
180
default :
115
181
close (c .closed )
116
182
c .setCloseErr (err )
117
- c .client .Close ()
183
+ if ! c .callerManagedClient {
184
+ c .client .Close ()
185
+ }
118
186
}
119
187
}
120
188
@@ -212,13 +280,14 @@ func (c *Connection) Receive() (*acpb.MessageBody, error) {
212
280
}
213
281
}
214
282
215
- func (c * Connection ) streamSend (req * acpb.StreamAgentMessagesRequest , streamClosed chan struct {}) error {
283
+ func (c * Connection ) streamSend (req * acpb.StreamAgentMessagesRequest , streamClosed , streamSendLock chan struct {}, stream acpb. AgentCommunication_StreamAgentMessagesClient ) error {
216
284
select {
217
285
case <- streamClosed :
218
286
return errors .New ("stream closed" )
219
- default :
287
+ case streamSendLock <- struct {}{}:
288
+ defer func () { <- streamSendLock }()
220
289
}
221
- if err := c . stream .Send (req ); err != nil {
290
+ if err := stream .Send (req ); err != nil {
222
291
if err != io .EOF && ! errors .Is (err , io .EOF ) {
223
292
// Something is very broken, just close the stream here.
224
293
loggerPrintf ("Unexpected send error, closing connection: %v" , err )
@@ -241,11 +310,16 @@ func (c *Connection) streamSend(req *acpb.StreamAgentMessagesRequest, streamClos
241
310
return nil
242
311
}
243
312
244
- func (c * Connection ) send (streamClosed chan struct {}) {
313
+ func (c * Connection ) send (streamClosed , streamSendLock chan struct {}, stream acpb.AgentCommunication_StreamAgentMessagesClient ) {
314
+ defer func () {
315
+ // Lock the stream sends so we can close the stream.
316
+ streamSendLock <- struct {}{}
317
+ stream .CloseSend ()
318
+ }()
245
319
for {
246
320
select {
247
321
case req := <- c .sends :
248
- if err := c .streamSend (req , streamClosed ); err != nil {
322
+ if err := c .streamSend (req , streamClosed , streamSendLock , stream ); err != nil {
249
323
return
250
324
}
251
325
case <- c .closed :
@@ -256,7 +330,7 @@ func (c *Connection) send(streamClosed chan struct{}) {
256
330
}
257
331
}
258
332
259
- func (c * Connection ) acknowledgeMessage (messageID string , streamClosed chan struct {}) error {
333
+ func (c * Connection ) acknowledgeMessage (messageID string , streamClosed , streamSendLock chan struct {}, stream acpb. AgentCommunication_StreamAgentMessagesClient ) error {
260
334
ackReq := & acpb.StreamAgentMessagesRequest {
261
335
MessageId : messageID ,
262
336
Type : & acpb.StreamAgentMessagesRequest_MessageResponse {},
@@ -265,15 +339,15 @@ func (c *Connection) acknowledgeMessage(messageID string, streamClosed chan stru
265
339
case <- c .closed :
266
340
return fmt .Errorf ("connection closed with err: %w" , c .closeErr )
267
341
default :
268
- return c .streamSend (ackReq , streamClosed )
342
+ return c .streamSend (ackReq , streamClosed , streamSendLock , stream )
269
343
}
270
344
}
271
345
272
346
// recv keeps receiving and acknowledging new messages.
273
- func (c * Connection ) recv (ctx context.Context , streamClosed chan struct {}) {
347
+ func (c * Connection ) recv (ctx context.Context , streamClosed , streamSendLock chan struct {}, stream acpb. AgentCommunication_StreamAgentMessagesClient ) {
274
348
loggerPrintf ("Receiving messages" )
275
349
for {
276
- resp , err := c . stream .Recv ()
350
+ resp , err := stream .Recv ()
277
351
if err != nil {
278
352
select {
279
353
case <- streamClosed :
@@ -284,6 +358,7 @@ func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
284
358
select {
285
359
case <- c .closed :
286
360
// Connection is closed, return now.
361
+ loggerPrintf ("Connection closed, recv returning" )
287
362
return
288
363
default :
289
364
}
@@ -315,7 +390,7 @@ func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
315
390
case * acpb.StreamAgentMessagesResponse_MessageBody :
316
391
// Acknowledge message first, if this ack fails dont forward the message on to the handling
317
392
// logic since that indicates a stream disconnect.
318
- if err := c .acknowledgeMessage (resp .GetMessageId (), streamClosed ); err != nil {
393
+ if err := c .acknowledgeMessage (resp .GetMessageId (), streamClosed , streamSendLock , stream ); err != nil {
319
394
loggerPrintf ("Error acknowledging message %q: %v" , resp .GetMessageId (), err )
320
395
continue
321
396
}
@@ -337,37 +412,36 @@ func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
337
412
}
338
413
}
339
414
340
- func ( c * Connection ) createStreamLoop (ctx context.Context ) error {
415
+ func createStreamLoop (ctx context.Context , client * agentcommunication. Client , resourceID string , channelID string ) (acpb. AgentCommunication_StreamAgentMessagesClient , error ) {
341
416
resourceExhaustedRetries := 0
342
417
unavailableRetries := 0
343
- var err error
344
418
for {
345
- c . stream , err = c . client .StreamAgentMessages (ctx )
419
+ stream , err := client .StreamAgentMessages (ctx )
346
420
if err != nil {
347
- return fmt .Errorf ("error creating stream: %v" , err )
421
+ return nil , fmt .Errorf ("error creating stream: %v" , err )
348
422
}
349
423
350
424
// RegisterConnection is a special message that must be sent before any other messages.
351
425
req := & acpb.StreamAgentMessagesRequest {
352
426
MessageId : uuid .New ().String (),
353
427
Type : & acpb.StreamAgentMessagesRequest_RegisterConnection {
354
- RegisterConnection : & acpb.RegisterConnection {ResourceId : c . resourceID , ChannelId : c . channelID }}}
428
+ RegisterConnection : & acpb.RegisterConnection {ResourceId : resourceID , ChannelId : channelID }}}
355
429
356
- if err := c . stream .Send (req ); err != nil {
357
- return fmt .Errorf ("error sending register connection: %v" , err )
430
+ if err := stream .Send (req ); err != nil {
431
+ return nil , fmt .Errorf ("error sending register connection: %v" , err )
358
432
}
359
433
360
434
// We expect the first message to be a MessageResponse.
361
- resp , err := c . stream .Recv ()
435
+ resp , err := stream .Recv ()
362
436
if err == nil {
363
437
switch resp .GetType ().(type ) {
364
438
case * acpb.StreamAgentMessagesResponse_MessageResponse :
365
439
if resp .GetMessageResponse ().GetStatus ().GetCode () != int32 (codes .OK ) {
366
- return fmt .Errorf ("unexpected register response: %+v" , resp .GetMessageResponse ().GetStatus ())
440
+ return nil , fmt .Errorf ("unexpected register response: %+v" , resp .GetMessageResponse ().GetStatus ())
367
441
}
368
442
}
369
443
// Stream is connected.
370
- return nil
444
+ return stream , nil
371
445
}
372
446
373
447
st , ok := status .FromError (err )
@@ -394,7 +468,7 @@ func (c *Connection) createStreamLoop(ctx context.Context) error {
394
468
}
395
469
loggerPrintf ("Stream returned Unavailable, exceeded max number of reconnects, closing connection: %v" , err )
396
470
}
397
- return err
471
+ return nil , err
398
472
}
399
473
}
400
474
@@ -416,15 +490,19 @@ func (c *Connection) createStream(ctx context.Context) error {
416
490
417
491
// Set a timeout for the stream, this is well above service side timeout.
418
492
cnclCtx , cancel := context .WithTimeout (ctx , 60 * time .Minute )
419
- if err := c .createStreamLoop (cnclCtx ); err != nil {
493
+ stream , err := createStreamLoop (cnclCtx , c .client , c .resourceID , c .channelID )
494
+ if err != nil {
420
495
cancel ()
421
496
c .close (err )
422
497
return err
423
498
}
424
499
500
+ // Used to signal that the stream is closed.
425
501
streamClosed := make (chan struct {})
426
- go c .recv (ctx , streamClosed )
427
- go c .send (streamClosed )
502
+ // This ensures that only one send is happening at a time.
503
+ streamSendLock := make (chan struct {}, 1 )
504
+ go c .recv (ctx , streamClosed , streamSendLock , stream )
505
+ go c .send (streamClosed , streamSendLock , stream )
428
506
429
507
go func () {
430
508
defer cancel ()
@@ -442,10 +520,49 @@ func (c *Connection) createStream(ctx context.Context) error {
442
520
return nil
443
521
}
444
522
523
+ // NewConnection creates a new streaming connection.
524
+ // Caller is responsible for calling Close() on the connection when done, certain errors will cause
525
+ // the connection to be closed automatically. The passed in client will not be closed and can be
526
+ // reused.
527
+ func NewConnection (ctx context.Context , channelID string , client * agentcommunication.Client ) (* Connection , error ) {
528
+ conn := & Connection {
529
+ channelID : channelID ,
530
+ closed : make (chan struct {}),
531
+ messages : make (chan * acpb.MessageBody ),
532
+ responseSubs : make (map [string ]chan * status.Status ),
533
+ streamReady : make (chan struct {}),
534
+ sends : make (chan * acpb.StreamAgentMessagesRequest ),
535
+ timeToWaitForResp : 2 * time .Second ,
536
+ client : client ,
537
+ callerManagedClient : true ,
538
+ }
539
+
540
+ zone , err := cm .Zone ()
541
+ if err != nil {
542
+ return nil , err
543
+ }
544
+ projectNum , err := cm .NumericProjectID ()
545
+ if err != nil {
546
+ return nil , err
547
+ }
548
+ instanceID , err := cm .InstanceID ()
549
+ if err != nil {
550
+ return nil , err
551
+ }
552
+ conn .resourceID = fmt .Sprintf ("projects/%s/zones/%s/instances/%s" , projectNum , zone , instanceID )
553
+
554
+ if err := conn .createStream (ctx ); err != nil {
555
+ conn .close (err )
556
+ return nil , err
557
+ }
558
+
559
+ return conn , nil
560
+ }
561
+
445
562
// CreateConnection creates a new connection.
563
+ // DEPRECATED: Use NewConnection instead.
446
564
func CreateConnection (ctx context.Context , channelID string , regional bool , opts ... option.ClientOption ) (* Connection , error ) {
447
565
conn := & Connection {
448
- regional : regional ,
449
566
channelID : channelID ,
450
567
closed : make (chan struct {}),
451
568
messages : make (chan * acpb.MessageBody ),
@@ -469,21 +586,7 @@ func CreateConnection(ctx context.Context, channelID string, regional bool, opts
469
586
}
470
587
conn .resourceID = fmt .Sprintf ("projects/%s/zones/%s/instances/%s" , projectNum , zone , instanceID )
471
588
472
- location := zone
473
- if conn .regional {
474
- location = location [:len (location )- 2 ]
475
- }
476
-
477
- defaultOpts := []option.ClientOption {
478
- option .WithoutAuthentication (), // Do not use oauth.
479
- option .WithGRPCDialOption (grpc .WithTransportCredentials (credentials .NewTLS (nil ))), // Because we disabled Auth we need to specifically enable TLS.
480
- option .WithGRPCDialOption (grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : 60 * time .Second , Timeout : 10 * time .Second })),
481
- option .WithEndpoint (fmt .Sprintf ("%s-agentcommunication.googleapis.com:443" , location )),
482
- }
483
-
484
- opts = append (defaultOpts , opts ... )
485
-
486
- conn .client , err = agentcommunication .NewClient (ctx , opts ... )
589
+ conn .client , err = NewClient (ctx , regional , opts ... )
487
590
if err != nil {
488
591
return nil , err
489
592
}
0 commit comments