Skip to content

Commit 27b0dba

Browse files
Rebase from master
1 parent ffffabf commit 27b0dba

File tree

10 files changed

+394
-1
lines changed

10 files changed

+394
-1
lines changed

commands.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,22 @@ func getOwnCommands() []ownCommand {
158158
needConfiguration: false,
159159
hide: true,
160160
},
161+
{
162+
name: "send",
163+
description: "send a configuration profile to a remote client",
164+
action: sendProfileCommand,
165+
needConfiguration: true,
166+
noProfile: true,
167+
hide: true,
168+
},
169+
{
170+
name: "serve",
171+
description: "serve configuration profiles to remote clients",
172+
action: serveCommand,
173+
needConfiguration: true,
174+
noProfile: true,
175+
hide: true,
176+
},
161177
}
162178
}
163179

config/config.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,22 @@ func (c *Config) getProfilePath(key string) string {
701701
return c.flatKey(constants.SectionConfigurationProfiles, key)
702702
}
703703

704+
// HasRemote returns true if the remote exists in the configuration
705+
func (c *Config) HasRemote(remoteName string) bool {
706+
return c.IsSet(c.flatKey(constants.SectionConfigurationRemotes, remoteName))
707+
}
708+
709+
func (c *Config) GetRemote(remoteName string) (*Remote, error) {
710+
// we don't need to check the file version: the remotes can be in a separate configuration file
711+
712+
remote := NewRemote(c, remoteName)
713+
err := c.unmarshalKey(c.flatKey(constants.SectionConfigurationRemotes, remoteName), remote)
714+
715+
rootPath := filepath.Dir(c.GetConfigFile())
716+
remote.SetRootPath(rootPath)
717+
return remote, err
718+
}
719+
704720
// unmarshalConfig returns the decoder config options depending on the configuration version and format
705721
func (c *Config) unmarshalConfig() viper.DecoderConfigOption {
706722
if c.GetVersion() == Version01 {

config/error.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ package config
33
import "errors"
44

55
var (
6-
ErrNotFound = errors.New("not found")
6+
ErrNotFound = errors.New("not found")
7+
ErrNotSupportedInVersion1 = errors.New("not supported in configuration version 1")
78
)

config/remote .go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package config
2+
3+
type Remote struct {
4+
name string
5+
config *Config
6+
Connection string `mapstructure:"connection" default:"ssh" description:"Connection type to use to connect to the remote client"`
7+
Host string `mapstructure:"host" description:"Address of the remote client. Format: <host>:<port>"`
8+
Username string `mapstructure:"username" description:"User to connect to the remote client"`
9+
PrivateKeyPath string `mapstructure:"private-key" description:"Path to the private key to use for authentication"`
10+
KnownHostsPath string `mapstructure:"known-hosts" description:"Path to the known hosts file"`
11+
SendFiles []string `mapstructure:"send-files" description:"Configuration files to transfer to the remote client"`
12+
}
13+
14+
func NewRemote(config *Config, name string) *Remote {
15+
remote := &Remote{
16+
name: name,
17+
config: config,
18+
}
19+
return remote
20+
}
21+
22+
// SetRootPath changes the path of all the relative paths and files in the configuration
23+
func (r *Remote) SetRootPath(rootPath string) {
24+
r.PrivateKeyPath = fixPath(r.PrivateKeyPath, expandEnv, absolutePrefix(rootPath))
25+
r.KnownHostsPath = fixPath(r.KnownHostsPath, expandEnv, absolutePrefix(rootPath))
26+
27+
for i := range r.SendFiles {
28+
r.SendFiles[i] = fixPath(r.SendFiles[i], expandEnv, absolutePrefix(rootPath))
29+
}
30+
}

constants/section.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ const (
1313
SectionConfigurationMixins = "mixins"
1414
SectionConfigurationMixinUse = "use"
1515
SectionConfigurationSchedule = "schedule"
16+
SectionConfigurationRemotes = "remotes"
1617

1718
SectionDefinitionCommon = "common"
1819
SectionDefinitionForget = "forget"

send.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
8+
"github.com/creativeprojects/resticprofile/ssh"
9+
)
10+
11+
func sendProfileCommand(w io.Writer, cmdCtx commandContext) error {
12+
if len(cmdCtx.flags.resticArgs) < 2 {
13+
return fmt.Errorf("missing argument: remote name")
14+
}
15+
remote, err := cmdCtx.config.GetRemote(cmdCtx.flags.resticArgs[1])
16+
if err != nil {
17+
return err
18+
}
19+
// send the files to the remote using tar
20+
handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
21+
sendFiles(resp, remote.SendFiles)
22+
})
23+
cnx := ssh.NewSSH(ssh.Config{
24+
Host: remote.Host,
25+
Username: remote.Username,
26+
PrivateKeyPath: remote.PrivateKeyPath,
27+
KnownHostsPath: remote.KnownHostsPath,
28+
Handler: handler,
29+
})
30+
defer cnx.Close()
31+
32+
err = cnx.Connect()
33+
if err != nil {
34+
return err
35+
}
36+
return nil
37+
}

