Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ffi/miri-tests/mock-glide-core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,8 @@ impl Client {
) -> RedisResult<Value> {
todo!()
}

pub async fn refresh_iam_token(&mut self) -> RedisResult<()> {
todo!()
}
}
40 changes: 40 additions & 0 deletions ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,46 @@ pub unsafe extern "C-unwind" fn update_connection_password(
})
}

/// Manually refresh the IAM authentication token.
///
/// This function triggers an immediate refresh of the IAM token and updates the connection.
/// It is only available if the client was created with IAM authentication.
///
/// # Parameters
///
/// * `client_adapter_ptr`: Pointer to a valid client returned from [`create_client`].
/// * `request_id`: Unique identifier for a valid payload buffer created in the calling language.
///
/// # Returns
///
/// * A pointer to a [`CommandResult`] containing "OK" on success, or an error if:
/// - The client is not using IAM authentication
/// - Token generation fails
/// - Authentication with the new token fails
///
/// # Safety
///
/// * `client_adapter_ptr` must not be `null` and must be obtained from the `ConnectionResponse` returned from [`create_client`].
/// * `client_adapter_ptr` must be able to be safely casted to a valid [`Arc<ClientAdapter>`] via [`Arc::from_raw`].
/// * `request_id` must be valid until it is passed in a call to [`free_command_response`].
/// * This function should only be called with a `client_adapter_ptr` created by [`create_client`], before [`close_client`] was called with the pointer.
#[unsafe(no_mangle)]
pub unsafe extern "C-unwind" fn refresh_iam_token(
client_adapter_ptr: *const c_void,
request_id: usize,
) -> *mut CommandResult {
let client_adapter = unsafe {
// we increment the strong count to ensure that the client is not dropped just because we turned it into an Arc.
Arc::increment_strong_count(client_adapter_ptr);
Arc::from_raw(client_adapter_ptr as *mut ClientAdapter)
};

let mut client = client_adapter.core.client.clone();
client_adapter.execute_request(request_id, async move {
client.refresh_iam_token().await.map(|_| Value::Okay)
})
}

/// Executes a Lua script.
///
/// # Parameters
Expand Down
87 changes: 87 additions & 0 deletions go/base_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,93 @@ func (client *baseClient) ResetConnectionPassword(ctx context.Context) (string,
return client.submitConnectionPasswordUpdate(ctx, "", false)
}

func (client *baseClient) submitRefreshIamToken(ctx context.Context) (string, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs documentation

// Check if context is already done
select {
case <-ctx.Done():
return models.DefaultStringResponse, ctx.Err()
default:
// Continue with execution
}

// Create a channel to receive the result
resultChannel := make(chan payload, 1)
resultChannelPtr := unsafe.Pointer(&resultChannel)

pinner := pinner{}
pinnedChannelPtr := uintptr(pinner.Pin(resultChannelPtr))
defer pinner.Unpin()

client.mu.Lock()
if client.coreClient == nil {
client.mu.Unlock()
return models.DefaultStringResponse, NewClosingError("RefreshIamToken failed. The client is closed.")
}
client.pending[resultChannelPtr] = struct{}{}

C.refresh_iam_token(
client.coreClient,
C.uintptr_t(pinnedChannelPtr),
)
client.mu.Unlock()

// Wait for result or context cancellation
var payload payload
select {
case <-ctx.Done():
client.mu.Lock()
if client.pending != nil {
delete(client.pending, resultChannelPtr)
}
client.mu.Unlock()
// Start cleanup goroutine
go func() {
// Wait for payload on separate channel
if payload := <-resultChannel; payload.value != nil {
C.free_command_response(payload.value)
}
}()
return models.DefaultStringResponse, ctx.Err()
case payload = <-resultChannel:
// Continue with normal processing
}

client.mu.Lock()
if client.pending != nil {
delete(client.pending, resultChannelPtr)
}
client.mu.Unlock()

if payload.error != nil {
return models.DefaultStringResponse, payload.error
}

return handleOkResponse(payload.value)
}

// RefreshIamToken manually refreshes the IAM token for the current connection.
//
// This method is only available if the client was created with IAM authentication.
// It triggers an immediate refresh of the IAM token and updates the connection.
//
// Parameters:
//
// ctx - The context for controlling the command execution.
//
// Return value:
//
// `"OK"` response on success.
//
// Example:
//
// result, err := client.RefreshIamToken(context.Background())
// if err != nil {
// // handle error
// }
func (client *baseClient) RefreshIamToken(ctx context.Context) (string, error) {
return client.submitRefreshIamToken(ctx)
}

