Skip to content

Commit 0250449

Browse files
committed
feat: Generate SSH server keys in host agent and use them in guest OS
This change changes the SSH server keys that have been generated for each boot in guest OS to be generated by hostagent for each boot. This allows the hostagent to obtain the public key before booting, so that knownhosts can be used with an ssh connection. The code that uses `ssh.InsecureIgnoreHostKey()` in `x/crypto/ssh` is pointed out in CodeQL as `Use of insecure HostKeyCallback implementation (High)`, so it is an implementation to avoid this. Signed-off-by: Norio Nomura <[email protected]>
1 parent 8ca744e commit 0250449

File tree

10 files changed

+172
-39
lines changed

10 files changed

+172
-39
lines changed

pkg/cidata/cidata.TEMPLATE.d/user-data

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,11 @@ bootcmd:
104104
{{- end }}
105105
{{- end }}
106106
{{- end }}
107+
108+
{{- if .SSHHostKeys }}
109+
ssh_keys:
110+
{{- range $type, $key := .SSHHostKeys }}
111+
{{ $type }}: |
112+
{{ indent 4 $key }}
113+
{{- end }}
114+
{{- end }}

pkg/cidata/cidata.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func setupEnv(instConfigEnv map[string]string, propagateProxyEnv bool, slirpGate
118118
return env, nil
119119
}
120120

121-
func templateArgs(ctx context.Context, bootScripts bool, instDir, name string, instConfig *limatype.LimaYAML, udpDNSLocalPort, tcpDNSLocalPort, vsockPort int, virtioPort string, noCloudInit, rosettaEnabled, rosettaBinFmt bool) (*TemplateArgs, error) {
121+
func templateArgs(ctx context.Context, bootScripts bool, instDir, name string, instConfig *limatype.LimaYAML, udpDNSLocalPort, tcpDNSLocalPort, vsockPort int, virtioPort string, noCloudInit, rosettaEnabled, rosettaBinFmt, hostKeys bool) (*TemplateArgs, error) {
122122
if err := limayaml.Validate(instConfig, false); err != nil {
123123
return nil, err
124124
}
@@ -342,11 +342,19 @@ func templateArgs(ctx context.Context, bootScripts bool, instDir, name string, i
342342
}
343343
}
344344

345+
if hostKeys {
346+
sshHostKeys, err := sshutil.GenerateSSHHostKeys(instDir, args.Hostname)
347+
if err != nil {
348+
return nil, fmt.Errorf("failed to generate SSH host keys: %w", err)
349+
}
350+
args.SSHHostKeys = sshHostKeys
351+
}
352+
345353
return &args, nil
346354
}
347355

348356
func GenerateCloudConfig(ctx context.Context, instDir, name string, instConfig *limatype.LimaYAML) error {
349-
args, err := templateArgs(ctx, false, instDir, name, instConfig, 0, 0, 0, "", false, false, false)
357+
args, err := templateArgs(ctx, false, instDir, name, instConfig, 0, 0, 0, "", false, false, false, false)
350358
if err != nil {
351359
return err
352360
}
@@ -369,7 +377,7 @@ func GenerateCloudConfig(ctx context.Context, instDir, name string, instConfig *
369377
}
370378

371379
func GenerateISO9660(ctx context.Context, drv driver.Driver, instDir, name string, instConfig *limatype.LimaYAML, udpDNSLocalPort, tcpDNSLocalPort int, guestAgentBinary, nerdctlArchive string, vsockPort int, virtioPort string, noCloudInit, rosettaEnabled, rosettaBinFmt bool) error {
372-
args, err := templateArgs(ctx, true, instDir, name, instConfig, udpDNSLocalPort, tcpDNSLocalPort, vsockPort, virtioPort, noCloudInit, rosettaEnabled, rosettaBinFmt)
380+
args, err := templateArgs(ctx, true, instDir, name, instConfig, udpDNSLocalPort, tcpDNSLocalPort, vsockPort, virtioPort, noCloudInit, rosettaEnabled, rosettaBinFmt, true)
373381
if err != nil {
374382
return err
375383
}
@@ -467,6 +475,13 @@ func GenerateISO9660(ctx context.Context, drv driver.Driver, instDir, name strin
467475
Path: "ssh_authorized_keys",
468476
Reader: strings.NewReader(strings.Join(args.SSHPubKeys, "\n")),
469477
})
478+
for keyType, keyContent := range args.SSHHostKeys {
479+
suffix := strings.Replace(strings.Replace(keyType, "_public", "_key.pub", 1), "_private", "_key", 1)
480+
layout = append(layout, iso9660util.Entry{
481+
Path: "ssh_host_" + suffix,
482+
Reader: strings.NewReader(keyContent),
483+
})
484+
}
470485
return writeCIDataDir(filepath.Join(instDir, filenames.CIDataISODir), layout)
471486
}
472487

