Skip to content

Commit cf5b841

Browse files
adjackuracopybara-github
authored andcommitted
Refactor how clients are used, add new NewClient and NewConnection methods.
Add SendAgentMessage to acs client Use the FQDN to avoid unnecessary DNS lookups PiperOrigin-RevId: 725650044
1 parent 71bfb94 commit cf5b841

File tree

2 files changed

+272
-108
lines changed

2 files changed

+272
-108
lines changed

client.go

Lines changed: 151 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,82 @@ var (
5959
logger *log.Logger
6060
)
6161

62+
func loggerPrintf(format string, v ...any) {
63+
if DebugLogging {
64+
logger.Output(2, fmt.Sprintf(format, v...))
65+
}
66+
}
67+
68+
// NewClient creates a new agent communication grpc client.
69+
// Caller must close the returned client when it is done being used to clean up its underlying
70+
// connections.
71+
func NewClient(ctx context.Context, regional bool, opts ...option.ClientOption) (*agentcommunication.Client, error) {
72+
zone, err := cm.Zone()
73+
if err != nil {
74+
return nil, err
75+
}
76+
77+
location := zone
78+
if regional {
79+
location = location[:len(location)-2]
80+
}
81+
82+
defaultOpts := []option.ClientOption{
83+
option.WithoutAuthentication(), // Do not use oauth.
84+
option.WithGRPCDialOption(grpc.WithTransportCredentials(credentials.NewTLS(nil))), // Because we disabled Auth we need to specifically enable TLS.
85+
option.WithGRPCDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: 60 * time.Second, Timeout: 10 * time.Second})),
86+
// Use the FQDN to avoid unnecessary DNS lookups.
87+
option.WithEndpoint(fmt.Sprintf("%s-agentcommunication.googleapis.com.:443", location)),
88+
}
89+
90+
opts = append(defaultOpts, opts...)
91+
return agentcommunication.NewClient(ctx, opts...)
92+
}
93+
94+
// SendAgentMessage sends a message to the client. This is equivalent to sending a message via
95+
// StreamAgentMessages with a single message and waiting for the response.
96+
func SendAgentMessage(ctx context.Context, channelID string, client *agentcommunication.Client, msg *acpb.MessageBody) (*acpb.SendAgentMessageResponse, error) {
97+
loggerPrintf("SendAgentMessage")
98+
zone, err := cm.Zone()
99+
if err != nil {
100+
return nil, err
101+
}
102+
projectNum, err := cm.NumericProjectID()
103+
if err != nil {
104+
return nil, err
105+
}
106+
instanceID, err := cm.InstanceID()
107+
if err != nil {
108+
return nil, err
109+
}
110+
resourceID := fmt.Sprintf("projects/%s/zones/%s/instances/%s", projectNum, zone, instanceID)
111+
112+
token, err := cm.Get("instance/service-accounts/default/identity?audience=agentcommunication.googleapis.com&format=full")
113+
if err != nil {
114+
return nil, fmt.Errorf("%w: %v", ErrGettingInstanceToken, err)
115+
}
116+
117+
ctx = metadata.NewOutgoingContext(ctx, metadata.New(map[string]string{
118+
"authentication": "Bearer " + token,
119+
"agent-communication-resource-id": resourceID,
120+
"agent-communication-channel-id": channelID,
121+
}))
122+
123+
loggerPrintf("Using ResourceID %q", resourceID)
124+
loggerPrintf("Using ChannelID %q", channelID)
125+
126+
return client.SendAgentMessage(ctx, &acpb.SendAgentMessageRequest{
127+
ChannelId: channelID,
128+
ResourceId: resourceID,
129+
MessageBody: msg,
130+
})
131+
}
132+
62133
// Connection is an AgentCommunication connection.
63134
type Connection struct {
64135
client *agentcommunication.Client
65-
stream acpb.AgentCommunication_StreamAgentMessagesClient
136+
// Indicates that the client is caller managed and should not be closed.
137+
callerManagedClient bool
66138
// Indicates that the entire connection is closed and will not reopen.
67139
closed chan struct{}
68140
closeErr error
@@ -81,12 +153,6 @@ type Connection struct {
81153
timeToWaitForResp time.Duration
82154
}
83155

84-
func loggerPrintf(format string, v ...any) {
85-
if DebugLogging {
86-
logger.Output(2, fmt.Sprintf(format, v...))
87-
}
88-
}
89-
90156
// Close the connection.
91157
func (c *Connection) Close() {
92158
c.close(ErrConnectionClosed)
@@ -114,7 +180,9 @@ func (c *Connection) close(err error) {
114180
default:
115181
close(c.closed)
116182
c.setCloseErr(err)
117-
c.client.Close()
183+
if !c.callerManagedClient {
184+
c.client.Close()
185+
}
118186
}
119187
}
120188

@@ -212,13 +280,14 @@ func (c *Connection) Receive() (*acpb.MessageBody, error) {
212280
}
213281
}
214282

