Skip to content

Commit ac1c301

Browse files
committed
pkg/driver/vz: Try SSH handshake to check if SSH port is available.
Check the SSH server in a way that complies with the SSH protocol using x/crypto/ssh. This change fixes #4334 by falling back to usernet port forwarder on failing SSH connections over VSOCK. - pkg/networks/usernet: Rename entry point from `/extension/wait_port` to `/extension/wait-ssh-server` Because it changed to an SSH server-specific entry point. When a client accesses the old entry point, it fails and continues with falling back to the usernet forwarder. - pkg/sshutil: Add `WaitSSHReady()` WaitSSHReady waits until the SSH server is ready to accept connections. The dialContext function is used to create a connection to the SSH server. The addr, user parameter is used for ssh.ClientConn creation. The timeoutSeconds parameter specifies the maximum number of seconds to wait. Signed-off-by: Norio Nomura <[email protected]> # Conflicts: # go.mod
1 parent b1e6640 commit ac1c301

File tree

8 files changed

+121
-41
lines changed

8 files changed

+121
-41
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ require (
117117
github.com/x448/float16 v0.8.4 // indirect
118118
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
119119
github.com/yuin/gopher-lua v1.1.1 // indirect
120-
golang.org/x/crypto v0.44.0 // indirect
120+
golang.org/x/crypto v0.44.0
121121
golang.org/x/mod v0.29.0 // indirect
122122
golang.org/x/oauth2 v0.32.0 // indirect
123123
golang.org/x/term v0.37.0 // indirect

pkg/driver/vz/vm_darwin.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,18 @@ func startVM(ctx context.Context, inst *limatype.Instance, sshLocalPort int) (vm
113113
useSSHOverVsock = b
114114
}
115115
}
116+
hostAddress := net.JoinHostPort(inst.SSHAddress, strconv.Itoa(usernetSSHLocalPort))
116117
if !useSSHOverVsock {
117118
logrus.Info("LIMA_SSH_OVER_VSOCK is false, skipping detection of SSH server on vsock port")
118-
} else if err := usernetClient.WaitOpeningSSHPort(ctx, inst); err == nil {
119-
hostAddress := net.JoinHostPort(inst.SSHAddress, strconv.Itoa(usernetSSHLocalPort))
120-
if err := wrapper.startVsockForwarder(ctx, 22, hostAddress); err == nil {
121-
logrus.Infof("Detected SSH server is listening on the vsock port; changed %s to proxy for the vsock port", hostAddress)
122-
usernetSSHLocalPort = 0 // disable gvisor ssh port forwarding
123-
} else {
124-
logrus.WithError(err).Warn("Failed to detect SSH server on vsock port, falling back to usernet forwarder")
125-
}
119+
} else if err := usernetClient.WaitOpeningSSHPort(ctx, inst); err != nil {
120+
logrus.WithError(err).Info("Failed to wait for the guest SSH server to become available, falling back to usernet forwarder")
121+
} else if err := wrapper.checkSSHOverVsockAvailable(ctx, inst); err != nil {
122+
logrus.WithError(err).Info("Failed to detect SSH server on vsock port, falling back to usernet forwarder")
123+
} else if err := wrapper.startVsockForwarder(ctx, 22, hostAddress); err != nil {
124+
logrus.WithError(err).Info("Failed to start SSH server forwarder on vsock port, falling back to usernet forwarder")
126125
} else {
127-
logrus.WithError(err).Warn("Failed to wait for the guest SSH server to become available, falling back to usernet forwarder")
126+
logrus.Infof("Detected SSH server is listening on the vsock port; changed %s to proxy for the vsock port", hostAddress)
127+
usernetSSHLocalPort = 0 // disable gvisor ssh port forwarding
128128
}
129129
err := usernetClient.ConfigureDriver(ctx, inst, usernetSSHLocalPort)
130130
if err != nil {

pkg/driver/vz/vsock_forwarder.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,14 @@ import (
1212

1313
"github.com/containers/gvisor-tap-vsock/pkg/tcpproxy"
1414
"github.com/sirupsen/logrus"
15+
16+
"github.com/lima-vm/lima/v2/pkg/limatype"
17+
"github.com/lima-vm/lima/v2/pkg/sshutil"
1518
)
1619

1720
func (m *virtualMachineWrapper) startVsockForwarder(ctx context.Context, vsockPort uint32, hostAddress string) error {
18-
// Test if the vsock port is open
19-
conn, err := m.dialVsock(ctx, vsockPort)
20-
if err != nil {
21-
return err
22-
}
23-
conn.Close()
2421
// Start listening on localhost:hostPort and forward to vsock:vsockPort
25-
_, _, err = net.SplitHostPort(hostAddress)
22+
_, _, err := net.SplitHostPort(hostAddress)
2623
if err != nil {
2724
return err
2825
}
@@ -73,3 +70,9 @@ func (m *virtualMachineWrapper) dialVsock(_ context.Context, port uint32) (conn
7370
}
7471
return nil, err
7572
}
73+
74+
func (m *virtualMachineWrapper) checkSSHOverVsockAvailable(ctx context.Context, inst *limatype.Instance) error {
75+
return sshutil.WaitSSHReady(ctx, func(ctx context.Context) (net.Conn, error) {
76+
return m.dialVsock(ctx, uint32(22))
77+
}, "vsock:22", *inst.Config.User.Name, 1)
78+
}

pkg/networks/usernet/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ func (c *Client) WaitOpeningSSHPort(ctx context.Context, inst *limatype.Instance
140140
if err != nil {
141141
return err
142142
}
143+
user := *inst.Config.User.Name
143144
// -1 avoids both sides timing out simultaneously.
144-
u := fmt.Sprintf("%s/extension/wait_port?ip=%s&port=22&timeout=%d", c.base, ipAddr, timeoutSeconds-1)
145+
u := fmt.Sprintf("%s/extension/wait-ssh-server?ip=%s&port=22&timeout=%d&user=%s", c.base, ipAddr, timeoutSeconds-1, user)
145146
res, err := httpclientutil.Get(ctx, c.client, u)
146147
if err != nil {
147148
return err

pkg/networks/usernet/gvproxy.go

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222
"github.com/containers/gvisor-tap-vsock/pkg/virtualnetwork"
2323
"github.com/sirupsen/logrus"
2424
"golang.org/x/sync/errgroup"
25+
26+
"github.com/lima-vm/lima/v2/pkg/sshutil"
2527
)
2628

2729
type GVisorNetstackOpts struct {
@@ -243,7 +245,7 @@ func httpServe(ctx context.Context, g *errgroup.Group, ln net.Listener, mux http
243245

244246
func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux {
245247
m := n.Mux()
246-
m.HandleFunc("/extension/wait_port", func(w http.ResponseWriter, r *http.Request) {
248+
m.HandleFunc("/extension/wait-ssh-server", func(w http.ResponseWriter, r *http.Request) {
247249
ip := r.URL.Query().Get("ip")
248250
if net.ParseIP(ip) == nil {
249251
msg := fmt.Sprintf("invalid ip address: %s", ip)
@@ -255,8 +257,14 @@ func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux {
255257
http.Error(w, err.Error(), http.StatusBadRequest)
256258
return
257259
}
258-
port := uint16(port16)
259-
addr := fmt.Sprintf("%s:%d", ip, port)
260+
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", uint16(port16)))
261+
262+
user := r.URL.Query().Get("user")
263+
if user == "" {
264+
msg := "user query parameter is required"
265+
http.Error(w, msg, http.StatusBadRequest)
266+
return
267+
}
260268

261269
timeoutSeconds := 10
262270
if timeoutString := r.URL.Query().Get("timeout"); timeoutString != "" {
@@ -267,27 +275,14 @@ func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux {
267275
}
268276
timeoutSeconds = int(timeout16)
269277
}
270-
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds)*time.Second)
271-
defer cancel()
278+
dialContext := func(ctx context.Context) (net.Conn, error) {
279+
return n.DialContextTCP(ctx, addr)
280+
}
272281
// Wait until the port is available.
273-
for {
274-
conn, err := n.DialContextTCP(ctx, addr)
275-
if err == nil {
276-
conn.Close()
277-
logrus.Debugf("Port is available on %s", addr)
278-
w.WriteHeader(http.StatusOK)
279-
break
280-
}
281-
select {
282-
case <-ctx.Done():
283-
msg := fmt.Sprintf("timed out waiting for port to become available on %s", addr)
284-
logrus.Warn(msg)
285-
http.Error(w, msg, http.StatusRequestTimeout)
286-
return
287-
default:
288-
}
289-
logrus.Debugf("Waiting for port to become available on %s", addr)
290-
time.Sleep(1 * time.Second)
282+
if err = sshutil.WaitSSHReady(r.Context(), dialContext, addr, user, timeoutSeconds); err != nil {
283+
http.Error(w, err.Error(), http.StatusRequestTimeout)
284+
} else {
285+
w.WriteHeader(http.StatusOK)
291286
}
292287
})
293288
return m

pkg/osutil/osutil_unix.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package osutil
88
import (
99
"bytes"
1010
"context"
11+
"errors"
1112
"fmt"
1213
"os"
1314
"os/exec"
@@ -36,3 +37,7 @@ func Sysctl(ctx context.Context, name string) (string, error) {
3637
}
3738
return strings.TrimSuffix(string(stdout), "\n"), nil
3839
}
40+
41+
func IsConnectionResetError(err error) bool {
42+
return errors.Is(err, syscall.ECONNRESET)
43+
}

pkg/osutil/osutil_windows.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,7 @@ func SignalName(sig os.Signal) string {
5757
func Sysctl(_ context.Context, _ string) (string, error) {
5858
return "", errors.New("sysctl: unimplemented on Windows")
5959
}
60+
61+
func IsConnectionResetError(err error) bool {
62+
return errors.Is(err, syscall.WSAECONNRESET)
63+
}

pkg/sshutil/sshutil.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"errors"
1212
"fmt"
1313
"io/fs"
14+
"net"
1415
"os"
1516
"os/exec"
1617
"path/filepath"
@@ -24,6 +25,7 @@ import (
2425
"github.com/coreos/go-semver/semver"
2526
"github.com/mattn/go-shellwords"
2627
"github.com/sirupsen/logrus"
28+
"golang.org/x/crypto/ssh"
2729
"golang.org/x/sys/cpu"
2830

2931
"github.com/lima-vm/lima/v2/pkg/ioutilx"
@@ -509,3 +511,73 @@ func detectAESAcceleration() bool {
509511
}
510512
return cpu.ARM.HasAES || cpu.ARM64.HasAES || cpu.PPC64.IsPOWER8 || cpu.S390X.HasAES || cpu.X86.HasAES
511513
}
514+
515+
// WaitSSHReady waits until the SSH server is ready to accept connections.
516+
// The dialContext function is used to create a connection to the SSH server.
517+
// The addr, user, parameter is used for ssh.ClientConn creation.
518+
// The timeoutSeconds parameter specifies the maximum number of seconds to wait.
519+
func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Conn, error), addr, user string, timeoutSeconds int) error {
520+
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
521+
defer cancel()
522+
523+
// Prepare signer
524+
signer, err := userPrivateKeySigner()
525+
if err != nil {
526+
return err
527+
}
528+
// Prepare ssh client config
529+
sshConfig := &ssh.ClientConfig{
530+
User: user,
531+
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
532+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
533+
Timeout: 10 * time.Second,
534+
}
535+
// Wait until the SSH server is available.
536+
for {
537+
conn, err := dialContext(ctx)
538+
if err == nil {
539+
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, sshConfig)
540+
if err == nil {
541+
sshClient := ssh.NewClient(sshConn, chans, reqs)
542+
return sshClient.Close()
543+
}
544+
conn.Close()
545+
if !isRetryableError(err) {
546+
return fmt.Errorf("failed to create ssh.Conn to %q: %w", addr, err)
547+
}
548+
}
549+
logrus.Debugf("Waiting for SSH port to accept connections on %s", addr)
550+
select {
551+
case <-ctx.Done():
552+
return fmt.Errorf("failed to waiting for SSH port to become available on %s: %w", addr, ctx.Err())
553+
case <-time.After(1 * time.Second):
554+
continue
555+
}
556+
}
557+
}
558+
559+
func isRetryableError(err error) bool {
560+
// Port forwarder accepted the connection, but the destination is not ready yet.
561+
return osutil.IsConnectionResetError(err) ||
562+
// SSH server not ready yet (e.g. host key not generated on initial boot).
563+
strings.HasSuffix(err.Error(), "no supported methods remain")
564+
}
565+
566+
// userPrivateKeySigner returns the user's private key signer.
567+
// The public key is always installed in the VM.
568+
func userPrivateKeySigner() (ssh.Signer, error) {
569+
configDir, err := dirnames.LimaConfigDir()
570+
if err != nil {
571+
return nil, err
572+
}
573+
privateKeyPath := filepath.Join(configDir, filenames.UserPrivateKey)
574+
key, err := os.ReadFile(privateKeyPath)
575+
if err != nil {
576+
return nil, fmt.Errorf("failed to read private key %q: %w", privateKeyPath, err)
577+
}
578+
signer, err := ssh.ParsePrivateKey(key)
579+
if err != nil {
580+
return nil, fmt.Errorf("failed to parse private key %q: %w", privateKeyPath, err)
581+
}
582+
return signer, nil
583+
}

0 commit comments

Comments
 (0)