// Set the given key with the given value. The return value is a response from Valkey containing the string "OK".
//
// See [valkey.io] for details.
Expand Down
100 changes: 97 additions & 3 deletions go/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,86 @@ func (addr *NodeAddress) toProtobuf() *protobuf.NodeAddress {
return &protobuf.NodeAddress{Host: addr.Host, Port: uint32(addr.Port)}
}

// ServiceType represents the types of AWS services that can be used for IAM authentication.
type ServiceType int

const (
// Elasticache represents Amazon ElastiCache service.
Elasticache ServiceType = iota
// MemoryDB represents Amazon MemoryDB service.
MemoryDB
)
Comment on lines +40 to +45
Copy link

@currantw currantw Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we follow the existing CamelCase convention? Or, if you think it is better to match now Amazon names then, then should it just be "ElastiCache"?

Suggested change
const (
// Elasticache represents Amazon ElastiCache service.
Elasticache ServiceType = iota
// MemoryDB represents Amazon MemoryDB service.
MemoryDB
)
const (
// Amazon ElastiCache service.
ElastiCache ServiceType = iota
// Amazon MemoryDB service.
MemoryDb
)


// IamAuthConfig represents configuration settings for IAM authentication.
type IamAuthConfig struct {
// The name of the ElastiCache/MemoryDB cluster.
clusterName string
// The type of service being used (ElastiCache or MemoryDB).
service ServiceType
// The AWS region where the ElastiCache/MemoryDB cluster is located.
region string
// Optional refresh interval in seconds for renewing IAM authentication tokens.
// If not provided, defaults to 300 seconds (5 min).
refreshIntervalSeconds *uint32

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for making refreshIntervalSeconds an *uint32 instead of just an uint32? I'm not really familiar with Go, but seems like this would help avoid the nil check in the toProtobuf method below as well?

}

// NewIamAuthConfig returns an [IamAuthConfig] struct with the given configuration.
func NewIamAuthConfig(clusterName string, service ServiceType, region string) *IamAuthConfig {
defaultRefresh := uint32(300)
return &IamAuthConfig{
clusterName: clusterName,
service: service,
region: region,
refreshIntervalSeconds: &defaultRefresh,
}
}

// WithRefreshIntervalSeconds sets the refresh interval in seconds for IAM token renewal.
func (config *IamAuthConfig) WithRefreshIntervalSeconds(seconds uint32) *IamAuthConfig {
config.refreshIntervalSeconds = &seconds
return config
}

func (config *IamAuthConfig) toProtobuf() *protobuf.IamCredentials {
iamCreds := &protobuf.IamCredentials{
ClusterName: config.clusterName,
Region: config.region,
}

if config.service == Elasticache {
iamCreds.ServiceType = protobuf.ServiceType_ELASTICACHE
} else {
iamCreds.ServiceType = protobuf.ServiceType_MEMORYDB
}

if config.refreshIntervalSeconds != nil {
iamCreds.RefreshIntervalSeconds = config.refreshIntervalSeconds
}

return iamCreds
}

// ServerCredentials represents the credentials for connecting to servers.
// Supports two authentication modes:
// - Password-based authentication: Use username and password
// - IAM authentication: Use username (required) and iamConfig
//
// These modes are mutually exclusive.
type ServerCredentials struct {
// The username that will be used for authenticating connections to the servers. If not supplied, "default"
// will be used.
// will be used for password-based authentication. Required for IAM authentication.
username string
// The password that will be used for authenticating connections to the servers.
// Mutually exclusive with iamConfig.
password string
// IAM authentication configuration. Mutually exclusive with password.
// The client will automatically generate and refresh the authentication token based on the provided configuration.
iamConfig *IamAuthConfig
}

// NewServerCredentials returns a [ServerCredentials] struct with the given username and password.
func NewServerCredentials(username string, password string) *ServerCredentials {
return &ServerCredentials{username, password}
return &ServerCredentials{username: username, password: password}
}