215-
func (c *Connection) streamSend(req *acpb.StreamAgentMessagesRequest, streamClosed chan struct{}) error {
283+
func (c *Connection) streamSend(req *acpb.StreamAgentMessagesRequest, streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) error {
216284
select {
217285
case <-streamClosed:
218286
return errors.New("stream closed")
219-
default:
287+
case streamSendLock <- struct{}{}:
288+
defer func() { <-streamSendLock }()
220289
}
221-
if err := c.stream.Send(req); err != nil {
290+
if err := stream.Send(req); err != nil {
222291
if err != io.EOF && !errors.Is(err, io.EOF) {
223292
// Something is very broken, just close the stream here.
224293
loggerPrintf("Unexpected send error, closing connection: %v", err)
@@ -241,11 +310,16 @@ func (c *Connection) streamSend(req *acpb.StreamAgentMessagesRequest, streamClos
241310
return nil
242311
}
243312

244-
func (c *Connection) send(streamClosed chan struct{}) {
313+
func (c *Connection) send(streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) {
314+
defer func() {
315+
// Lock the stream sends so we can close the stream.
316+
streamSendLock <- struct{}{}
317+
stream.CloseSend()
318+
}()
245319
for {
246320
select {
247321
case req := <-c.sends:
248-
if err := c.streamSend(req, streamClosed); err != nil {
322+
if err := c.streamSend(req, streamClosed, streamSendLock, stream); err != nil {
249323
return
250324
}
251325
case <-c.closed:
@@ -256,7 +330,7 @@ func (c *Connection) send(streamClosed chan struct{}) {
256330
}
257331
}
258332

259-
func (c *Connection) acknowledgeMessage(messageID string, streamClosed chan struct{}) error {
333+
func (c *Connection) acknowledgeMessage(messageID string, streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) error {
260334
ackReq := &acpb.StreamAgentMessagesRequest{
261335
MessageId: messageID,
262336
Type: &acpb.StreamAgentMessagesRequest_MessageResponse{},
@@ -265,15 +339,15 @@ func (c *Connection) acknowledgeMessage(messageID string, streamClosed chan stru
265339
case <-c.closed:
266340
return fmt.Errorf("connection closed with err: %w", c.closeErr)
267341
default:
268-
return c.streamSend(ackReq, streamClosed)
342+
return c.streamSend(ackReq, streamClosed, streamSendLock, stream)
269343
}
270344
}
271345

272346
// recv keeps receiving and acknowledging new messages.
273-
func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
347+
func (c *Connection) recv(ctx context.Context, streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) {
274348
loggerPrintf("Receiving messages")
275349
for {
276-
resp, err := c.stream.Recv()
350+
resp, err := stream.Recv()
277351
if err != nil {
278352
select {
279353
case <-streamClosed:
@@ -284,6 +358,7 @@ func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
284358
select {
285359
case <-c.closed:
286360
// Connection is closed, return now.
361+
loggerPrintf("Connection closed, recv returning")
287362
return
288363
default:
289364
}
@@ -315,7 +390,7 @@ func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
315390
case *acpb.StreamAgentMessagesResponse_MessageBody:
316391
// Acknowledge message first, if this ack fails dont forward the message on to the handling
317392
// logic since that indicates a stream disconnect.
318-
if err := c.acknowledgeMessage(resp.GetMessageId(), streamClosed); err != nil {
393+
if err := c.acknowledgeMessage(resp.GetMessageId(), streamClosed, streamSendLock, stream); err != nil {
319394
loggerPrintf("Error acknowledging message %q: %v", resp.GetMessageId(), err)
320395
continue
321396
}
@@ -337,37 +412,36 @@ func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
337412
}
338413
}
339414

340-
func (c *Connection) createStreamLoop(ctx context.Context) error {
415+
func createStreamLoop(ctx context.Context, client *agentcommunication.Client, resourceID string, channelID string) (acpb.AgentCommunication_StreamAgentMessagesClient, error) {
341416
resourceExhaustedRetries := 0
342417
unavailableRetries := 0
343-
var err error
344418
for {
345-
c.stream, err = c.client.StreamAgentMessages(ctx)
419+
stream, err := client.StreamAgentMessages(ctx)
346420
if err != nil {
347-
return fmt.Errorf("error creating stream: %v", err)
421+
return nil, fmt.Errorf("error creating stream: %v", err)
348422
}
349423

350424
// RegisterConnection is a special message that must be sent before any other messages.
351425
req := &acpb.StreamAgentMessagesRequest{
352426
MessageId: uuid.New().String(),
353427
Type: &acpb.StreamAgentMessagesRequest_RegisterConnection{
354-
RegisterConnection: &acpb.RegisterConnection{ResourceId: c.resourceID, ChannelId: c.channelID}}}
428+
RegisterConnection: &acpb.RegisterConnection{ResourceId: resourceID, ChannelId: channelID}}}
355429

356-
if err := c.stream.Send(req); err != nil {
357-
return fmt.Errorf("error sending register connection: %v", err)
430+
if err := stream.Send(req); err != nil {
431+
return nil, fmt.Errorf("error sending register connection: %v", err)
358432
}
359433

360434
// We expect the first message to be a MessageResponse.
361-
resp, err := c.stream.Recv()
435+
resp, err := stream.Recv()
362436
if err == nil {
363437
switch resp.GetType().(type) {
364438
case *acpb.StreamAgentMessagesResponse_MessageResponse:
365439
if resp.GetMessageResponse().GetStatus().GetCode() != int32(codes.OK) {
366-
return fmt.Errorf("unexpected register response: %+v", resp.GetMessageResponse().GetStatus())
440+
return nil, fmt.Errorf("unexpected register response: %+v", resp.GetMessageResponse().GetStatus())
367441
}
368442
}
369443
// Stream is connected.
370-
return nil
444+
return stream, nil
371445
}
372446

373447
st, ok := status.FromError(err)
@@ -394,7 +468,7 @@ func (c *Connection) createStreamLoop(ctx context.Context) error {
394468
}
395469
loggerPrintf("Stream returned Unavailable, exceeded max number of reconnects, closing connection: %v", err)
396470
}
397-
return err
471+
return nil, err
398472
}
399473
}
400474

@@ -416,15 +490,19 @@ func (c *Connection) createStream(ctx context.Context) error {
416490

417491
// Set a timeout for the stream, this is well above service side timeout.
418492
cnclCtx, cancel := context.WithTimeout(ctx, 60*time.Minute)
419-
if err := c.createStreamLoop(cnclCtx); err != nil {
493+
stream, err := createStreamLoop(cnclCtx, c.client, c.resourceID, c.channelID)
494+
if err != nil {
420495
cancel()
421496
c.close(err)
422497
return err
423498
}
424499

500+
// Used to signal that the stream is closed.
425501
streamClosed := make(chan struct{})
426-
go c.recv(ctx, streamClosed)
427-
go c.send(streamClosed)
502+
// This ensures that only one send is happening at a time.
503+
streamSendLock := make(chan struct{}, 1)
504+
go c.recv(ctx, streamClosed, streamSendLock, stream)
505+
go c.send(streamClosed, streamSendLock, stream)
428506

429507
go func() {
430508
defer cancel()
@@ -442,10 +520,49 @@ func (c *Connection) createStream(ctx context.Context) error {
442520
return nil
443521
}
444522

523+
// NewConnection creates a new streaming connection.
524+
// Caller is responsible for calling Close() on the connection when done, certain errors will cause
525+
// the connection to be closed automatically. The passed in client will not be closed and can be
526+
// reused.
527+
func NewConnection(ctx context.Context, channelID string, client *agentcommunication.Client) (*Connection, error) {
528+
conn := &Connection{
529+
channelID: channelID,
530+
closed: make(chan struct{}),
531+
messages: make(chan *acpb.MessageBody),
532+
responseSubs: make(map[string]chan *status.Status),
533+
streamReady: make(chan struct{}),
534+
sends: make(chan *acpb.StreamAgentMessagesRequest),
535+
timeToWaitForResp: 2 * time.Second,
536+
client: client,
537+
callerManagedClient: true,
538+
}
539+
540+
zone, err := cm.Zone()
541+
if err != nil {
542+
return nil, err
543+
}
544+
projectNum, err := cm.NumericProjectID()
545+
if err != nil {
546+
return nil, err
547+
}
548+
instanceID, err := cm.InstanceID()
549+
if err != nil {
550+
return nil, err
551+
}
552+
conn.resourceID = fmt.Sprintf("projects/%s/zones/%s/instances/%s", projectNum, zone, instanceID)
553+
554+
if err := conn.createStream(ctx); err != nil {
555+
conn.close(err)
556+
return nil, err
557+
}
558+
559+
return conn, nil
560+
}
561+
445562
// CreateConnection creates a new connection.
563+
// DEPRECATED: Use NewConnection instead.
446564
func CreateConnection(ctx context.Context, channelID string, regional bool, opts ...option.ClientOption) (*Connection, error) {
447565
conn := &Connection{
448-
regional: regional,
449566
channelID: channelID,
450567
closed: make(chan struct{}),
451568
messages: make(chan *acpb.MessageBody),
@@ -469,21 +586,7 @@ func CreateConnection(ctx context.Context, channelID string, regional bool, opts
469586
}
470587
conn.resourceID = fmt.Sprintf("projects/%s/zones/%s/instances/%s", projectNum, zone, instanceID)
471588

472-
location := zone
473-
if conn.regional {
474-
location = location[:len(location)-2]
475-
}
476-
477-
defaultOpts := []option.ClientOption{
478-
option.WithoutAuthentication(), // Do not use oauth.
479-
option.WithGRPCDialOption(grpc.WithTransportCredentials(credentials.NewTLS(nil))), // Because we disabled Auth we need to specifically enable TLS.
480-
option.WithGRPCDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: 60 * time.Second, Timeout: 10 * time.Second})),
481-
option.WithEndpoint(fmt.Sprintf("%s-agentcommunication.googleapis.com:443", location)),
482-
}
483-
484-
opts = append(defaultOpts, opts...)
485-
486-
conn.client, err = agentcommunication.NewClient(ctx, opts...)
589+
conn.client, err = NewClient(ctx, regional, opts...)
487590
if err != nil {
488591
return nil, err
489592
}

0 commit comments

Comments
 (0)