Skip to content

Commit

Permalink
Refactor how clients are used, add new NewClient and NewConnection me…
Browse files Browse the repository at this point in the history
…thods.

Add SendAgentMessage to acs client
Use the FQDN to avoid unnecessary DNS lookups

PiperOrigin-RevId: 725650044
  • Loading branch information
adjackura authored and copybara-github committed Feb 11, 2025
1 parent 71bfb94 commit cf5b841
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 108 deletions.
199 changes: 151 additions & 48 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,82 @@ var (
logger *log.Logger
)

func loggerPrintf(format string, v ...any) {
if DebugLogging {
logger.Output(2, fmt.Sprintf(format, v...))
}
}

// NewClient creates a new agent communication grpc client.
// Caller must close the returned client when it is done being used to clean up its underlying
// connections.
func NewClient(ctx context.Context, regional bool, opts ...option.ClientOption) (*agentcommunication.Client, error) {
zone, err := cm.Zone()
if err != nil {
return nil, err
}

location := zone
if regional {
location = location[:len(location)-2]
}

defaultOpts := []option.ClientOption{
option.WithoutAuthentication(), // Do not use oauth.
option.WithGRPCDialOption(grpc.WithTransportCredentials(credentials.NewTLS(nil))), // Because we disabled Auth we need to specifically enable TLS.
option.WithGRPCDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: 60 * time.Second, Timeout: 10 * time.Second})),
// Use the FQDN to avoid unnecessary DNS lookups.
option.WithEndpoint(fmt.Sprintf("%s-agentcommunication.googleapis.com.:443", location)),
}

opts = append(defaultOpts, opts...)
return agentcommunication.NewClient(ctx, opts...)
}

// SendAgentMessage sends a message to the client. This is equivalent to sending a message via
// StreamAgentMessages with a single message and waiting for the response.
func SendAgentMessage(ctx context.Context, channelID string, client *agentcommunication.Client, msg *acpb.MessageBody) (*acpb.SendAgentMessageResponse, error) {
loggerPrintf("SendAgentMessage")
zone, err := cm.Zone()
if err != nil {
return nil, err
}
projectNum, err := cm.NumericProjectID()
if err != nil {
return nil, err
}
instanceID, err := cm.InstanceID()
if err != nil {
return nil, err
}
resourceID := fmt.Sprintf("projects/%s/zones/%s/instances/%s", projectNum, zone, instanceID)

token, err := cm.Get("instance/service-accounts/default/identity?audience=agentcommunication.googleapis.com&format=full")
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrGettingInstanceToken, err)
}

ctx = metadata.NewOutgoingContext(ctx, metadata.New(map[string]string{
"authentication": "Bearer " + token,
"agent-communication-resource-id": resourceID,
"agent-communication-channel-id": channelID,
}))

loggerPrintf("Using ResourceID %q", resourceID)
loggerPrintf("Using ChannelID %q", channelID)

return client.SendAgentMessage(ctx, &acpb.SendAgentMessageRequest{
ChannelId: channelID,
ResourceId: resourceID,
MessageBody: msg,
})
}

// Connection is an AgentCommunication connection.
type Connection struct {
client *agentcommunication.Client
stream acpb.AgentCommunication_StreamAgentMessagesClient
// Indicates that the client is caller managed and should not be closed.
callerManagedClient bool
// Indicates that the entire connection is closed and will not reopen.
closed chan struct{}
closeErr error
Expand All @@ -81,12 +153,6 @@ type Connection struct {
timeToWaitForResp time.Duration
}

func loggerPrintf(format string, v ...any) {
if DebugLogging {
logger.Output(2, fmt.Sprintf(format, v...))
}
}

// Close the connection.
func (c *Connection) Close() {
c.close(ErrConnectionClosed)
Expand Down Expand Up @@ -114,7 +180,9 @@ func (c *Connection) close(err error) {
default:
close(c.closed)
c.setCloseErr(err)
c.client.Close()
if !c.callerManagedClient {
c.client.Close()
}
}
}

Expand Down Expand Up @@ -212,13 +280,14 @@ func (c *Connection) Receive() (*acpb.MessageBody, error) {
}
}

