Skip to content

Commit 6ad8e67

Browse files
committed
Use WMI to implement Service related System APIs
1 parent a9bd679 commit 6ad8e67

File tree

1 file changed

+204
-19
lines changed

1 file changed

+204
-19
lines changed

pkg/os/system/api.go

Lines changed: 204 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package system
22

33
import (
44
"fmt"
5+
"time"
56

67
"github.com/kubernetes-csi/csi-proxy/pkg/cim"
78
"github.com/kubernetes-csi/csi-proxy/pkg/server/system/impl"
8-
"github.com/kubernetes-csi/csi-proxy/pkg/utils"
9+
"github.com/microsoft/wmi/pkg/errors"
10+
wmiinst "github.com/microsoft/wmi/pkg/wmiinstance"
11+
"github.com/microsoft/wmi/server2019/root/cimv2"
912
)
1013

1114
// Implements the System OS API calls. All code here should be very simple
@@ -24,6 +27,28 @@ type ServiceInfo struct {
2427
Status uint32 `json:"Status"`
2528
}
2629

30+
type periodicalCheckFunc func() (bool, error)
31+
32+
const (
33+
// startServiceErrorCodeAccepted indicates the request is accepted
34+
startServiceErrorCodeAccepted = 0
35+
36+
// startServiceErrorCodeAlreadyRunning indicates a service is already running
37+
startServiceErrorCodeAlreadyRunning = 10
38+
39+
// stopServiceErrorCodeAccepted indicates the request is accepted
40+
stopServiceErrorCodeAccepted = 0
41+
42+
// stopServiceErrorCodeStopPending indicates the request cannot be sent to the service because the state of the service is 0,1,2 (pending)
43+
stopServiceErrorCodeStopPending = 5
44+
45+
// stopServiceErrorCodeDependentRunning indicates a service cannot be stopped as its dependents may still be running
46+
stopServiceErrorCodeDependentRunning = 3
47+
48+
serviceStateRunning = "Running"
49+
serviceStateStopped = "Stopped"
50+
)
51+
2752
var (
2853
startModeMappings = map[string]uint32{
2954
"Boot": impl.START_TYPE_BOOT,
@@ -33,24 +58,27 @@ var (
3358
"Disabled": impl.START_TYPE_DISABLED,
3459
}
3560

36-
statusMappings = map[string]uint32{
37-
"Unknown": impl.SERVICE_STATUS_UNKNOWN,
38-
"Stopped": impl.SERVICE_STATUS_STOPPED,
39-
"Start Pending": impl.SERVICE_STATUS_START_PENDING,
40-
"Stop Pending": impl.SERVICE_STATUS_STOP_PENDING,
41-
"Running": impl.SERVICE_STATUS_RUNNING,
42-
"Continue Pending": impl.SERVICE_STATUS_CONTINUE_PENDING,
43-
"Pause Pending": impl.SERVICE_STATUS_PAUSE_PENDING,
44-
"Paused": impl.SERVICE_STATUS_PAUSED,
61+
stateMappings = map[string]uint32{
62+
"Unknown": impl.SERVICE_STATUS_UNKNOWN,
63+
serviceStateStopped: impl.SERVICE_STATUS_STOPPED,
64+
"Start Pending": impl.SERVICE_STATUS_START_PENDING,
65+
"Stop Pending": impl.SERVICE_STATUS_STOP_PENDING,
66+
serviceStateRunning: impl.SERVICE_STATUS_RUNNING,
67+
"Continue Pending": impl.SERVICE_STATUS_CONTINUE_PENDING,
68+
"Pause Pending": impl.SERVICE_STATUS_PAUSE_PENDING,
69+
"Paused": impl.SERVICE_STATUS_PAUSED,
4570
}
71+
72+
serviceStateCheckInternal = 500 * time.Millisecond
73+
serviceStateCheckTimeout = 5 * time.Second
4674
)
4775

4876
func serviceStartModeToStartType(startMode string) uint32 {
4977
return startModeMappings[startMode]
5078
}
5179

5280
func serviceState(status string) uint32 {
53-
return statusMappings[status]
81+
return stateMappings[status]
5482
}
5583

5684
type APIImplementor struct{}
@@ -101,23 +129,180 @@ func (APIImplementor) GetService(name string) (*ServiceInfo, error) {
101129
}, nil
102130
}
103131

132+
func waitForServiceState(serviceCheck periodicalCheckFunc, interval time.Duration, timeout time.Duration) error {
133+
timeoutChan := time.After(timeout)
134+
ticker := time.NewTicker(interval)
135+
defer ticker.Stop()
136+
137+
for {
138+
select {
139+
case <-timeoutChan:
140+
return errors.Timedout
141+
case <-ticker.C:
142+
done, err := serviceCheck()
143+
if err != nil {
144+
return err
145+
}
146+
147+
if done {
148+
return nil
149+
}
150+
}
151+
}
152+
}
153+
154+
func getServiceState(name string) (string, *cimv2.Win32_Service, error) {
155+
service, err := cim.QueryServiceByName(name, nil)
156+
if err != nil {
157+
return "", nil, err
158+
}
159+
160+
state, err := service.GetPropertyState()
161+
if err != nil {
162+
return "", nil, fmt.Errorf("failed to get state property of service %s: %w", name, err)
163+
}
164+
165+
return state, service, nil
166+
}
167+
104168
func (APIImplementor) StartService(name string) error {
105-
// Note: both StartService and StopService are not implemented by WMI
106-
script := `Start-Service -Name $env:ServiceName`
107-
cmdEnv := fmt.Sprintf("ServiceName=%s", name)
108-
out, err := utils.RunPowershellCmd(script, cmdEnv)
169+
state, service, err := getServiceState(name)
109170
if err != nil {
110-
return fmt.Errorf("error starting service name=%s. cmd: %s, output: %s, error: %v", name, script, string(out), err)
171+
return err
172+
}
173+
174+
if state != serviceStateRunning {
175+
var retVal uint32
176+
retVal, err = service.StartService()
177+
if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) {
178+
return fmt.Errorf("error starting service name %s. return value: %d, error: %v", name, retVal, err)
179+
}
180+
181+
err = waitForServiceState(func() (bool, error) {
182+
state, service, err = getServiceState(name)
183+
if err != nil {
184+
return false, err
185+
}
186+
187+
return state == serviceStateRunning, nil
188+
189+
}, serviceStateCheckInternal, serviceStateCheckTimeout)
190+
if err != nil {
191+
return fmt.Errorf("error waiting service %s become running. error: %v", name, err)
192+
}
193+
}
194+
195+
if state != serviceStateRunning {
196+
return fmt.Errorf("error starting service name %s. current state: %s", name, state)
111197
}
112198

113199
return nil
114200
}
115201