pkg/cidata/template.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ type TemplateArgs struct {
115115
Plain bool
116116
TimeZone string
117117
NoCloudInit bool
118+
SSHHostKeys map[string]string // `ssh_keys` field in cloud-init SSH module
118119
}
119120

120121
func ValidateTemplateArgs(args *TemplateArgs) error {

pkg/driver/vz/vsock_forwarder.go

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,11 @@ import (
99
"context"
1010
"errors"
1111
"net"
12-
"path/filepath"
1312

1413
"github.com/containers/gvisor-tap-vsock/pkg/tcpproxy"
1514
"github.com/sirupsen/logrus"
1615

1716
"github.com/lima-vm/lima/v2/pkg/limatype"
18-
"github.com/lima-vm/lima/v2/pkg/limatype/dirnames"
19-
"github.com/lima-vm/lima/v2/pkg/limatype/filenames"
2017
"github.com/lima-vm/lima/v2/pkg/sshutil"
2118
)
2219

@@ -75,15 +72,7 @@ func (m *virtualMachineWrapper) dialVsock(_ context.Context, port uint32) (conn
7572
}
7673

7774
func (m *virtualMachineWrapper) checkSSHOverVsockAvailable(ctx context.Context, inst *limatype.Instance) error {
78-
user := *inst.Config.User.Name
79-
configDir, err := dirnames.LimaConfigDir()
80-
if err != nil {
81-
return err
82-
}
83-
privateKeyPath := filepath.Join(configDir, filenames.UserPrivateKey)
84-
vsockPort := uint32(22)
85-
addr := "vsock:22"
8675
return sshutil.WaitSSHReady(ctx, func(ctx context.Context) (net.Conn, error) {
87-
return m.dialVsock(ctx, vsockPort)
88-
}, addr, user, privateKeyPath, 1)
76+
return m.dialVsock(ctx, uint32(22))
77+
}, "vsock:22", *inst.Config.User.Name, inst.Name, 1)
8978
}

pkg/driver/wsl2/boot/02-no-cloud-init-setup.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ chmod 700 "${LIMA_CIDATA_HOME}"/.ssh/
1717
cp "${LIMA_CIDATA_MNT}"/ssh_authorized_keys "${LIMA_CIDATA_HOME}"/.ssh/authorized_keys
1818
chown "${LIMA_CIDATA_UID}:${LIMA_CIDATA_GID}" "${LIMA_CIDATA_HOME}"/.ssh/authorized_keys
1919
chmod 600 "${LIMA_CIDATA_HOME}"/.ssh/authorized_keys
20+
# copy SSH host keys
21+
mkdir -p /etc/ssh/
22+
cp "${LIMA_CIDATA_MNT}"/ssh_host_* /etc/ssh/
23+
chmod 600 /etc/ssh/ssh_host_*
24+
chmod 644 /etc/ssh/ssh_host_*.pub
2025

2126
# add $LIMA_CIDATA_USER to sudoers
2227
echo "${LIMA_CIDATA_USER} ALL=(ALL) NOPASSWD:ALL" | tee -a /etc/sudoers.d/99_lima_sudoers