func (c *Connection) streamSend(req *acpb.StreamAgentMessagesRequest, streamClosed chan struct{}) error {
func (c *Connection) streamSend(req *acpb.StreamAgentMessagesRequest, streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) error {
select {
case <-streamClosed:
return errors.New("stream closed")
default:
case streamSendLock <- struct{}{}:
defer func() { <-streamSendLock }()
}
if err := c.stream.Send(req); err != nil {
if err := stream.Send(req); err != nil {
if err != io.EOF && !errors.Is(err, io.EOF) {
// Something is very broken, just close the stream here.
loggerPrintf("Unexpected send error, closing connection: %v", err)
Expand All @@ -241,11 +310,16 @@ func (c *Connection) streamSend(req *acpb.StreamAgentMessagesRequest, streamClos
return nil
}

func (c *Connection) send(streamClosed chan struct{}) {
func (c *Connection) send(streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) {
defer func() {
// Lock the stream sends so we can close the stream.
streamSendLock <- struct{}{}
stream.CloseSend()
}()
for {
select {
case req := <-c.sends:
if err := c.streamSend(req, streamClosed); err != nil {
if err := c.streamSend(req, streamClosed, streamSendLock, stream); err != nil {
return
}
case <-c.closed:
Expand All @@ -256,7 +330,7 @@ func (c *Connection) send(streamClosed chan struct{}) {
}
}

func (c *Connection) acknowledgeMessage(messageID string, streamClosed chan struct{}) error {
func (c *Connection) acknowledgeMessage(messageID string, streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) error {
ackReq := &acpb.StreamAgentMessagesRequest{
MessageId: messageID,
Type: &acpb.StreamAgentMessagesRequest_MessageResponse{},
Expand All @@ -265,15 +339,15 @@ func (c *Connection) acknowledgeMessage(messageID string, streamClosed chan stru
case <-c.closed:
return fmt.Errorf("connection closed with err: %w", c.closeErr)
default:
return c.streamSend(ackReq, streamClosed)
return c.streamSend(ackReq, streamClosed, streamSendLock, stream)
}
}

// recv keeps receiving and acknowledging new messages.
func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
func (c *Connection) recv(ctx context.Context, streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) {
loggerPrintf("Receiving messages")
for {
resp, err := c.stream.Recv()
resp, err := stream.Recv()
if err != nil {
select {
case <-streamClosed:
Expand All @@ -284,6 +358,7 @@ func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
select {
case <-c.closed:
// Connection is closed, return now.
loggerPrintf("Connection closed, recv returning")
return
default:
}
Expand Down Expand Up @@ -315,7 +390,7 @@ func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
case *acpb.StreamAgentMessagesResponse_MessageBody:
// Acknowledge message first, if this ack fails dont forward the message on to the handling
// logic since that indicates a stream disconnect.
if err := c.acknowledgeMessage(resp.GetMessageId(), streamClosed); err != nil {
if err := c.acknowledgeMessage(resp.GetMessageId(), streamClosed, streamSendLock, stream); err != nil {
loggerPrintf("Error acknowledging message %q: %v", resp.GetMessageId(), err)
continue
}
Expand All @@ -337,37 +412,36 @@ func (c *Connection) recv(ctx context.Context, streamClosed chan struct{}) {
}
}

func (c *Connection) createStreamLoop(ctx context.Context) error {
func createStreamLoop(ctx context.Context, client *agentcommunication.Client, resourceID string, channelID string) (acpb.AgentCommunication_StreamAgentMessagesClient, error) {
resourceExhaustedRetries := 0
unavailableRetries := 0
var err error
for {
c.stream, err = c.client.StreamAgentMessages(ctx)
stream, err := client.StreamAgentMessages(ctx)
if err != nil {
return fmt.Errorf("error creating stream: %v", err)
return nil, fmt.Errorf("error creating stream: %v", err)
}

// RegisterConnection is a special message that must be sent before any other messages.
req := &acpb.StreamAgentMessagesRequest{
MessageId: uuid.New().String(),
Type: &acpb.StreamAgentMessagesRequest_RegisterConnection{
RegisterConnection: &acpb.RegisterConnection{ResourceId: c.resourceID, ChannelId: c.channelID}}}
RegisterConnection: &acpb.RegisterConnection{ResourceId: resourceID, ChannelId: channelID}}}

if err := c.stream.Send(req); err != nil {
return fmt.Errorf("error sending register connection: %v", err)
if err := stream.Send(req); err != nil {
return nil, fmt.Errorf("error sending register connection: %v", err)
}

// We expect the first message to be a MessageResponse.
resp, err := c.stream.Recv()
resp, err := stream.Recv()
if err == nil {
switch resp.GetType().(type) {
case *acpb.StreamAgentMessagesResponse_MessageResponse:
if resp.GetMessageResponse().GetStatus().GetCode() != int32(codes.OK) {
return fmt.Errorf("unexpected register response: %+v", resp.GetMessageResponse().GetStatus())
return nil, fmt.Errorf("unexpected register response: %+v", resp.GetMessageResponse().GetStatus())
}
}
// Stream is connected.
return nil
return stream, nil
}

st, ok := status.FromError(err)
Expand All @@ -394,7 +468,7 @@ func (c *Connection) createStreamLoop(ctx context.Context) error {
}
loggerPrintf("Stream returned Unavailable, exceeded max number of reconnects, closing connection: %v", err)
}
return err
return nil, err
}
}

Expand All @@ -416,15 +490,19 @@ func (c *Connection) createStream(ctx context.Context) error {

// Set a timeout for the stream, this is well above service side timeout.
cnclCtx, cancel := context.WithTimeout(ctx, 60*time.Minute)
if err := c.createStreamLoop(cnclCtx); err != nil {
stream, err := createStreamLoop(cnclCtx, c.client, c.resourceID, c.channelID)
if err != nil {
cancel()
c.close(err)
return err
}

// Used to signal that the stream is closed.
streamClosed := make(chan struct{})
go c.recv(ctx, streamClosed)
go c.send(streamClosed)
// This ensures that only one send is happening at a time.
streamSendLock := make(chan struct{}, 1)
go c.recv(ctx, streamClosed, streamSendLock, stream)
go c.send(streamClosed, streamSendLock, stream)

go func() {
defer cancel()
Expand All @@ -442,10 +520,49 @@ func (c *Connection) createStream(ctx context.Context) error {
return nil
}

// NewConnection creates a new streaming connection.
// Caller is responsible for calling Close() on the connection when done, certain errors will cause
// the connection to be closed automatically. The passed in client will not be closed and can be
// reused.
func NewConnection(ctx context.Context, channelID string, client *agentcommunication.Client) (*Connection, error) {
conn := &Connection{
channelID: channelID,
closed: make(chan struct{}),
messages: make(chan *acpb.MessageBody),
responseSubs: make(map[string]chan *status.Status),
streamReady: make(chan struct{}),
sends: make(chan *acpb.StreamAgentMessagesRequest),
timeToWaitForResp: 2 * time.Second,
client: client,
callerManagedClient: true,
}

zone, err := cm.Zone()
if err != nil {
return nil, err
}
projectNum, err := cm.NumericProjectID()
if err != nil {
return nil, err
}
instanceID, err := cm.InstanceID()
if err != nil {
return nil, err
}
conn.resourceID = fmt.Sprintf("projects/%s/zones/%s/instances/%s", projectNum, zone, instanceID)

if err := conn.createStream(ctx); err != nil {
conn.close(err)
return nil, err
}

return conn, nil
}

// CreateConnection creates a new connection.
// DEPRECATED: Use NewConnection instead.
func CreateConnection(ctx context.Context, channelID string, regional bool, opts ...option.ClientOption) (*Connection, error) {
conn := &Connection{
regional: regional,
channelID: channelID,
closed: make(chan struct{}),
messages: make(chan *acpb.MessageBody),
Expand All @@ -469,21 +586,7 @@ func CreateConnection(ctx context.Context, channelID string, regional bool, opts
}
conn.resourceID = fmt.Sprintf("projects/%s/zones/%s/instances/%s", projectNum, zone, instanceID)

location := zone
if conn.regional {
location = location[:len(location)-2]
}

defaultOpts := []option.ClientOption{
option.WithoutAuthentication(), // Do not use oauth.
option.WithGRPCDialOption(grpc.WithTransportCredentials(credentials.NewTLS(nil))), // Because we disabled Auth we need to specifically enable TLS.
option.WithGRPCDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: 60 * time.Second, Timeout: 10 * time.Second})),
option.WithEndpoint(fmt.Sprintf("%s-agentcommunication.googleapis.com:443", location)),
}

opts = append(defaultOpts, opts...)

conn.client, err = agentcommunication.NewClient(ctx, opts...)
conn.client, err = NewClient(ctx, regional, opts...)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit cf5b841

Please sign in to comment.