116202
func (APIImplementor) StopService(name string, force bool) error {
117-
script := `Stop-Service -Name $env:ServiceName -Force:$([System.Convert]::ToBoolean($env:Force))`
118-
out, err := utils.RunPowershellCmd(script, fmt.Sprintf("ServiceName=%s", name), fmt.Sprintf("Force=%t", force))
203+
state, service, err := getServiceState(name)
119204
if err != nil {
120-
return fmt.Errorf("error stopping service name=%s. cmd: %s, output: %s, error: %v", name, script, string(out), err)
205+
return err
206+
}
207+
208+
if state == serviceStateStopped {
209+
return nil
210+
}
211+
212+
stopSingleService := func(name string, service *wmiinst.WmiInstance) (bool, error) {
213+
retVal, err := service.InvokeMethodWithReturn("StopService")
214+
if err != nil || (retVal != stopServiceErrorCodeAccepted && retVal != stopServiceErrorCodeStopPending) {
215+
if retVal == stopServiceErrorCodeDependentRunning {
216+
return true, fmt.Errorf("error stopping service %s as dependent services are not stopped", name)
217+
}
218+
return false, fmt.Errorf("error stopping service %s. return value: %d, error: %v", name, retVal, err)
219+
}
220+
221+
var serviceState string
222+
err = waitForServiceState(func() (bool, error) {
223+
serviceState, _, err = getServiceState(name)
224+
if err != nil {
225+
return false, err
226+
}
227+
228+
return serviceState == serviceStateStopped, nil
229+
230+
}, serviceStateCheckInternal, serviceStateCheckTimeout)
231+
if err != nil {
232+
return false, fmt.Errorf("error waiting service %s become stopped. error: %v", name, err)
233+
}
234+
235+
if serviceState != serviceStateStopped {
236+
return false, fmt.Errorf("error stopping service name %s. current state: %s", name, serviceState)
237+
}
238+
239+
return false, nil
240+
}
241+
242+
dependentRunning, err := stopSingleService(name, service.WmiInstance)
243+
if !force || err == nil || !dependentRunning {
244+
return err
245+
}
246+
247+
var serviceNames []string
248+
var servicesToCheck wmiinst.WmiInstanceCollection
249+
servicesByName := map[string]*wmiinst.WmiInstance{}
250+
251+
servicesToCheck = append(servicesToCheck, service.WmiInstance)
252+
i := 0
253+
for i < len(servicesToCheck) {
254+
current := servicesToCheck[i]
255+
i += 1
256+
257+
currentNameVal, err := current.GetProperty("Name")
258+
if err != nil {
259+
return err
260+
}
261+
262+
currentName := currentNameVal.(string)
263+
if _, ok := servicesByName[currentName]; ok {
264+
continue
265+
}
266+
267+
currentStateVal, err := current.GetProperty("State")
268+
if err != nil {
269+
return err
270+
}
271+
272+
currentState := currentStateVal
273+
if currentState != serviceStateRunning {
274+
continue
275+
}
276+
277+
servicesByName[currentName] = current
278+
serviceNames = append(serviceNames, currentName)
279+
280+
dependents, err := current.GetAssociated("Win32_DependentService", "Win32_Service", "Dependent", "Antecedent")
281+
if err != nil {
282+
return err
283+
}
284+
285+
servicesToCheck = append(servicesToCheck, dependents...)
286+
}
287+
288+
i = len(serviceNames) - 1
289+
for i >= 0 {
290+
serviceName := serviceNames[i]
291+
i -= 1
292+
293+
state, service, err := getServiceState(serviceName)
294+
if err != nil {
295+
return err
296+
}
297+
298+
if state == serviceStateStopped {
299+
continue
300+
}
301+
302+
_, err = stopSingleService(serviceName, service.WmiInstance)
303+
if err != nil {
304+
return err
305+
}
121306
}
122307

123308
return nil

0 commit comments

Comments
 (0)