Skip to content

Commit 86956a4

Browse files
Merge pull request #3002 from czeslavo/liveness-status-only
Add NodesStatusInfo method to the Scylla client
2 parents 71549c7 + 3229d0d commit 86956a4

File tree

9 files changed

+223
-65
lines changed

9 files changed

+223
-65
lines changed

pkg/probeserver/scylladbapistatus/prober.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (p *Prober) Readyz(w http.ResponseWriter, req *http.Request) {
7777
defer scyllaClient.Close()
7878

7979
// Contact Scylla to learn about the status of the member
80-
nodeStatuses, err := scyllaClient.Status(ctx, localhost)
80+
nodeStatuses, err := scyllaClient.NodesStatusAndStateInfo(ctx, localhost)
8181
if err != nil {
8282
klog.ErrorS(err, "readyz probe: can't get scylla node status", "Service", p.serviceRef())
8383
w.WriteHeader(http.StatusInternalServerError)

pkg/scyllaclient/client.go

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,30 +85,50 @@ func (c *Client) Close() {
8585
}
8686
}
8787

88-
func (c *Client) Status(ctx context.Context, host string) (NodeStatusInfoSlice, error) {
88+
func (c *Client) NodesStatusInfo(ctx context.Context, host string) (NodeStatusInfoSlice, error) {
8989
if len(host) > 0 {
90-
// Always query same host
9190
ctx = forceHost(ctx, host)
9291
}
9392

94-
// Get all hosts
95-
resp, err := c.scyllaClient.Operations.StorageServiceHostIDGet(&scyllaoperations.StorageServiceHostIDGetParams{Context: ctx})
93+
getHostIDResp, err := c.scyllaClient.Operations.StorageServiceHostIDGet(&scyllaoperations.StorageServiceHostIDGetParams{Context: ctx})
9694
if err != nil {
9795
return nil, err
9896
}
9997

100-
all := make([]NodeStatusInfo, len(resp.Payload))
101-
for i, p := range resp.Payload {
102-
all[i].Addr = p.Key
103-
all[i].HostID = p.Value
98+
all := make([]NodeStatusInfo, len(getHostIDResp.Payload))
99+
for i, n := range getHostIDResp.Payload {
100+
all[i].Addr = n.Key
101+
all[i].HostID = n.Value
102+
all[i].Status = NodeStatusDown
104103
}
105104

106-
// Get live nodes
107-
live, err := c.scyllaClient.Operations.GossiperEndpointLiveGet(&scyllaoperations.GossiperEndpointLiveGetParams{Context: ctx})
105+
getLiveNodesResp, err := c.scyllaClient.Operations.GossiperEndpointLiveGet(&scyllaoperations.GossiperEndpointLiveGetParams{Context: ctx})
108106
if err != nil {
109107
return nil, err
110108
}
111-
setNodeStatus(all, NodeStatusUp, live.Payload)
109+
liveNodeAddrs := strset.New(getLiveNodesResp.Payload...)
110+
111+
for i, n := range all {
112+
if liveNodeAddrs.Has(n.Addr) {
113+
all[i].Status = NodeStatusUp
114+
}
115+
}
116+
117+
return all, nil
118+
}
119+
120+
func (c *Client) NodesStatusAndStateInfo(ctx context.Context, host string) (NodeStatusAndStateInfoSlice, error) {
121+
allNodesLivenessStatus, err := c.NodesStatusInfo(ctx, host)
122+
if err != nil {
123+
return nil, err
124+
}
125+
126+
all := make([]NodeStatusAndStateInfo, len(allNodesLivenessStatus))
127+
for i, n := range allNodesLivenessStatus {
128+
all[i] = NodeStatusAndStateInfo{
129+
NodeStatusInfo: n,
130+
}
131+
}
112132

113133
// Get joining nodes
114134
joining, err := c.scyllaClient.Operations.StorageServiceNodesJoiningGet(&scyllaoperations.StorageServiceNodesJoiningGetParams{Context: ctx})
@@ -438,7 +458,7 @@ func DefaultTransport() *http.Transport {
438458
}
439459
}
440460

441-
func setNodeState(all []NodeStatusInfo, state NodeState, addrs []string) {
461+
func setNodeState(all []NodeStatusAndStateInfo, state NodeState, addrs []string) {
442462
if len(addrs) == 0 {
443463
return
444464
}
@@ -451,19 +471,6 @@ func setNodeState(all []NodeStatusInfo, state NodeState, addrs []string) {
451471
}
452472
}
453473

454-
func setNodeStatus(all []NodeStatusInfo, status NodeStatus, addrs []string) {
455-
if len(addrs) == 0 {
456-
return
457-
}
458-
m := strset.New(addrs...)
459-
460-
for i := range all {
461-
if m.Has(all[i].Addr) {
462-
all[i].Status = status
463-
}
464-
}
465-
}
466-
467474
// fixContentType adjusts Scylla REST API response so that it can be consumed
468475
// by Open API.
469476
func fixContentType(next http.RoundTripper) http.RoundTripper {

pkg/scyllaclient/model.go

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -120,36 +120,30 @@ func operationalModeFromString(str string) OperationalMode {
120120
type CompactionType string
121121

122122
const (
123-
CompactionCompactionType CompactionType = "COMPACTION"
124-
CleanupCompactionType CompactionType = "CLEANUP"
125-
ScrubCompactionType CompactionType = "SCRUB"
126-
UpgradeCompactionType CompactionType = "UPGRADE"
127-
ReshapeCompactionType CompactionType = "RESHAPE"
123+
CleanupCompactionType CompactionType = "CLEANUP"
128124
)
129125

130-
// NodeStatusInfo represents a nodetool status line.
131-
type NodeStatusInfo struct {
132-
HostID string
133-
Addr string
134-
Status NodeStatus
135-
State NodeState
126+
// NodeStatusAndStateInfo represents a node's status and state (like in nodetool status).
127+
type NodeStatusAndStateInfo struct {
128+
NodeStatusInfo
129+
State NodeState
136130
}
137131

138-
func (s NodeStatusInfo) String() string {
132+
func (s NodeStatusAndStateInfo) String() string {
139133
return fmt.Sprintf("host: %s, Status: %s%s", s.Addr, s.Status, s.State)
140134
}
141135

142136
// IsUN returns true if host is Up and NORMAL meaning it's a fully functional
143137
// live node.
144-
func (s NodeStatusInfo) IsUN() bool {
138+
func (s NodeStatusAndStateInfo) IsUN() bool {
145139
return s.Status == NodeStatusUp && s.State == NodeStateNormal
146140
}
147141

148-
// NodeStatusInfoSlice adds functionality to Status response.
149-
type NodeStatusInfoSlice []NodeStatusInfo
142+
// NodeStatusAndStateInfoSlice adds functionality to Status response.
143+
type NodeStatusAndStateInfoSlice []NodeStatusAndStateInfo
150144

151145
// Hosts returns slice of address of all nodes.
152-
func (s NodeStatusInfoSlice) Hosts() []string {
146+
func (s NodeStatusAndStateInfoSlice) Hosts() []string {
153147
var hosts []string
154148
for _, h := range s {
155149
hosts = append(hosts, h.Addr)
@@ -158,7 +152,7 @@ func (s NodeStatusInfoSlice) Hosts() []string {
158152
}
159153

160154
// HostIDs returns slice of HostID of all nodes.
161-
func (s NodeStatusInfoSlice) HostIDs() []string {
155+
func (s NodeStatusAndStateInfoSlice) HostIDs() []string {
162156
var hostIDs []string
163157
for _, h := range s {
164158
hostIDs = append(hostIDs, h.HostID)
@@ -167,7 +161,7 @@ func (s NodeStatusInfoSlice) HostIDs() []string {
167161
}
168162

169163
// LiveHosts returns slice of address of nodes in UN state.
170-
func (s NodeStatusInfoSlice) LiveHosts() []string {
164+
func (s NodeStatusAndStateInfoSlice) LiveHosts() []string {
171165
var hosts []string
172166
for _, h := range s {
173167
if h.IsUN() {
@@ -178,7 +172,7 @@ func (s NodeStatusInfoSlice) LiveHosts() []string {
178172
}
179173

180174
// DownHosts returns slice of address of nodes that are down.
181-
func (s NodeStatusInfoSlice) DownHosts() []string {
175+
func (s NodeStatusAndStateInfoSlice) DownHosts() []string {
182176
var hosts []string
183177
for _, h := range s {
184178
if h.Status == NodeStatusDown {
@@ -189,7 +183,7 @@ func (s NodeStatusInfoSlice) DownHosts() []string {
189183
}
190184

191185
// DownHostIDs returns slice of HostID of nodes that are down.
192-
func (s NodeStatusInfoSlice) DownHostIDs() []string {
186+
func (s NodeStatusAndStateInfoSlice) DownHostIDs() []string {
193187
var hostIDs []string
194188
for _, h := range s {
195189
if h.Status == NodeStatusDown {
@@ -198,3 +192,43 @@ func (s NodeStatusInfoSlice) DownHostIDs() []string {
198192
}
199193
return hostIDs
200194
}
195+
196+
// NodeStatusInfo represents the status (Up/Down) of a node.
197+
type NodeStatusInfo struct {
198+
HostID string
199+
Addr string
200+
Status NodeStatus
201+
}
202+
203+
type NodeStatusInfoSlice []NodeStatusInfo
204+
205+
// UpHostIDs returns slice of HostID of nodes that are up.
206+
func (s NodeStatusInfoSlice) UpHostIDs() []string {
207+
var hosts []string
208+
for _, h := range s {
209+
if h.Status == NodeStatusUp {
210+
hosts = append(hosts, h.Addr)
211+
}
212+
}
213+
return hosts
214+
}
215+
216+
// HostIDs returns slice of HostID of all nodes.
217+
func (s NodeStatusInfoSlice) HostIDs() []string {
218+
var hostIDs []string
219+
for _, h := range s {
220+
hostIDs = append(hostIDs, h.HostID)
221+
}
222+
return hostIDs
223+
}
224+
225+
// DownHostIDs returns slice of HostID of nodes that are down.
226+
func (s NodeStatusInfoSlice) DownHostIDs() []string {
227+
var hosts []string
228+
for _, h := range s {
229+
if h.Status == NodeStatusDown {
230+
hosts = append(hosts, h.HostID)
231+
}
232+
}
233+
return hosts
234+
}

pkg/scyllaclient/status.go

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,35 @@
33
package scyllaclient
44

55
import (
6+
"errors"
7+
68
"github.com/go-openapi/runtime"
7-
"github.com/pkg/errors"
89
scyllav2models "github.com/scylladb/scylladb-swagger-go-client/scylladb/gen/v2/models"
910
)
1011

1112
// StatusCodeOf returns HTTP status code carried by the error or it's cause.
1213
// If not status can be found it returns 0.
1314
func StatusCodeOf(err error) int {
14-
err = errors.Cause(err)
15-
switch v := err.(type) {
16-
case interface {
15+
type coder interface {
1716
Code() int
18-
}:
19-
return v.Code()
20-
case *runtime.APIError:
21-
return v.Code
22-
case interface {
17+
}
18+
if v := new(coder); errors.As(err, v) {
19+
return (*v).Code()
20+
}
21+
22+
type payloader interface {
2323
GetPayload() *scyllav2models.ErrorModel
24-
}:
25-
p := v.GetPayload()
24+
}
25+
if v := new(payloader); errors.As(err, v) {
26+
p := (*v).GetPayload()
2627
if p != nil {
2728
return int(p.Code)
2829
}
2930
}
31+
32+
if v := new(runtime.APIError); errors.As(err, &v) {
33+
return v.Code
34+
}
35+
3036
return 0
3137
}

pkg/scyllaclient/status_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package scyllaclient_test
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"testing"
7+
8+
"github.com/go-openapi/runtime"
9+
"github.com/scylladb/scylla-operator/pkg/scyllaclient"
10+
scyllav2models "github.com/scylladb/scylladb-swagger-go-client/scylladb/gen/v2/models"
11+
)
12+
13+
type mockCoderError struct {
14+
code int
15+
}
16+
17+
func (e *mockCoderError) Code() int {
18+
return e.code
19+
}
20+
21+
func (e *mockCoderError) Error() string {
22+
return fmt.Sprintf("mock coder error with code %d", e.code)
23+
}
24+
25+
type mockPayloaderError struct {
26+
payload *scyllav2models.ErrorModel
27+
}
28+
29+
func (e *mockPayloaderError) GetPayload() *scyllav2models.ErrorModel {
30+
return e.payload
31+
}
32+
33+
func (e *mockPayloaderError) Error() string {
34+
if e.payload != nil {
35+
return fmt.Sprintf("mock payloader error with code %d", e.payload.Code)
36+
}
37+
return "mock payloader error with nil payload"
38+
}
39+
40+
func TestStatusCodeOf(t *testing.T) {
41+
testCases := []struct {
42+
name string
43+
err error
44+
expectedCode int
45+
}{
46+
{
47+
name: "coder interface - direct error",
48+
err: &mockCoderError{code: 404},
49+
expectedCode: 404,
50+
},
51+
{
52+
name: "coder interface - wrapped error",
53+
err: fmt.Errorf("wrapped: %w", &mockCoderError{code: 500}),
54+
expectedCode: 500,
55+
},
56+
{
57+
name: "payloader interface with valid payload - direct error",
58+
err: &mockPayloaderError{payload: &scyllav2models.ErrorModel{Code: 400}},
59+
expectedCode: 400,
60+
},
61+
{
62+
name: "payloader interface with valid payload - wrapped error",
63+
err: fmt.Errorf("wrapped: %w", &mockPayloaderError{payload: &scyllav2models.ErrorModel{Code: 403}}),
64+
expectedCode: 403,
65+
},
66+
{
67+
name: "payloader interface with nil payload - direct error",
68+
err: &mockPayloaderError{payload: nil},
69+
expectedCode: 0,
70+
},
71+
{
72+
name: "payloader interface with nil payload - wrapped error",
73+
err: fmt.Errorf("wrapped: %w", &mockPayloaderError{payload: nil}),
74+
expectedCode: 0,
75+
},
76+
{
77+
name: "runtime.APIError - direct error",
78+
err: &runtime.APIError{Code: 502},
79+
expectedCode: 502,
80+
},
81+
{
82+
name: "runtime.APIError - wrapped error",
83+
err: fmt.Errorf("wrapped: %w", &runtime.APIError{Code: 503}),
84+
expectedCode: 503,
85+
},
86+
{
87+
name: "unknown error type - direct error",
88+
err: errors.New("some unknown error"),
89+
expectedCode: 0,
90+
},
91+
{
92+
name: "unknown error type - wrapped error",
93+
err: fmt.Errorf("wrapped: %w", errors.New("some unknown error")),
94+
expectedCode: 0,
95+
},
96+
{
97+
name: "nil error",
98+
err: nil,
99+
expectedCode: 0,
100+
},
101+
}
102+
103+
for _, tc := range testCases {
104+
t.Run(tc.name, func(t *testing.T) {
105+
result := scyllaclient.StatusCodeOf(tc.err)
106+
if result != tc.expectedCode {
107+
t.Errorf("StatusCodeOf(%v) = %d, expected %d", tc.err, result, tc.expectedCode)
108+
}
109+
})
110+
}
111+
}

0 commit comments

Comments
 (0)