@@ -2,10 +2,13 @@ package system
2
2
3
3
import (
4
4
"fmt"
5
+ "time"
5
6
6
7
"github.com/kubernetes-csi/csi-proxy/pkg/cim"
7
8
"github.com/kubernetes-csi/csi-proxy/pkg/server/system/impl"
8
- "github.com/kubernetes-csi/csi-proxy/pkg/utils"
9
+ "github.com/microsoft/wmi/pkg/errors"
10
+ wmiinst "github.com/microsoft/wmi/pkg/wmiinstance"
11
+ "github.com/microsoft/wmi/server2019/root/cimv2"
9
12
)
10
13
11
14
// Implements the System OS API calls. All code here should be very simple
@@ -24,6 +27,28 @@ type ServiceInfo struct {
24
27
Status uint32 `json:"Status"`
25
28
}
26
29
30
+ type periodicalCheckFunc func () (bool , error )
31
+
32
+ const (
33
+ // startServiceErrorCodeAccepted indicates the request is accepted
34
+ startServiceErrorCodeAccepted = 0
35
+
36
+ // startServiceErrorCodeAlreadyRunning indicates a service is already running
37
+ startServiceErrorCodeAlreadyRunning = 10
38
+
39
+ // stopServiceErrorCodeAccepted indicates the request is accepted
40
+ stopServiceErrorCodeAccepted = 0
41
+
42
+ // stopServiceErrorCodeStopPending indicates the request cannot be sent to the service because the state of the service is 0,1,2 (pending)
43
+ stopServiceErrorCodeStopPending = 5
44
+
45
+ // stopServiceErrorCodeDependentRunning indicates a service cannot be stopped as its dependents may still be running
46
+ stopServiceErrorCodeDependentRunning = 3
47
+
48
+ serviceStateRunning = "Running"
49
+ serviceStateStopped = "Stopped"
50
+ )
51
+
27
52
var (
28
53
startModeMappings = map [string ]uint32 {
29
54
"Boot" : impl .START_TYPE_BOOT ,
@@ -33,24 +58,27 @@ var (
33
58
"Disabled" : impl .START_TYPE_DISABLED ,
34
59
}
35
60
36
- statusMappings = map [string ]uint32 {
37
- "Unknown" : impl .SERVICE_STATUS_UNKNOWN ,
38
- "Stopped" : impl .SERVICE_STATUS_STOPPED ,
39
- "Start Pending" : impl .SERVICE_STATUS_START_PENDING ,
40
- "Stop Pending" : impl .SERVICE_STATUS_STOP_PENDING ,
41
- "Running" : impl .SERVICE_STATUS_RUNNING ,
42
- "Continue Pending" : impl .SERVICE_STATUS_CONTINUE_PENDING ,
43
- "Pause Pending" : impl .SERVICE_STATUS_PAUSE_PENDING ,
44
- "Paused" : impl .SERVICE_STATUS_PAUSED ,
61
+ stateMappings = map [string ]uint32 {
62
+ "Unknown" : impl .SERVICE_STATUS_UNKNOWN ,
63
+ serviceStateStopped : impl .SERVICE_STATUS_STOPPED ,
64
+ "Start Pending" : impl .SERVICE_STATUS_START_PENDING ,
65
+ "Stop Pending" : impl .SERVICE_STATUS_STOP_PENDING ,
66
+ serviceStateRunning : impl .SERVICE_STATUS_RUNNING ,
67
+ "Continue Pending" : impl .SERVICE_STATUS_CONTINUE_PENDING ,
68
+ "Pause Pending" : impl .SERVICE_STATUS_PAUSE_PENDING ,
69
+ "Paused" : impl .SERVICE_STATUS_PAUSED ,
45
70
}
71
+
72
+ serviceStateCheckInternal = 500 * time .Millisecond
73
+ serviceStateCheckTimeout = 5 * time .Second
46
74
)
47
75
48
76
func serviceStartModeToStartType (startMode string ) uint32 {
49
77
return startModeMappings [startMode ]
50
78
}
51
79
52
80
func serviceState (status string ) uint32 {
53
- return statusMappings [status ]
81
+ return stateMappings [status ]
54
82
}
55
83
56
84
type APIImplementor struct {}
@@ -101,23 +129,180 @@ func (APIImplementor) GetService(name string) (*ServiceInfo, error) {
101
129
}, nil
102
130
}
103
131
132
+ func waitForServiceState (serviceCheck periodicalCheckFunc , interval time.Duration , timeout time.Duration ) error {
133
+ timeoutChan := time .After (timeout )
134
+ ticker := time .NewTicker (interval )
135
+ defer ticker .Stop ()
136
+
137
+ for {
138
+ select {
139
+ case <- timeoutChan :
140
+ return errors .Timedout
141
+ case <- ticker .C :
142
+ done , err := serviceCheck ()
143
+ if err != nil {
144
+ return err
145
+ }
146
+
147
+ if done {
148
+ return nil
149
+ }
150
+ }
151
+ }
152
+ }
153
+
154
+ func getServiceState (name string ) (string , * cimv2.Win32_Service , error ) {
155
+ service , err := cim .QueryServiceByName (name , nil )
156
+ if err != nil {
157
+ return "" , nil , err
158
+ }
159
+
160
+ state , err := service .GetPropertyState ()
161
+ if err != nil {
162
+ return "" , nil , fmt .Errorf ("failed to get state property of service %s: %w" , name , err )
163
+ }
164
+
165
+ return state , service , nil
166
+ }
167
+
104
168
func (APIImplementor ) StartService (name string ) error {
105
- // Note: both StartService and StopService are not implemented by WMI
106
- script := `Start-Service -Name $env:ServiceName`
107
- cmdEnv := fmt .Sprintf ("ServiceName=%s" , name )
108
- out , err := utils .RunPowershellCmd (script , cmdEnv )
169
+ state , service , err := getServiceState (name )
109
170
if err != nil {
110
- return fmt .Errorf ("error starting service name=%s. cmd: %s, output: %s, error: %v" , name , script , string (out ), err )
171
+ return err
172
+ }
173
+
174
+ if state != serviceStateRunning {
175
+ var retVal uint32
176
+ retVal , err = service .StartService ()
177
+ if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning ) {
178
+ return fmt .Errorf ("error starting service name %s. return value: %d, error: %v" , name , retVal , err )
179
+ }
180
+
181
+ err = waitForServiceState (func () (bool , error ) {
182
+ state , service , err = getServiceState (name )
183
+ if err != nil {
184
+ return false , err
185
+ }
186
+
187
+ return state == serviceStateRunning , nil
188
+
189
+ }, serviceStateCheckInternal , serviceStateCheckTimeout )
190
+ if err != nil {
191
+ return fmt .Errorf ("error waiting service %s become running. error: %v" , name , err )
192
+ }
193
+ }
194
+
195
+ if state != serviceStateRunning {
196
+ return fmt .Errorf ("error starting service name %s. current state: %s" , name , state )
111
197
}
112
198
113
199
return nil
114
200
}
115
201
116
202
func (APIImplementor ) StopService (name string , force bool ) error {
117
- script := `Stop-Service -Name $env:ServiceName -Force:$([System.Convert]::ToBoolean($env:Force))`
118
- out , err := utils .RunPowershellCmd (script , fmt .Sprintf ("ServiceName=%s" , name ), fmt .Sprintf ("Force=%t" , force ))
203
+ state , service , err := getServiceState (name )
119
204
if err != nil {
120
- return fmt .Errorf ("error stopping service name=%s. cmd: %s, output: %s, error: %v" , name , script , string (out ), err )
205
+ return err
206
+ }
207
+
208
+ if state == serviceStateStopped {
209
+ return nil
210
+ }
211
+
212
+ stopSingleService := func (name string , service * wmiinst.WmiInstance ) (bool , error ) {
213
+ retVal , err := service .InvokeMethodWithReturn ("StopService" )
214
+ if err != nil || (retVal != stopServiceErrorCodeAccepted && retVal != stopServiceErrorCodeStopPending ) {
215
+ if retVal == stopServiceErrorCodeDependentRunning {
216
+ return true , fmt .Errorf ("error stopping service %s as dependent services are not stopped" , name )
217
+ }
218
+ return false , fmt .Errorf ("error stopping service %s. return value: %d, error: %v" , name , retVal , err )
219
+ }
220
+
221
+ var serviceState string
222
+ err = waitForServiceState (func () (bool , error ) {
223
+ serviceState , _ , err = getServiceState (name )
224
+ if err != nil {
225
+ return false , err
226
+ }
227
+
228
+ return serviceState == serviceStateStopped , nil
229
+
230
+ }, serviceStateCheckInternal , serviceStateCheckTimeout )
231
+ if err != nil {
232
+ return false , fmt .Errorf ("error waiting service %s become stopped. error: %v" , name , err )
233
+ }
234
+
235
+ if serviceState != serviceStateStopped {
236
+ return false , fmt .Errorf ("error stopping service name %s. current state: %s" , name , serviceState )
237
+ }
238
+
239
+ return false , nil
240
+ }
241
+
242
+ dependentRunning , err := stopSingleService (name , service .WmiInstance )
243
+ if ! force || err == nil || ! dependentRunning {
244
+ return err
245
+ }
246
+
247
+ var serviceNames []string
248
+ var servicesToCheck wmiinst.WmiInstanceCollection
249
+ servicesByName := map [string ]* wmiinst.WmiInstance {}
250
+
251
+ servicesToCheck = append (servicesToCheck , service .WmiInstance )
252
+ i := 0
253
+ for i < len (servicesToCheck ) {
254
+ current := servicesToCheck [i ]
255
+ i += 1
256
+
257
+ currentNameVal , err := current .GetProperty ("Name" )
258
+ if err != nil {
259
+ return err
260
+ }
261
+
262
+ currentName := currentNameVal .(string )
263
+ if _ , ok := servicesByName [currentName ]; ok {
264
+ continue
265
+ }
266
+
267
+ currentStateVal , err := current .GetProperty ("State" )
268
+ if err != nil {
269
+ return err
270
+ }
271
+
272
+ currentState := currentStateVal
273
+ if currentState != serviceStateRunning {
274
+ continue
275
+ }
276
+
277
+ servicesByName [currentName ] = current
278
+ serviceNames = append (serviceNames , currentName )
279
+
280
+ dependents , err := current .GetAssociated ("Win32_DependentService" , "Win32_Service" , "Dependent" , "Antecedent" )
281
+ if err != nil {
282
+ return err
283
+ }
284
+
285
+ servicesToCheck = append (servicesToCheck , dependents ... )
286
+ }
287
+
288
+ i = len (serviceNames ) - 1
289
+ for i >= 0 {
290
+ serviceName := serviceNames [i ]
291
+ i -= 1
292
+
293
+ state , service , err := getServiceState (serviceName )
294
+ if err != nil {
295
+ return err
296
+ }
297
+
298
+ if state == serviceStateStopped {
299
+ continue
300
+ }
301
+
302
+ _ , err = stopSingleService (serviceName , service .WmiInstance )
303
+ if err != nil {
304
+ return err
305
+ }
121
306
}
122
307
123
308
return nil
0 commit comments