diff --git a/README.md b/README.md index ef77c41b..6787530b 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,8 @@ func connect() { // ... etc } ``` -### Using DNS to identify an instance + +### Using DNS domain names to identify instances The connector can be configured to use DNS to look up an instance. This would allow you to configure your application to connect to a database instance, and @@ -292,6 +293,41 @@ func connect() { } ``` +### Automatic fail-over using DNS domain names + +When the connector is configured using a domain name, the connector will +periodically check if the DNS record for an instance changes. When the connector +detects that the domain name refers to a different instance, the connector will +close all open connections to the old instance. Subsequent connection attempts +will be directed to the new instance. + +For example: suppose application is configured to connect using the +domain name `prod-db.mycompany.example.com`. Initially the corporate DNS +zone has a TXT record with the value `my-project:region:my-instance`. The +application establishes connections to the `my-project:region:my-instance` +Cloud SQL instance. + +Then, to reconfigure the application using a different database +instance: `my-project:other-region:my-instance-2`. You update the DNS record +for `prod-db.mycompany.example.com` with the target +`my-project:other-region:my-instance-2` + +The connector inside the application detects the change to this +DNS entry. Now, when the application connects to its database using the +domain name `prod-db.mycompany.example.com`, it will connect to the +`my-project:other-region:my-instance-2` Cloud SQL instance. + +The connector will automatically close all existing connections to +`my-project:region:my-instance`. This will force the connection pools to +establish new connections. Also, it may cause database queries in progress +to fail. + +The connector will poll for changes to the DNS name every 30 seconds by default. +You may configure the frequency of the connections using the option +`WithFailoverPeriod(d time.Duration)`. When this is set to 0, the connector will +disable polling and only check if the DNS record changed when it is +creating a new connection. + ### Using Options diff --git a/dialer.go b/dialer.go index 8a62a997..5c5607de 100644 --- a/dialer.go +++ b/dialer.go @@ -110,20 +110,12 @@ type connectionInfoCache interface { io.Closer } -// monitoredCache is a wrapper around a connectionInfoCache that tracks the -// number of connections to the associated instance. -type monitoredCache struct { - openConns *uint64 - - connectionInfoCache -} - // A Dialer is used to create connections to Cloud SQL instances. // // Use NewDialer to initialize a Dialer. type Dialer struct { lock sync.RWMutex - cache map[instance.ConnName]monitoredCache + cache map[instance.ConnName]*monitoredCache keyGenerator *keyGenerator refreshTimeout time.Duration // closed reports if the dialer has been closed. @@ -155,7 +147,8 @@ type Dialer struct { iamTokenSource oauth2.TokenSource // resolver converts instance names into DNS names. - resolver instance.ConnectionNameResolver + resolver instance.ConnectionNameResolver + failoverPeriod time.Duration } var ( @@ -179,6 +172,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { logger: nullLogger{}, useragents: []string{userAgent}, serviceUniverse: "googleapis.com", + failoverPeriod: cloudsql.FailoverPeriod, } for _, opt := range opts { opt(cfg) @@ -192,6 +186,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN { return nil, errUseTokenSource } + // Add this to the end to make sure it's not overridden cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " "))) @@ -263,7 +258,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { d := &Dialer{ closed: make(chan struct{}), - cache: make(map[instance.ConnName]monitoredCache), + cache: make(map[instance.ConnName]*monitoredCache), lazyRefresh: cfg.lazyRefresh, keyGenerator: g, refreshTimeout: cfg.refreshTimeout, @@ -274,7 +269,9 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { iamTokenSource: cfg.iamLoginTokenSource, dialFunc: cfg.dialFunc, resolver: r, + failoverPeriod: cfg.failoverPeriod, } + return d, nil } @@ -380,15 +377,24 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn latency := time.Since(startTime).Milliseconds() go func() { - n := atomic.AddUint64(c.openConns, 1) + n := atomic.AddUint64(c.openConnsCount, 1) trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String()) trace.RecordDialLatency(ctx, icn, d.dialerID, latency) }() - return newInstrumentedConn(tlsConn, func() { - n := atomic.AddUint64(c.openConns, ^uint64(0)) + iConn := newInstrumentedConn(tlsConn, func() { + n := atomic.AddUint64(c.openConnsCount, ^uint64(0)) trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String()) - }, d.dialerID, cn.String()), nil + }, d.dialerID, cn.String()) + + // If this connection was opened using a Domain Name, then store it for later + // in case it needs to be forcibly closed. + if cn.DomainName() != "" { + c.mu.Lock() + c.openConns = append(c.openConns, iConn) + c.mu.Unlock() + } + return iConn, nil } // removeCached stops all background refreshes and deletes the connection @@ -448,7 +454,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) } ci, err := c.ConnectionInfo(ctx) if err != nil { - d.removeCached(ctx, cn, c, err) + d.removeCached(ctx, cn, c.connectionInfoCache, err) return "", err } return ci.DBVersion, nil @@ -472,7 +478,7 @@ func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) err } _, err = c.ConnectionInfo(ctx) if err != nil { - d.removeCached(ctx, cn, c, err) + d.removeCached(ctx, cn, c.connectionInfoCache, err) } return err } @@ -493,6 +499,8 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str type instrumentedConn struct { net.Conn closeFunc func() + mu sync.RWMutex + closed bool dialerID string connName string } @@ -517,9 +525,19 @@ func (i *instrumentedConn) Write(b []byte) (int, error) { return bytesWritten, err } +// isClosed returns true if this connection is closing or is already closed. +func (i *instrumentedConn) isClosed() bool { + i.mu.RLock() + defer i.mu.RUnlock() + return i.closed +} + // Close delegates to the underlying net.Conn interface and reports the close // to the provided closeFunc only when Close returns no error. func (i *instrumentedConn) Close() error { + i.mu.Lock() + defer i.mu.Unlock() + i.closed = true err := i.Conn.Close() if err != nil { return err @@ -551,50 +569,104 @@ func (d *Dialer) Close() error { // modify the existing one, or leave it unchanged as needed. func (d *Dialer) connectionInfoCache( ctx context.Context, cn instance.ConnName, useIAMAuthN *bool, -) (monitoredCache, error) { +) (*monitoredCache, error) { d.lock.RLock() c, ok := d.cache[cn] d.lock.RUnlock() - if !ok { - d.lock.Lock() - defer d.lock.Unlock() - // Recheck to ensure instance wasn't created or changed between locks - c, ok = d.cache[cn] - if !ok { - var useIAMAuthNDial bool - if useIAMAuthN != nil { - useIAMAuthNDial = *useIAMAuthN - } - d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String()) - k, err := d.keyGenerator.rsaKey() - if err != nil { - return monitoredCache{}, err - } - var cache connectionInfoCache - if d.lazyRefresh { - cache = cloudsql.NewLazyRefreshCache( - cn, - d.logger, - d.sqladmin, k, - d.refreshTimeout, d.iamTokenSource, - d.dialerID, useIAMAuthNDial, - ) - } else { - cache = cloudsql.NewRefreshAheadCache( - cn, - d.logger, - d.sqladmin, k, - d.refreshTimeout, d.iamTokenSource, - d.dialerID, useIAMAuthNDial, - ) - } - var count uint64 - c = monitoredCache{openConns: &count, connectionInfoCache: cache} - d.cache[cn] = c - } + + // recheck the domain name, this may close the cache. + if ok { + c.checkDomainName(ctx) + } + + if ok && !c.isClosed() { + c.UpdateRefresh(useIAMAuthN) + return c, nil } - c.UpdateRefresh(useIAMAuthN) + d.lock.Lock() + defer d.lock.Unlock() + + // Recheck to ensure instance wasn't created or changed between locks + c, ok = d.cache[cn] + + // c exists and is not closed + if ok && !c.isClosed() { + c.UpdateRefresh(useIAMAuthN) + return c, nil + } + + // c exists and is closed, remove it from the cache + if ok { + // remove it. + _ = c.Close() + delete(d.cache, cn) + } + + // c does not exist, check for matching domain and close it + oldCn, old, ok := d.findByDn(cn) + if ok { + _ = old.Close() + delete(d.cache, oldCn) + } + + // Create a new instance of monitoredCache + var useIAMAuthNDial bool + if useIAMAuthN != nil { + useIAMAuthNDial = *useIAMAuthN + } + d.logger.Debugf(ctx, "[%v] Connection info added to cache", cn.String()) + k, err := d.keyGenerator.rsaKey() + if err != nil { + return nil, err + } + var cache connectionInfoCache + if d.lazyRefresh { + cache = cloudsql.NewLazyRefreshCache( + cn, + d.logger, + d.sqladmin, k, + d.refreshTimeout, d.iamTokenSource, + d.dialerID, useIAMAuthNDial, + ) + } else { + cache = cloudsql.NewRefreshAheadCache( + cn, + d.logger, + d.sqladmin, k, + d.refreshTimeout, d.iamTokenSource, + d.dialerID, useIAMAuthNDial, + ) + } + c = newMonitoredCache(ctx, cache, cn, d.failoverPeriod, d.resolver, d.logger) + d.cache[cn] = c return c, nil } + +// getOrAdd returns the cache entry, creating it if necessary. This will also +// take care to remove entries with the same domain name. +// +// cn - the connection name to getOrAdd +// +// returns: +// +// monitoredCache - the cached entry +// bool ok - the instance exists +// instance.ConnName - the key to the old entry with the same domain name +// +// This method does not manage locks. +func (d *Dialer) findByDn(cn instance.ConnName) (instance.ConnName, *monitoredCache, bool) { + + // Try to get an instance with the same domain name but different instance + // Remove this instance from the cache, it will be replaced. + if cn.HasDomainName() { + for oldCn, oc := range d.cache { + if oldCn.DomainName() == cn.DomainName() && oldCn != cn { + return oldCn, oc, true + } + } + } + + return instance.ConnName{}, nil, false +} diff --git a/dialer_test.go b/dialer_test.go index 2e640af4..b337eaf0 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -25,6 +25,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "testing" "time" @@ -476,9 +477,7 @@ func TestEngineVersionRemovesInvalidInstancesFromCache(t *testing.T) { spy := &spyConnectionInfoCache{ connectInfoCalls: []connectionInfoResp{tc.resp}, } - d.cache[inst] = monitoredCache{ - connectionInfoCache: spy, - } + d.cache[inst] = newMonitoredCache(nil, spy, inst, 0, nil, nil) _, err = d.EngineVersion(context.Background(), tc.icn) if err == nil { @@ -626,9 +625,7 @@ func TestWarmupRemovesInvalidInstancesFromCache(t *testing.T) { spy := &spyConnectionInfoCache{ connectInfoCalls: []connectionInfoResp{tc.resp}, } - d.cache[inst] = monitoredCache{ - connectionInfoCache: spy, - } + d.cache[inst] = newMonitoredCache(nil, spy, inst, 0, nil, nil) err = d.Warmup(context.Background(), tc.icn, tc.opts...) if err == nil { @@ -802,9 +799,7 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) { spy := &spyConnectionInfoCache{ connectInfoCalls: []connectionInfoResp{tc.resp}, } - d.cache[inst] = monitoredCache{ - connectionInfoCache: spy, - } + d.cache[inst] = newMonitoredCache(nil, spy, inst, 0, nil, nil) _, err = d.Dial(context.Background(), tc.icn, tc.opts...) if err == nil { @@ -854,7 +849,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) { }, }, } - d.cache[cn] = monitoredCache{connectionInfoCache: spy} + d.cache[cn] = newMonitoredCache(nil, spy, cn, 0, nil, nil) _, err = d.Dial(context.Background(), icn) if !errors.Is(err, sentinel) { @@ -1028,16 +1023,13 @@ func TestDialerInitializesLazyCache(t *testing.T) { } type fakeResolver struct { - domainName string - instanceName instance.ConnName + entries map[string]instance.ConnName } func (r *fakeResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) { - // For TestDialerSuccessfullyDialsDnsTxtRecord - if name == r.domainName { - return r.instanceName, nil + if val, ok := r.entries[name]; ok { + return val, nil } - // TestDialerFailsDnsTxtRecordMissing return instance.ConnName{}, fmt.Errorf("no resolution for %q", name) } @@ -1045,18 +1037,23 @@ func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) { inst := mock.NewFakeCSQLInstance( "my-project", "my-region", "my-instance", ) - wantName, _ := instance.ParseConnName("my-project:my-region:my-instance") + wantName, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com") + wantName2, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db2.example.com") + // This will create 2 separate connectionInfoCache entries, one for + // each DNS name. d := setupDialer(t, setupConfig{ testInstance: inst, reqs: []*mock.Request{ - mock.InstanceGetSuccess(inst, 1), - mock.CreateEphemeralSuccess(inst, 1), + mock.InstanceGetSuccess(inst, 2), + mock.CreateEphemeralSuccess(inst, 2), }, dialerOptions: []Option{ WithTokenSource(mock.EmptyTokenSource{}), WithResolver(&fakeResolver{ - domainName: "db.example.com", - instanceName: wantName, + entries: map[string]instance.ConnName{ + "db.example.com": wantName, + "db2.example.com": wantName2, + }, }), }, }) @@ -1065,6 +1062,10 @@ func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) { context.Background(), t, d, "db.example.com", ) + testSuccessfulDial( + context.Background(), t, d, + "db2.example.com", + ) } func TestDialerFailsDnsTxtRecordMissing(t *testing.T) { @@ -1085,3 +1086,141 @@ func TestDialerFailsDnsTxtRecordMissing(t *testing.T) { t.Fatalf("want = %v, got = %v", wantMsg, err) } } + +type changingResolver struct { + stage *int32 +} + +func (r *changingResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) { + // For TestDialerFailoverOnInstanceChange + if name == "update.example.com" { + if atomic.LoadInt32(r.stage) == 0 { + return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") + } + return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance2", "update.example.com") + } + // TestDialerFailsDnsSrvRecordMissing + return instance.ConnName{}, fmt.Errorf("no resolution for %q", name) +} + +func TestDialerUpdatesOnDialAfterDnsChange(t *testing.T) { + // At first, the resolver will resolve + // update.example.com to "my-instance" + // Then, the resolver will resolve the same domain name to + // "my-instance2". + // This shows that on every call to Dial(), the dialer will resolve the + // SRV record and connect to the correct instance. + inst := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance", + ) + inst2 := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance2", + ) + r := &changingResolver{ + stage: new(int32), + } + + d := setupDialer(t, setupConfig{ + skipServer: true, + reqs: []*mock.Request{ + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + mock.InstanceGetSuccess(inst2, 1), + mock.CreateEphemeralSuccess(inst2, 1), + }, + dialerOptions: []Option{ + WithResolver(r), + WithTokenSource(mock.EmptyTokenSource{}), + }, + }) + + // Start the proxy for instance 1 + stop1 := mock.StartServerProxy(t, inst) + t.Cleanup(func() { + stop1() + }) + + testSuccessfulDial( + context.Background(), t, d, + "update.example.com", + ) + stop1() + + atomic.StoreInt32(r.stage, 1) + + // Start the proxy for instance 2 + stop2 := mock.StartServerProxy(t, inst2) + t.Cleanup(func() { + stop2() + }) + + testSucessfulDialWithInstanceName( + context.Background(), t, d, + "update.example.com", "my-instance2", + ) +} + +func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) { + // At first, the resolver will resolve + // update.example.com to "my-instance" + // Then, the resolver will resolve the same domain name to + // "my-instance2". + // This shows that on every call to Dial(), the dialer will resolve the + // SRV record and connect to the correct instance. + inst := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance", + ) + inst2 := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance2", + ) + r := &changingResolver{ + stage: new(int32), + } + + d := setupDialer(t, setupConfig{ + skipServer: true, + reqs: []*mock.Request{ + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + mock.InstanceGetSuccess(inst2, 1), + mock.CreateEphemeralSuccess(inst2, 1), + }, + dialerOptions: []Option{ + WithFailoverPeriod(10 * time.Millisecond), + WithResolver(r), + WithTokenSource(mock.EmptyTokenSource{}), + }, + }) + + // Start the proxy for instance 1 + stop1 := mock.StartServerProxy(t, inst) + t.Cleanup(func() { + stop1() + }) + + testSuccessfulDial( + context.Background(), t, d, + "update.example.com", + ) + stop1() + atomic.StoreInt32(r.stage, 1) + + time.Sleep(1 * time.Second) + instCn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") + c, _ := d.cache[instCn] + if !c.isClosed() { + t.Fatal("Expected monitoredCache to be closed after domain name changed. monitoredCache was not closed.") + } + + // Start the proxy for instance 2 + stop2 := mock.StartServerProxy(t, inst2) + t.Cleanup(func() { + stop2() + }) + + testSucessfulDialWithInstanceName( + context.Background(), t, d, + "update.example.com", "my-instance2", + ) + +} diff --git a/instance/conn_name.go b/instance/conn_name.go index 2dd3de73..ab2fa956 100644 --- a/instance/conn_name.go +++ b/instance/conn_name.go @@ -32,9 +32,10 @@ var ( // ConnName represents the "instance connection name", in the format // "project:region:name". type ConnName struct { - project string - region string - name string + project string + region string + name string + domainName string } func (c *ConnName) String() string { @@ -56,8 +57,24 @@ func (c *ConnName) Name() string { return c.name } +// DomainName returns the domain name for this instance +func (c *ConnName) DomainName() string { + return c.domainName +} + +// HasDomainName returns the Cloud SQL domain name +func (c *ConnName) HasDomainName() bool { + return c.domainName != "" +} + // ParseConnName initializes a new ConnName struct. func ParseConnName(cn string) (ConnName, error) { + return ParseConnNameWithDomainName(cn, "") +} + +// ParseConnNameWithDomainName initializes a new ConnName struct, +// also setting the domain name. +func ParseConnNameWithDomainName(cn string, dn string) (ConnName, error) { b := []byte(cn) m := connNameRegex.FindSubmatch(b) if m == nil { @@ -69,9 +86,10 @@ func ParseConnName(cn string) (ConnName, error) { } c := ConnName{ - project: string(m[1]), - region: string(m[3]), - name: string(m[4]), + project: string(m[1]), + region: string(m[3]), + name: string(m[4]), + domainName: dn, } return c, nil } diff --git a/instance/conn_name_test.go b/instance/conn_name_test.go index 315dec4d..e07f759a 100644 --- a/instance/conn_name_test.go +++ b/instance/conn_name_test.go @@ -23,11 +23,11 @@ func TestParseConnName(t *testing.T) { }{ { "project:region:instance", - ConnName{"project", "region", "instance"}, + ConnName{project: "project", region: "region", name: "instance"}, }, { "google.com:project:region:instance", - ConnName{"google.com:project", "region", "instance"}, + ConnName{project: "google.com:project", region: "region", name: "instance"}, }, { "project:instance", // missing region diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index f8e44b3b..bc25e672 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -45,6 +45,11 @@ const ( // refreshInterval. RefreshTimeout = 60 * time.Second + // FailoverPeriod is the frequency with which the dialer will check + // if the DNS record has changed for connections configured using + // a DNS name. + FailoverPeriod = 30 * time.Second + // refreshBurst is the initial burst allowed by the rate limiter. refreshBurst = 2 ) diff --git a/monitored_cache.go b/monitored_cache.go new file mode 100644 index 00000000..7f17f3c0 --- /dev/null +++ b/monitored_cache.go @@ -0,0 +1,151 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsqlconn + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "cloud.google.com/go/cloudsqlconn/debug" + "cloud.google.com/go/cloudsqlconn/instance" + "cloud.google.com/go/cloudsqlconn/internal/cloudsql" +) + +// monitoredCache is a wrapper around a connectionInfoCache that tracks the +// number of connections to the associated instance. +type monitoredCache struct { + openConnsCount *uint64 + cn instance.ConnName + resolver instance.ConnectionNameResolver + logger debug.ContextLogger + + // domainNameTicker periodically checks any domain names to see if they + // changed. + domainNameTicker *time.Ticker + closedCh chan struct{} + + mu sync.Mutex + openConns []*instrumentedConn + closed bool + + connectionInfoCache +} + +func newMonitoredCache(ctx context.Context, cache connectionInfoCache, cn instance.ConnName, failoverPeriod time.Duration, resolver instance.ConnectionNameResolver, logger debug.ContextLogger) *monitoredCache { + c := &monitoredCache{ + openConnsCount: new(uint64), + closedCh: make(chan struct{}), + cn: cn, + resolver: resolver, + logger: logger, + connectionInfoCache: cache, + } + if cn.HasDomainName() { + c.domainNameTicker = time.NewTicker(failoverPeriod) + go func() { + for { + select { + case <-c.domainNameTicker.C: + c.purgeClosedConns() + c.checkDomainName(ctx) + case <-c.closedCh: + return + } + } + }() + + } + + return c +} +func (c *monitoredCache) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +func (c *monitoredCache) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil + } + + c.closed = true + close(c.closedCh) + + if c.domainNameTicker != nil { + c.domainNameTicker.Stop() + } + + if atomic.LoadUint64(c.openConnsCount) > 0 { + for _, socket := range c.openConns { + if !socket.isClosed() { + _ = socket.Close() // force socket closed, ok to ignore error. + } + } + atomic.StoreUint64(c.openConnsCount, 0) + } + + return c.connectionInfoCache.Close() +} + +func (c *monitoredCache) ForceRefresh() { + c.connectionInfoCache.ForceRefresh() +} + +func (c *monitoredCache) UpdateRefresh(b *bool) { + c.connectionInfoCache.UpdateRefresh(b) +} +func (c *monitoredCache) ConnectionInfo(ctx context.Context) (cloudsql.ConnectionInfo, error) { + return c.connectionInfoCache.ConnectionInfo(ctx) +} + +func (c *monitoredCache) purgeClosedConns() { + c.mu.Lock() + defer c.mu.Unlock() + + var open []*instrumentedConn + for _, s := range c.openConns { + if !s.isClosed() { + open = append(open, s) + } + } + c.openConns = open +} + +func (c *monitoredCache) checkDomainName(ctx context.Context) { + if !c.cn.HasDomainName() { + return + } + newCn, err := c.resolver.Resolve(ctx, c.cn.DomainName()) + if err != nil { + // The domain name could not be resolved. + c.logger.Debugf(ctx, "domain name %s for instance %s did not resolve, "+ + "closing all connections: %v", + c.cn.DomainName(), c.cn.Name(), err) + c.Close() + } + if newCn != c.cn { + // The instance changed. + c.logger.Debugf(ctx, "domain name %s changed from %s to %s, "+ + "closing all connections.", + c.cn.DomainName(), c.cn.Name(), newCn.Name()) + c.Close() + } + +} diff --git a/monitored_cache_test.go b/monitored_cache_test.go new file mode 100644 index 00000000..0fa42a58 --- /dev/null +++ b/monitored_cache_test.go @@ -0,0 +1,180 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsqlconn + +import ( + "context" + "net" + "net/netip" + "sync/atomic" + "testing" + "time" + + "cloud.google.com/go/cloudsqlconn/instance" +) + +type testLog struct { + t *testing.T +} + +func (l *testLog) Debugf(_ context.Context, f string, args ...interface{}) { + l.t.Logf(f, args...) +} + +func TestMonitoredCache_purgeClosedConns(t *testing.T) { + cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com") + c := newMonitoredCache(context.TODO(), + &spyConnectionInfoCache{}, + cn, + 10*time.Millisecond, + &fakeResolver{entries: map[string]instance.ConnName{"db.example.com": cn}}, + &testLog{t: t}, + ) + + // Add connections + c.mu.Lock() + c.openConns = []*instrumentedConn{ + &instrumentedConn{closed: false}, + &instrumentedConn{closed: true}, + } + c.mu.Unlock() + + // wait for the resolver to run + time.Sleep(100 * time.Millisecond) + c.mu.Lock() + if got := len(c.openConns); got != 1 { + t.Fatalf("got %d, want 1. Expected openConns to only contain open connections", got) + } + c.mu.Unlock() + +} + +func TestMonitoredCache_checkDomainName_instanceChanged(t *testing.T) { + cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") + r := &changingResolver{ + stage: new(int32), + } + c := newMonitoredCache(context.TODO(), + &spyConnectionInfoCache{}, + cn, + 10*time.Millisecond, + r, + &testLog{t: t}, + ) + + // Dont' change the instance yet. Check that the connection is open. + // wait for the resolver to run + time.Sleep(100 * time.Millisecond) + if c.isClosed() { + t.Fatal("got cache closed, want cache open") + } + // update the domain name + atomic.StoreInt32(r.stage, 1) + + // wait for the resolver to run + time.Sleep(100 * time.Millisecond) + if !c.isClosed() { + t.Fatal("got cache open, want cache closed") + } + +} + +func TestMonitoredCache_Close(t *testing.T) { + cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com") + var closeFuncCalls int32 + + r := &changingResolver{ + stage: new(int32), + } + + c := newMonitoredCache(context.TODO(), + &spyConnectionInfoCache{}, + cn, + 10*time.Millisecond, + r, + &testLog{t: t}, + ) + inc := func() { + atomic.AddInt32(&closeFuncCalls, 1) + } + + c.mu.Lock() + // set up the state as if there were 2 open connections. + c.openConns = []*instrumentedConn{ + { + closed: false, + closeFunc: inc, + Conn: &mockConn{}, + }, + { + closed: false, + closeFunc: inc, + Conn: &mockConn{}, + }, + { + closed: true, + closeFunc: inc, + Conn: &mockConn{}, + }, + } + *c.openConnsCount = 2 + c.mu.Unlock() + + c.Close() + if !c.isClosed() { + t.Fatal("got cache open, want cache closed") + } + // wait for closeFunc() to be called. + time.Sleep(100 * time.Millisecond) + if got := atomic.LoadInt32(&closeFuncCalls); got != 2 { + t.Fatalf("got %d, want 2", got) + } + +} + +type mockConn struct { +} + +func (m *mockConn) Read(_ []byte) (int, error) { + return 0, nil +} + +func (m *mockConn) Write(_ []byte) (int, error) { + return 0, nil +} + +func (m *mockConn) Close() error { + return nil +} + +func (m *mockConn) LocalAddr() net.Addr { + return net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:3307")) +} + +func (m *mockConn) RemoteAddr() net.Addr { + return net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:3307")) +} + +func (m *mockConn) SetDeadline(_ time.Time) error { + return nil +} + +func (m *mockConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (m *mockConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/options.go b/options.go index c21fcd2a..a719eca9 100644 --- a/options.go +++ b/options.go @@ -54,6 +54,7 @@ type dialerConfig struct { setTokenSource bool setIAMAuthNTokenSource bool resolver instance.ConnectionNameResolver + failoverPeriod time.Duration // err tracks any dialer options that may have failed. err error } @@ -271,6 +272,16 @@ func WithDNSResolver() Option { } } +// WithFailoverPeriod will cause the connector to periodically check the SRV DNS +// records of instance configured using DNS names. By default, this is 30 +// seconds. If this is set to 0, the connector will only check for domain name +// changes when establishing a new connection. +func WithFailoverPeriod(f time.Duration) Option { + return func(d *dialerConfig) { + d.failoverPeriod = f + } +} + type debugLoggerWithoutContext struct { logger debug.Logger }