Skip to content

Commit

Permalink
Add iterators to unified resource cache
Browse files Browse the repository at this point in the history
This updates the UnifiedResourceCache with IterateResources as an
alternative to IterateUnifiedResources. The new function returns
an iterator instead of collecting and returning a page of results.
While this API may not entirely replace the current one, it offers
a better way for users that just want to iterate resources without
collecting them. Additionally, a few helper methods were included
for callers that might wish to only iterate one specific resource
type. Internally the UnifiedResourceCache was refactored to use the
same logic for all exposed iteration methods.
  • Loading branch information
rosstimothy committed Feb 15, 2025
1 parent d1a38bb commit 3621abd
Show file tree
Hide file tree
Showing 4 changed files with 900 additions and 157 deletions.
111 changes: 43 additions & 68 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1519,62 +1519,46 @@ func (a *Server) runPeriodicOperations() {
}()
case heartbeatCheckKey:
go func() {
req := &proto.ListUnifiedResourcesRequest{Kinds: []string{types.KindNode}, SortBy: types.SortBy{Field: types.ResourceKind}}

for {
_, next, err := a.UnifiedResourceCache.IterateUnifiedResources(a.closeCtx,
func(rwl types.ResourceWithLabels) (bool, error) {
srv, ok := rwl.(types.Server)
if !ok {
return false, nil
}
if services.NodeHasMissedKeepAlives(srv) {
heartbeatsMissedByAuth.Inc()
}

if srv.GetSubKind() != types.SubKindOpenSSHNode {
return false, nil
}
// TODO(tross) DELETE in v20.0.0 - all invalid hostnames should have been sanitized by then.
if !validServerHostname(srv.GetHostname()) {
logger := a.logger.With("server", srv.GetName(), "hostname", srv.GetHostname())

logger.DebugContext(a.closeCtx, "sanitizing invalid static SSH server hostname")
// Any existing static hosts will not have their
// hostname sanitized since they don't heartbeat.
if err := sanitizeHostname(srv); err != nil {
logger.WarnContext(a.closeCtx, "failed to sanitize static SSH server hostname", "error", err)
return false, nil
}

if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) {
logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err)
}
} else if oldHostname, ok := srv.GetLabel(replacedHostnameLabel); ok && validServerHostname(oldHostname) {
// If the hostname has been replaced by a sanitized version, revert it back to the original
// if the original is valid under the most recent rules.
logger := a.logger.With("server", srv.GetName(), "old_hostname", oldHostname, "sanitized_hostname", srv.GetHostname())
if err := restoreSanitizedHostname(srv); err != nil {
logger.WarnContext(a.closeCtx, "failed to restore sanitized static SSH server hostname", "error", err)
return false, nil
}
if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) {
logger.WarnContext(a.closeCtx, "Failed to update node hostname", "error", err)
}
}

return false, nil
},
req,
)
for srv, err := range a.UnifiedResourceCache.IterateNodes(a.closeCtx, "", services.IterateAscend) {
if err != nil {
a.logger.ErrorContext(a.closeCtx, "Failed to load nodes for heartbeat metric calculation", "error", err)
return
}

req.StartKey = next
if req.StartKey == "" {
break
if services.NodeHasMissedKeepAlives(srv) {
heartbeatsMissedByAuth.Inc()
}

if srv.GetSubKind() != types.SubKindOpenSSHNode {
continue
}

// TODO(tross) DELETE in v20.0.0 - all invalid hostnames should have been sanitized by then.
if !validServerHostname(srv.GetHostname()) {
logger := a.logger.With("server", srv.GetName(), "hostname", srv.GetHostname())

logger.DebugContext(a.closeCtx, "sanitizing invalid static SSH server hostname")
// Any existing static hosts will not have their
// hostname sanitized since they don't heartbeat.
if err := sanitizeHostname(srv); err != nil {
logger.WarnContext(a.closeCtx, "failed to sanitize static SSH server hostname", "error", err)
continue
}

if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) {
logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err)
}
} else if oldHostname, ok := srv.GetLabel(replacedHostnameLabel); ok && validServerHostname(oldHostname) {
// If the hostname has been replaced by a sanitized version, revert it back to the original
// if the original is valid under the most recent rules.
logger := a.logger.With("server", srv.GetName(), "old_hostname", oldHostname, "sanitized_hostname", srv.GetHostname())
if err := restoreSanitizedHostname(srv); err != nil {
logger.WarnContext(a.closeCtx, "failed to restore sanitized static SSH server hostname", "error", err)
continue
}
if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) {
logger.WarnContext(a.closeCtx, "Failed to update node hostname", "error", err)
}
}
}
}()
Expand Down Expand Up @@ -3316,27 +3300,18 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types.
// If the certificate is targeting a trusted Teleport cluster, it is the
// responsibility of the cluster to ensure its existence.
if req.routeToCluster == clusterName && req.kubernetesCluster != "" {
found, _, err := a.UnifiedResourceCache.IterateUnifiedResources(a.closeCtx, func(rwl types.ResourceWithLabels) (bool, error) {
if rwl.GetKind() != types.KindKubeServer {
return false, nil
var found bool
for ks, err := range a.UnifiedResourceCache.IterateKubernetesServers(a.closeCtx, "", services.IterateAscend) {
if err != nil {
return nil, trace.Wrap(err)
}

ks, ok := rwl.(types.KubeServer)
if !ok {
return false, nil
if ks.GetCluster().GetName() == req.kubernetesCluster {
found = true
break
}

return ks.GetCluster().GetName() == req.kubernetesCluster, nil
}, &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindKubeServer},
SortBy: types.SortBy{Field: services.SortByName},
Limit: 1,
})
if err != nil {
return nil, trace.Wrap(err)
}

if len(found) == 0 {
if !found {
return nil, trace.BadParameter("Kubernetes cluster %q is not registered in this Teleport cluster; you can list registered Kubernetes clusters using 'tsh kube ls'", req.kubernetesCluster)
}
}
Expand Down
9 changes: 4 additions & 5 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5584,13 +5584,12 @@ func TestListUnifiedResources_KindsFilter(t *testing.T) {
require.Equal(t, types.KindDatabaseServer, r.GetKind())
}

// Check for invalid sort error message
// Check that sorting is not required
_, err = clt.ListUnifiedResources(ctx, &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindDatabase},
Limit: 5,
SortBy: types.SortBy{},
Kinds: []string{types.KindDatabase},
Limit: 5,
})
require.ErrorContains(t, err, "sort field is required")
require.NoError(t, err, "sort field is not required")
}

func TestListUnifiedResources_WithPinnedResources(t *testing.T) {
Expand Down
Loading

0 comments on commit 3621abd

Please sign in to comment.