// NewServerCredentialsWithDefaultUsername returns a [ServerCredentials] struct with a default username of "default" and the
Expand All @@ -54,8 +122,34 @@ func NewServerCredentialsWithDefaultUsername(password string) *ServerCredentials
return &ServerCredentials{password: password}
}

// NewServerCredentialsWithIam returns a [ServerCredentials] struct configured for IAM authentication.
// The username is required for IAM authentication.
func NewServerCredentialsWithIam(username string, iamConfig *IamAuthConfig) (*ServerCredentials, error) {
if username == "" {
return nil, errors.New("username is required for IAM authentication")
}
if iamConfig == nil {
return nil, errors.New("iamConfig cannot be nil")
}
return &ServerCredentials{username: username, iamConfig: iamConfig}, nil
}

func (creds *ServerCredentials) toProtobuf() *protobuf.AuthenticationInfo {
return &protobuf.AuthenticationInfo{Username: creds.username, Password: creds.password}
authInfo := &protobuf.AuthenticationInfo{
Username: creds.username,
Password: creds.password,
}

if creds.iamConfig != nil {
authInfo.IamCredentials = creds.iamConfig.toProtobuf()
}

return authInfo
}

// IsIamAuth returns true if this credential is configured for IAM authentication.
func (creds *ServerCredentials) IsIamAuth() bool {
return creds.iamConfig != nil
}

// ReadFrom represents the client's read from strategy.
Expand Down
57 changes: 51 additions & 6 deletions go/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,60 @@ func TestServerCredentials(t *testing.T) {
},
}

for i, parameter := range parameters {
t.Run(fmt.Sprintf("Testing [%v]", i), func(t *testing.T) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates a subtest. This allows you to treat each iteration as a separate test and documents each iteration in the output. Why remove it?

result := parameter.input.toProtobuf()

assert.Equal(t, parameter.expected, result)
})
for _, param := range parameters {
result := param.input.toProtobuf()
assert.Equal(t, param.expected, result)
}
}

func TestServerCredentialsWithIam(t *testing.T) {
iamConfig := NewIamAuthConfig("my-cluster", Elasticache, "us-east-1")
creds, err := NewServerCredentialsWithIam("myUser", iamConfig)

assert.Nil(t, err)
assert.NotNil(t, creds)
assert.True(t, creds.IsIamAuth())

authInfo := creds.toProtobuf()
assert.Equal(t, "myUser", authInfo.Username)
assert.Equal(t, "", authInfo.Password)
assert.NotNil(t, authInfo.IamCredentials)
assert.Equal(t, "my-cluster", authInfo.IamCredentials.ClusterName)
assert.Equal(t, "us-east-1", authInfo.IamCredentials.Region)
assert.Equal(t, protobuf.ServiceType_ELASTICACHE, authInfo.IamCredentials.ServiceType)
assert.Equal(t, uint32(300), *authInfo.IamCredentials.RefreshIntervalSeconds)
}

func TestServerCredentialsWithIamCustomRefresh(t *testing.T) {
iamConfig := NewIamAuthConfig("my-cluster", MemoryDB, "us-west-2").
WithRefreshIntervalSeconds(600)
creds, err := NewServerCredentialsWithIam("myUser", iamConfig)

assert.Nil(t, err)
assert.NotNil(t, creds)

authInfo := creds.toProtobuf()
assert.Equal(t, protobuf.ServiceType_MEMORYDB, authInfo.IamCredentials.ServiceType)
assert.Equal(t, uint32(600), *authInfo.IamCredentials.RefreshIntervalSeconds)
}

func TestServerCredentialsWithIamRequiresUsername(t *testing.T) {
iamConfig := NewIamAuthConfig("my-cluster", Elasticache, "us-east-1")
creds, err := NewServerCredentialsWithIam("", iamConfig)

assert.NotNil(t, err)
assert.Nil(t, creds)
assert.Contains(t, err.Error(), "username is required")
}

func TestServerCredentialsWithIamRequiresConfig(t *testing.T) {
creds, err := NewServerCredentialsWithIam("myUser", nil)

assert.NotNil(t, err)
assert.Nil(t, creds)
assert.Contains(t, err.Error(), "iamConfig cannot be nil")
}

func TestConfig_AzAffinity(t *testing.T) {
hosts := []string{"host1", "host2"}
ports := []int{1234, 5678}
Expand Down
Loading
Loading