serve.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"net/http"
8+
"os"
9+
"os/signal"
10+
"time"
11+
12+
"github.com/creativeprojects/clog"
13+
)
14+
15+
func serveCommand(w io.Writer, cmdCtx commandContext) error {
16+
if len(cmdCtx.flags.resticArgs) < 2 {
17+
return fmt.Errorf("missing argument: port")
18+
}
19+
handler := http.NewServeMux()
20+
handler.HandleFunc("GET /configuration/{remote}", func(resp http.ResponseWriter, req *http.Request) {
21+
remoteName := req.PathValue("remote")
22+
if !cmdCtx.config.HasRemote(remoteName) {
23+
resp.WriteHeader(http.StatusNotFound)
24+
resp.Write([]byte("remote not found"))
25+
return
26+
}
27+
remote, err := cmdCtx.config.GetRemote(remoteName)
28+
if err != nil {
29+
resp.WriteHeader(http.StatusBadRequest)
30+
resp.Write([]byte(err.Error()))
31+
return
32+
}
33+
34+
clog.Infof("sending configuration for %q", remoteName)
35+
resp.Header().Set("Content-Type", "application/x-tar")
36+
resp.WriteHeader(http.StatusOK)
37+
sendFiles(resp, remote.SendFiles)
38+
})
39+
40+
quit := make(chan os.Signal, 1)
41+
signal.Notify(quit, os.Interrupt)
42+
defer signal.Stop(quit)
43+
44+
server := &http.Server{
45+
Addr: fmt.Sprintf("localhost:%s", cmdCtx.flags.resticArgs[1]),
46+
Handler: handler,
47+
}
48+
49+
// put the shutdown code in a goroutine
50+
go func(server *http.Server, quit chan os.Signal) {
51+
<-quit
52+
53+
clog.Info("shutting down the server")
54+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
55+
defer cancel()
56+
57+
err := server.Shutdown(ctx)
58+
if err != nil {
59+
clog.Errorf("error while shutting down the server: %v", err)
60+
}
61+
62+
}(server, quit)
63+
64+
// we want to return the server error if any so we need to keep it in the main thread.
65+
clog.Infof("listening on %s", server.Addr)
66+
err := server.ListenAndServe()
67+
if err != nil && err != http.ErrServerClosed {
68+
return err
69+
}
70+
return nil
71+
}

ssh/config.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package ssh
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"os"
7+
"path/filepath"
8+
"strings"
9+
)
10+
11+
type Config struct {
12+
Host string
13+
Username string
14+
PrivateKeyPath string
15+
KnownHostsPath string
16+
Handler http.Handler
17+
}
18+
19+
func (c *Config) Validate() error {
20+
if c.Host == "" {
21+
return fmt.Errorf("host is required")
22+
}
23+
if c.Username == "" {
24+
return fmt.Errorf("username is required")
25+
}
26+
if c.PrivateKeyPath == "" {
27+
home, err := os.UserHomeDir()
28+
if err != nil {
29+
return fmt.Errorf("unable to get current user home directory: %w", err)
30+
}
31+
c.PrivateKeyPath = filepath.Join(home, ".ssh/id_rsa") // we can go through all the default name for each key type
32+
}
33+
if c.KnownHostsPath == "" {
34+
home, err := os.UserHomeDir()
35+
if err != nil {
36+
return fmt.Errorf("unable to get current user home directory: %w", err)
37+
}
38+
c.KnownHostsPath = filepath.Join(home, ".ssh/known_hosts")
39+
}
40+
if !strings.Contains(c.Host, ":") {
41+
c.Host = c.Host + ":22"
42+
}
43+
return nil
44+
}