pkg/limatype/filenames/filenames.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ const (
5050
SerialVirtioSock = "serialv.sock"
5151
SSHSock = "ssh.sock"
5252
SSHConfig = "ssh.config"
53+
SSHKnownHosts = "ssh_known_hosts"
5354
VhostSock = "virtiofsd-%d.sock"
5455
VNCDisplayFile = "vncdisplay"
5556
VNCPasswordFile = "vncpassword"

pkg/networks/usernet/client.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"net"
1212
"net/http"
1313
"os"
14-
"path/filepath"
1514
"strconv"
1615
"time"
1716

@@ -20,8 +19,6 @@ import (
2019

2120
"github.com/lima-vm/lima/v2/pkg/httpclientutil"
2221
"github.com/lima-vm/lima/v2/pkg/limatype"
23-
"github.com/lima-vm/lima/v2/pkg/limatype/dirnames"
24-
"github.com/lima-vm/lima/v2/pkg/limatype/filenames"
2522
"github.com/lima-vm/lima/v2/pkg/limayaml"
2623
"github.com/lima-vm/lima/v2/pkg/networks/usernet/dnshosts"
2724
)
@@ -144,13 +141,9 @@ func (c *Client) WaitOpeningSSHPort(ctx context.Context, inst *limatype.Instance
144141
return err
145142
}
146143
user := *inst.Config.User.Name
147-
configDir, err := dirnames.LimaConfigDir()
148-
if err != nil {
149-
return err
150-
}
151-
privateKeyPath := filepath.Join(configDir, filenames.UserPrivateKey)
144+
instanceName := inst.Name
152145
// -1 avoids both sides timing out simultaneously.
153-
u := fmt.Sprintf("%s/extension/wait_ssh_server?ip=%s&port=22&timeout=%d&user=%s&privateKeyPath=%s", c.base, ipAddr, timeoutSeconds-1, user, privateKeyPath)
146+
u := fmt.Sprintf("%s/extension/wait-ssh-server?ip=%s&port=22&timeout=%d&user=%s&instance-name=%s", c.base, ipAddr, timeoutSeconds-1, user, instanceName)
154147
res, err := httpclientutil.Get(ctx, c.client, u)
155148
if err != nil {
156149
return err

pkg/networks/usernet/gvproxy.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ func httpServe(ctx context.Context, g *errgroup.Group, ln net.Listener, mux http
245245

246246
func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux {
247247
m := n.Mux()
248-
m.HandleFunc("/extension/wait_ssh_server", func(w http.ResponseWriter, r *http.Request) {
248+
m.HandleFunc("/extension/wait-ssh-server", func(w http.ResponseWriter, r *http.Request) {
249249
ip := r.URL.Query().Get("ip")
250250
if net.ParseIP(ip) == nil {
251251
msg := fmt.Sprintf("invalid ip address: %s", ip)
@@ -260,9 +260,9 @@ func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux {
260260
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", uint16(port16)))
261261

262262
user := r.URL.Query().Get("user")
263-
privateKeyPath := r.URL.Query().Get("privateKeyPath")
264-
if user == "" || privateKeyPath == "" {
265-
msg := "user and privateKeyPath query parameters are required"
263+
instanceName := r.URL.Query().Get("instance-name")
264+
if user == "" || instanceName == "" {
265+
msg := "user and instanceName query parameters are required"
266266
http.Error(w, msg, http.StatusBadRequest)
267267
return
268268
}
@@ -280,7 +280,7 @@ func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux {
280280
return n.DialContextTCP(ctx, addr)
281281
}
282282
// Wait until the port is available.
283-
if err = sshutil.WaitSSHReady(r.Context(), dialContext, addr, user, privateKeyPath, timeoutSeconds); err != nil {
283+
if err = sshutil.WaitSSHReady(r.Context(), dialContext, addr, user, instanceName, timeoutSeconds); err != nil {
284284
http.Error(w, err.Error(), http.StatusRequestTimeout)
285285
} else {
286286
w.WriteHeader(http.StatusOK)

pkg/sshutil/sshutil.go

Lines changed: 129 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,18 @@ package sshutil
66
import (
77
"bytes"
88
"context"
9+
"crypto"
10+
"crypto/ecdsa"
11+
"crypto/ed25519"
12+
"crypto/elliptic"
13+
"crypto/rand"
14+
"crypto/rsa"
915
"encoding/base64"
1016
"encoding/binary"
17+
"encoding/pem"
1118
"errors"
1219
"fmt"
20+
"io"
1321
"io/fs"
1422
"net"
1523
"os"
@@ -26,8 +34,10 @@ import (
2634
"github.com/mattn/go-shellwords"
2735
"github.com/sirupsen/logrus"
2836
"golang.org/x/crypto/ssh"
37+
"golang.org/x/crypto/ssh/knownhosts"
2938
"golang.org/x/sys/cpu"
3039

40+
"github.com/lima-vm/lima/v2/pkg/instance/hostname"
3141
"github.com/lima-vm/lima/v2/pkg/ioutilx"
3242
"github.com/lima-vm/lima/v2/pkg/limatype/dirnames"
3343
"github.com/lima-vm/lima/v2/pkg/limatype/filenames"
@@ -244,7 +254,6 @@ func CommonOpts(ctx context.Context, sshExe SSHExe, useDotSSH bool) ([]string, e
244254

245255
opts = append(opts,
246256
"StrictHostKeyChecking=no",
247-
"UserKnownHostsFile=/dev/null",
248257
"NoHostAuthenticationForLocalhost=yes",
249258
"PreferredAuthentications=publickey",
250259
"Compression=no",
@@ -345,18 +354,28 @@ func SSHOpts(ctx context.Context, sshExe SSHExe, instDir, username string, useDo
345354
return nil, err
346355
}
347356
controlPath := fmt.Sprintf(`ControlPath="%s"`, controlSock)
357+
userKnownHostsPath := filepath.Join(instDir, filenames.SSHKnownHosts)
358+
userKnownHosts := fmt.Sprintf(`UserKnownHostsFile="%s"`, userKnownHostsPath)
348359
if runtime.GOOS == "windows" {
349360
controlSock, err = ioutilx.WindowsSubsystemPath(ctx, controlSock)
350361
if err != nil {
351362
return nil, err
352363
}
353364
controlPath = fmt.Sprintf(`ControlPath='%s'`, controlSock)
365+
userKnownHostsPath, err = ioutilx.WindowsSubsystemPath(ctx, userKnownHostsPath)
366+
if err != nil {
367+
return nil, err
368+
}
369+
userKnownHosts = fmt.Sprintf(`UserKnownHostsFile='%s'`, userKnownHostsPath)
354370
}
371+
hostKeyAlias := fmt.Sprintf("HostKeyAlias=%s", hostname.FromInstName(filepath.Base(instDir)))
355372
opts = append(opts,
356373
fmt.Sprintf("User=%s", username), // guest and host have the same username, but we should specify the username explicitly (#85)
357374
"ControlMaster=auto",
358375
controlPath,
359376
"ControlPersist=yes",
377+
userKnownHosts,
378+
hostKeyAlias,
360379
)
361380
if forwardAgent {
362381
opts = append(opts, "ForwardAgent=yes")
@@ -516,27 +535,27 @@ func detectAESAcceleration() bool {
516535
// The dialContext function is used to create a connection to the SSH server.
517536
// The addr, user, privateKeyPath parameter is used for ssh.ClientConn creation.
518537
// The timeoutSeconds parameter specifies the maximum number of seconds to wait.
519-
func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Conn, error), addr, user, privateKeyPath string, timeoutSeconds int) error {
538+
func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Conn, error), addr, user, instanceName string, timeoutSeconds int) error {
520539
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
521540
defer cancel()
522541

523542
// Prepare signer
524-
key, err := os.ReadFile(privateKeyPath)
543+
signer, err := UserPrivateKey()
525544
if err != nil {
526-
return fmt.Errorf("failed to read private key %q: %w", privateKeyPath, err)
545+
return err
527546
}
528-
signer, err := ssh.ParsePrivateKey(key)
547+
// Prepare HostKeyCallback
548+
hostKeyChecker, err := HostKeyCheckerWithKeysInKnownHosts(instanceName)
529549
if err != nil {
530-
return fmt.Errorf("failed to parse private key %q: %w", privateKeyPath, err)
550+
return err
531551
}
532552
// Prepare ssh client config
533553
sshConfig := &ssh.ClientConfig{
534554
User: user,
535555
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
536-
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
556+
HostKeyCallback: hostKeyChecker,
537557
Timeout: 10 * time.Second,
538558
}
539-
540559
// Wait until the SSH server is available.
541560
for {
542561
conn, err := dialContext(ctx)
@@ -567,3 +586,105 @@ func isRetryableError(err error) bool {
567586
// SSH server not ready yet (e.g. host key not generated on initial boot).
568587
strings.HasSuffix(err.Error(), "no supported methods remain")
569588
}
589+
590+
// UserPrivateKey returns the user's private key signer.
591+
// The public key is always installed in the VM.
592+
func UserPrivateKey() (ssh.Signer, error) {
593+
configDir, err := dirnames.LimaConfigDir()
594+
if err != nil {
595+
return nil, err
596+
}
597+
privateKeyPath := filepath.Join(configDir, filenames.UserPrivateKey)
598+
key, err := os.ReadFile(privateKeyPath)
599+
if err != nil {
600+
return nil, fmt.Errorf("failed to read private key %q: %w", privateKeyPath, err)
601+
}
602+
signer, err := ssh.ParsePrivateKey(key)
603+
if err != nil {
604+
return nil, fmt.Errorf("failed to parse private key %q: %w", privateKeyPath, err)
605+
}
606+
return signer, nil
607+
}
608+
609+
func HostKeyCheckerWithKeysInKnownHosts(instanceName string) (ssh.HostKeyCallback, error) {
610+
publicKeys, err := PublicKeysFromKnownHosts(instanceName)
611+
if err != nil {
612+
return nil, err
613+
}
614+
return func(_ string, _ net.Addr, key ssh.PublicKey) error {
615+
keyBytes := key.Marshal()
616+
for _, pk := range publicKeys {
617+
if bytes.Equal(keyBytes, pk.Marshal()) {
618+
return nil
619+
}
620+
}
621+
return errors.New("ssh: host key mismatch")
622+
}, nil
623+
}
624+
625+
// PublicKeysFromKnownHosts returns the public keys from the known_hosts file located in the instance directory.
626+
func PublicKeysFromKnownHosts(instanceName string) ([]ssh.PublicKey, error) {
627+
// Load known_hosts from the instance directory
628+
instanceDir, err := dirnames.InstanceDir(instanceName)
629+
if err != nil {
630+
return nil, fmt.Errorf("failed to get instance dir for instance %q: %w", instanceName, err)
631+
}
632+
knownHostsPath := filepath.Join(instanceDir, filenames.SSHKnownHosts)
633+
knownHostsBytes, err := os.ReadFile(knownHostsPath)
634+
if err != nil {
635+
return nil, fmt.Errorf("failed to read known_hosts file at %s: %w", knownHostsPath, err)
636+
}
637+
var publicKeys []ssh.PublicKey
638+
rest := knownHostsBytes
639+
for len(rest) > 0 {
640+
var publicKey ssh.PublicKey
641+
publicKey, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
642+
if err != nil {
643+
return nil, fmt.Errorf("failed to parse public key from known_hosts file %s: %w", knownHostsPath, err)
644+
}
645+
publicKeys = append(publicKeys, publicKey)
646+
}
647+
return publicKeys, nil
648+
}
649+
650+
// GenerateSSHHostKeys generates an Ed25519 host key pair for the SSH server.
651+
// The private key is returned in PEM format, and the public key.
652+
func GenerateSSHHostKeys(instDir, hostname string) (map[string]string, error) {
653+
generators := map[string]func(io.Reader) (crypto.PrivateKey, error){
654+
"ecdsa": func(rand io.Reader) (crypto.PrivateKey, error) {
655+
return ecdsa.GenerateKey(elliptic.P256(), rand)
656+
},
657+
"ed25519": func(rand io.Reader) (crypto.PrivateKey, error) {
658+
_, priv, err := ed25519.GenerateKey(rand)
659+
return priv, err
660+
},
661+
"rsa": func(rand io.Reader) (crypto.PrivateKey, error) {
662+
return rsa.GenerateKey(rand, 3072)
663+
},
664+
}
665+
res := make(map[string]string, len(generators))
666+
var sshKnownHosts []byte
667+
for keyType, generator := range generators {
668+
priv, err := generator(rand.Reader)
669+
if err != nil {
670+
return nil, err
671+
}
672+
privPem, err := ssh.MarshalPrivateKey(priv, hostname)
673+
if err != nil {
674+
return nil, fmt.Errorf("failed to marshal %s private key to PEM format: %w", keyType, err)
675+
}
676+
pub, err := ssh.NewPublicKey(priv.(crypto.Signer).Public())
677+
if err != nil {
678+
return nil, fmt.Errorf("failed to create ssh %s public key: %w", keyType, err)
679+
}
680+
res[keyType+"_private"] = string(pem.EncodeToMemory(privPem))
681+
res[keyType+"_public"] = string(ssh.MarshalAuthorizedKey(pub))
682+
sshKnownHosts = append(sshKnownHosts, knownhosts.Line([]string{hostname}, pub)...)
683+
sshKnownHosts = append(sshKnownHosts, '\n')
684+
}
685+
knownHostsPath := filepath.Join(instDir, filenames.SSHKnownHosts)
686+
if err := os.WriteFile(knownHostsPath, sshKnownHosts, 0o644); err != nil {
687+
return nil, fmt.Errorf("failed to write known_hosts file at %s: %w", knownHostsPath, err)
688+
}
689+
return res, nil
690+
}

0 commit comments

Comments
 (0)