ssh/ssh.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package ssh
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"fmt"
7+
"log"
8+
"net"
9+
"net/http"
10+
"os"
11+
"time"
12+
13+
"github.com/creativeprojects/clog"
14+
"golang.org/x/crypto/ssh"
15+
"golang.org/x/crypto/ssh/knownhosts"
16+
)
17+
18+
const startPort = 10001
19+
20+
type SSH struct {
21+
config Config
22+
port int
23+
client *ssh.Client
24+
tunnel net.Listener
25+
server *http.Server
26+
}
27+
28+
func NewSSH(config Config) *SSH {
29+
return &SSH{
30+
config: config,
31+
port: startPort,
32+
}
33+
}
34+
35+
func (s *SSH) Connect() error {
36+
err := s.config.Validate()
37+
if err != nil {
38+
return err
39+
}
40+
hostKeyCallback, err := knownhosts.New(s.config.KnownHostsPath)
41+
if err != nil {
42+
return fmt.Errorf("cannot load host keys from known_hosts: %w", err)
43+
}
44+
key, err := os.ReadFile(s.config.PrivateKeyPath)
45+
if err != nil {
46+
return fmt.Errorf("unable to read private key: %w", err)
47+
}
48+
49+
// Create the Signer for this private key.
50+
signer, err := ssh.ParsePrivateKey(key)
51+
if err != nil {
52+
return fmt.Errorf("unable to parse private key: %w", err)
53+
}
54+
55+
config := &ssh.ClientConfig{
56+
User: s.config.Username,
57+
Auth: []ssh.AuthMethod{
58+
// Use the PublicKeys method for remote authentication.
59+
ssh.PublicKeys(signer),
60+
},
61+
HostKeyCallback: hostKeyCallback,
62+
}
63+
64+
// Connect to the remote server and perform the SSH handshake.
65+
s.client, err = ssh.Dial("tcp", s.config.Host, config)
66+
if err != nil {
67+
return fmt.Errorf("unable to connect: %w", err)
68+
}
69+
70+
// Request the remote side to open a local port
71+
s.tunnel, err = s.client.Listen("tcp", fmt.Sprintf("localhost:%d", s.port))
72+
if err != nil {
73+
log.Fatal("unable to register tcp forward: ", err)
74+
}
75+
76+
go func() {
77+
s.server = &http.Server{
78+
Handler: s.config.Handler,
79+
}
80+
// Serve HTTP with your SSH server acting as a reverse proxy.
81+
err := s.server.Serve(s.tunnel)
82+
if err != nil && err != http.ErrServerClosed {
83+
clog.Warningf("unable to serve http: %s", err)
84+
}
85+
}()
86+
87+
// Each ClientConn can support multiple interactive sessions,
88+
// represented by a Session.
89+
session, err := s.client.NewSession()
90+
if err != nil {
91+
log.Fatal("Failed to create session: ", err)
92+
}
93+
defer session.Close()
94+
95+
// Once a Session is created, we can execute a single command on
96+
// the remote side using the Run method.
97+
var b bytes.Buffer
98+
session.Stdout = &b
99+
if err := session.Run(fmt.Sprintf("curl http://localhost:%d", s.port)); err != nil {
100+
log.Fatal("Failed to run: " + err.Error())
101+
}
102+
fmt.Println(b.String())
103+
return nil
104+
}
105+
106+
func (s *SSH) Close() {
107+
if s.tunnel != nil {
108+
err := s.tunnel.Close()
109+
if err != nil {
110+
clog.Warningf("unable to close tunnel: %s", err)
111+
}
112+
}
113+
if s.server != nil {
114+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
115+
defer cancel()
116+
err := s.server.Shutdown(ctx)
117+
if err != nil {
118+
clog.Warningf("unable to close http server: %s", err)
119+
}
120+
}
121+
if s.client != nil {
122+
err := s.client.Close()
123+
if err != nil {
124+
clog.Warningf("unable to close ssh connection: %s", err)
125+
}
126+
}
127+
}

0 commit comments

Comments
 (0)