diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9685fda85..13cf97e34 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -12,7 +12,6 @@ on: - 'docs/**' jobs: - build: name: Build and test runs-on: ${{ matrix.os }} @@ -25,7 +24,6 @@ jobs: GO: ${{ matrix.go_version }} steps: - - name: Check out code into the Go module directory uses: actions/checkout@v4 @@ -71,6 +69,25 @@ jobs: name: code-coverage-report-${{ matrix.os }} path: coverage.out + test-ssh: + name: Test SSH client + runs-on: ubuntu-latest + + steps: + - name: Check out code into the Go module directory + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: 1.24 + + - name: Run tests + run: | + make start-ssh-server + make ssh-test + make stop-ssh-server + sonarCloudTrigger: needs: build name: SonarCloud Trigger diff --git a/.vscode/settings.json b/.vscode/settings.json index 3af104a0f..cd83a24e9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -18,5 +18,6 @@ "go.lintTool": "golangci-lint", "githubPullRequests.ignoredPullRequestBranches": [ "master" - ] + ], + "go.buildTags": "ssh" } \ No newline at end of file diff --git a/Makefile b/Makefile index 15f1a00d4..17368908d 100644 --- a/Makefile +++ b/Makefile @@ -54,6 +54,9 @@ ifeq ($(UNAME),Darwin) TMP_MOUNT=${TMP_MOUNT_DARWIN} endif +TMPDIR ?= /tmp +SSH_TESTS_TMPDIR=$(shell echo "$(TMPDIR)/resticprofile-ssh-tests" | tr -s /) + TOC_START=<\!--ts--> TOC_END=<\!--te--> TOC_PATH=toc.md @@ -336,3 +339,25 @@ deploy-current: build-linux build-pi rsync -avz --progress $(BINARY_PI) $$server: ; \ ssh $$server "sudo -S install $(BINARY_PI) /usr/local/bin/resticprofile" ; \ done + +.PHONY: start-ssh-server +start-ssh-server: + @echo "[*] $@" + @mkdir -p $(SSH_TESTS_TMPDIR) && rm -f $(SSH_TESTS_TMPDIR)/id_rsa* || echo "Failed to create temporary directory" + @ssh-keygen -t rsa -b 2048 -f $(SSH_TESTS_TMPDIR)/id_rsa -N "" -C "resticprofile@$(shell hostname)" + @cd ./ssh/test && \ + USER_ID=$(shell id -u) GROUP_ID=$(shell id -g) SSH_TESTS_TMPDIR=$(SSH_TESTS_TMPDIR) \ + docker compose up -d --force-recreate + @sleep 1 + @ssh-keyscan -p 2222 -H localhost > $(SSH_TESTS_TMPDIR)/known_hosts + +.PHONY: stop-ssh-server +stop-ssh-server: + @echo "[*] $@" + cd ./ssh/test && SSH_TESTS_TMPDIR=$(SSH_TESTS_TMPDIR) docker compose down --remove-orphans + @test -d "$(SSH_TESTS_TMPDIR)" && rm -rf "$(SSH_TESTS_TMPDIR)" || echo "temporary directory not found, nothing to remove" + +.PHONY: ssh-test +ssh-test: + @echo "[*] $@" + @go test -run TestSSHClient -v -tags ssh ./ssh diff --git a/commands.go b/commands.go index 629d9c7aa..979ce9180 100644 --- a/commands.go +++ b/commands.go @@ -168,6 +168,24 @@ func getOwnCommands() []ownCommand { needConfiguration: false, hide: true, }, + { + name: "send", + description: "send a configuration profile to a remote client and execute a command", + action: sendProfileCommand, + needConfiguration: true, + noProfile: true, + hide: true, + experimental: true, + }, + { + name: "serve", + description: "serve configuration profiles to remote clients", + action: serveCommand, + needConfiguration: true, + noProfile: true, + hide: true, + experimental: true, + }, } } diff --git a/config/config.go b/config/config.go index aa7a92f08..7c8323959 100644 --- a/config/config.go +++ b/config/config.go @@ -716,6 +716,22 @@ func (c *Config) getProfilePath(key string) string { return c.flatKey(constants.SectionConfigurationProfiles, key) } +// HasRemote returns true if the remote exists in the configuration +func (c *Config) HasRemote(remoteName string) bool { + return c.IsSet(c.flatKey(constants.SectionConfigurationRemotes, remoteName)) +} + +func (c *Config) GetRemote(remoteName string) (*Remote, error) { + // we don't need to check the file version: the remotes can be in a separate configuration file + + remote := NewRemote(c, remoteName) + err := c.unmarshalKey(c.flatKey(constants.SectionConfigurationRemotes, remoteName), remote) + + rootPath := filepath.Dir(c.GetConfigFile()) + remote.SetRootPath(rootPath) + return remote, err +} + // unmarshalConfig returns the decoder config options depending on the configuration version and format func (c *Config) unmarshalConfig() viper.DecoderConfigOption { if c.GetVersion() == Version01 { diff --git a/config/config_test.go b/config/config_test.go index 4ac352ab0..398b161a9 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -366,15 +366,6 @@ x=0 } func TestIncludes(t *testing.T) { - files := []string{} - cleanFiles := func() { - for _, file := range files { - os.Remove(file) - } - files = files[:0] - } - defer cleanFiles() - createFile := func(t *testing.T, suffix, content string) string { t.Helper() name := "" @@ -383,7 +374,9 @@ func TestIncludes(t *testing.T) { defer file.Close() _, err = file.WriteString(content) name = file.Name() - files = append(files, name) + t.Cleanup(func() { + _ = os.Remove(name) + }) } require.NoError(t, err) return name @@ -399,7 +392,6 @@ func TestIncludes(t *testing.T) { testID := fmt.Sprintf("%d", time.Now().Unix()) t.Run("multiple-includes", func(t *testing.T) { - defer cleanFiles() content := fmt.Sprintf(`includes=['*%[1]s.inc.toml','*%[1]s.inc.yaml','*%[1]s.inc.json']`, testID) configFile := createFile(t, "profiles.conf", content) @@ -415,8 +407,6 @@ func TestIncludes(t *testing.T) { }) t.Run("overrides", func(t *testing.T) { - defer cleanFiles() - configFile := createFile(t, "profiles.conf", ` includes = "*`+testID+`.inc.toml" [default] @@ -435,8 +425,6 @@ repository = "overridden-repo"`) }) t.Run("mixins", func(t *testing.T) { - defer cleanFiles() - configFile := createFile(t, "profiles.conf", ` version = 2 includes = "*`+testID+`.inc.toml" @@ -463,8 +451,6 @@ use = "another-run-before2"`) }) t.Run("hcl-includes-only-hcl", func(t *testing.T) { - defer cleanFiles() - configFile := createFile(t, "profiles.hcl", `includes = "*`+testID+`.inc.*"`) createFile(t, "pass-"+testID+".inc.hcl", `one { }`) @@ -473,13 +459,11 @@ use = "another-run-before2"`) createFile(t, "fail-"+testID+".inc.toml", `[two]`) _, err := LoadFile(configFile, "") - assert.Error(t, err) + require.Error(t, err) assert.Regexp(t, ".+ is in hcl format, includes must use the same format", err.Error()) }) t.Run("non-hcl-include-no-hcl", func(t *testing.T) { - defer cleanFiles() - configFile := createFile(t, "profiles.toml", `includes = "*`+testID+`.inc.*"`) createFile(t, "pass-"+testID+".inc.toml", "[one]\nk='v'") @@ -488,12 +472,11 @@ use = "another-run-before2"`) createFile(t, "fail-"+testID+".inc.hcl", `one { }`) _, err := LoadFile(configFile, "") - assert.Error(t, err) + require.Error(t, err) assert.Regexp(t, "hcl format .+ cannot be used in includes from toml", err.Error()) }) t.Run("cannot-load-different-versions", func(t *testing.T) { - defer cleanFiles() content := fmt.Sprintf(`includes=['*%s.inc.json']`, testID) configFile := createFile(t, "profiles.conf", content) @@ -505,7 +488,6 @@ use = "another-run-before2"`) }) t.Run("cannot-load-different-versions", func(t *testing.T) { - defer cleanFiles() content := fmt.Sprintf(`{"version": 2, "includes":["*%s.inc.json"]}`, testID) configFile := createFile(t, "profiles.json", content) diff --git a/config/error.go b/config/error.go index eb85ce37b..394c88c01 100644 --- a/config/error.go +++ b/config/error.go @@ -3,5 +3,6 @@ package config import "errors" var ( - ErrNotFound = errors.New("not found") + ErrNotFound = errors.New("not found") + ErrNotSupportedInVersion1 = errors.New("not supported in configuration version 1") ) diff --git a/config/remote.go b/config/remote.go new file mode 100644 index 000000000..a8d70fba4 --- /dev/null +++ b/config/remote.go @@ -0,0 +1,36 @@ +package config + +type Remote struct { + name string + config *Config + Connection string `mapstructure:"connection" default:"ssh" description:"Connection type to use to connect to the remote client"` + Host string `mapstructure:"host" description:"Address of the remote client. Format: :"` + Username string `mapstructure:"username" description:"User to connect to the remote client"` + PrivateKeyPath string `mapstructure:"private-key" description:"Path to the private key to use for authentication"` + KnownHostsPath string `mapstructure:"known-hosts" description:"Path to the known hosts file"` + BinaryPath string `mapstructure:"binary-path" description:"Path to the resticprofile binary to use on the remote client"` + ConfigurationFile string `mapstructure:"configuration-file" description:"Path to the configuration file to transfer to the remote client"` + ProfileName string `mapstructure:"profile-name" description:"Name of the profile to use on the remote client"` + SendFiles []string `mapstructure:"send-files" description:"Other configuration files to transfer to the remote client"` + SSHConfig string `mapstructure:"ssh-config" description:"Path to the OpenSSH config file to use for the connection"` +} + +func NewRemote(config *Config, name string) *Remote { + remote := &Remote{ + name: name, + config: config, + } + return remote +} + +// SetRootPath changes the path of all the relative paths and files in the configuration +func (r *Remote) SetRootPath(rootPath string) { + r.PrivateKeyPath = fixPath(r.PrivateKeyPath, expandEnv, absolutePrefix(rootPath)) + r.KnownHostsPath = fixPath(r.KnownHostsPath, expandEnv, absolutePrefix(rootPath)) + r.ConfigurationFile = fixPath(r.ConfigurationFile, expandEnv, absolutePrefix(rootPath)) + r.SSHConfig = fixPath(r.SSHConfig, expandEnv, absolutePrefix(rootPath)) + + for i := range r.SendFiles { + r.SendFiles[i] = fixPath(r.SendFiles[i], expandEnv, absolutePrefix(rootPath)) + } +} diff --git a/constants/exit_code.go b/constants/exit_code.go new file mode 100644 index 000000000..a5cf66a99 --- /dev/null +++ b/constants/exit_code.go @@ -0,0 +1,10 @@ +package constants + +const ( + ExitSuccess = iota + ExitGeneralError + ExitErrorInvalidFlags + ExitRunningOnBattery + ExitCannotSetupRemoteConfiguration + ExitErrorChildHasNoParentPort = 10 +) diff --git a/constants/other.go b/constants/other.go index ec8f74d9d..ff94cc7fa 100644 --- a/constants/other.go +++ b/constants/other.go @@ -3,4 +3,5 @@ package constants const ( TemporaryDirMarker = "temp:" JSONSchema = "$schema" + ManifestFilename = ".manifest.json" ) diff --git a/constants/section.go b/constants/section.go index 9e9bf3a60..1a8224e46 100644 --- a/constants/section.go +++ b/constants/section.go @@ -13,6 +13,7 @@ const ( SectionConfigurationMixins = "mixins" SectionConfigurationMixinUse = "use" SectionConfigurationSchedule = "schedule" + SectionConfigurationRemotes = "remotes" SectionDefinitionCommon = "common" SectionDefinitionForget = "forget" diff --git a/examples/linux.yaml b/examples/linux.yaml index 5790dad53..dd5492282 100644 --- a/examples/linux.yaml +++ b/examples/linux.yaml @@ -106,7 +106,15 @@ self: at: "*:15,20,25" permission: system after-network-online: true + exclude-file: + - root-excludes + - excludes check: + schedule: + at: "*:15" + permission: system + copy: + initialize: true schedule-permission: user schedule: - "*:15" @@ -131,9 +139,8 @@ src: run-after: echo All Done! run-before: - echo Starting! - - ls -al ~/go source: - - ~/go + - ~/go/src/github.com/creativeprojects/resticprofile tag: - test - dev diff --git a/filesearch/filesearch_test.go b/filesearch/filesearch_test.go index f736b01c7..32b5347f3 100644 --- a/filesearch/filesearch_test.go +++ b/filesearch/filesearch_test.go @@ -282,6 +282,7 @@ func TestFindResticBinaryWithTilde(t *testing.T) { t.Skip("not supported on Windows") return } + home, err := os.UserHomeDir() require.NoError(t, err) @@ -336,7 +337,6 @@ func TestShellExpand(t *testing.T) { func TestFindConfigurationIncludes(t *testing.T) { t.Parallel() - fs := afero.NewMemMapFs() testID := fmt.Sprintf("%x", time.Now().UnixNano()) tempDir := os.TempDir() files := []string{ @@ -346,6 +346,7 @@ func TestFindConfigurationIncludes(t *testing.T) { filepath.Join(tempDir, "inc3."+testID+".conf"), } + fs := afero.NewMemMapFs() for _, file := range files { require.NoError(t, afero.WriteFile(fs, file, []byte{}, iofs.ModePerm)) } diff --git a/flags.go b/flags.go index 174ce88f9..a3611d772 100644 --- a/flags.go +++ b/flags.go @@ -38,6 +38,7 @@ type commandLineFlags struct { noPriority bool ignoreOnBattery int usagesHelp string + remote string // url of the remote server to download configuration files from } func envValueOverride[T any](defaultValue T, keys ...string) T { @@ -92,6 +93,7 @@ func loadFlags(args []string) (*pflag.FlagSet, commandLineFlags, error) { noPriority: envValueOverride(false, "RESTICPROFILE_NO_PRIORITY"), wait: envValueOverride(false, "RESTICPROFILE_WAIT"), ignoreOnBattery: envValueOverride(0, "RESTICPROFILE_IGNORE_ON_BATTERY"), + remote: envValueOverride("", "RESTICPROFILE_REMOTE"), } flagset.BoolVarP(&flags.help, "help", "h", flags.help, "display this help") @@ -113,6 +115,9 @@ func loadFlags(args []string) (*pflag.FlagSet, commandLineFlags, error) { flagset.BoolVarP(&flags.wait, "wait", "w", flags.wait, "wait at the end until the user presses the enter key") flagset.IntVar(&flags.ignoreOnBattery, "ignore-on-battery", flags.ignoreOnBattery, "don't start the profile when the computer is running on battery. You can specify a value to ignore only when the % charge left is less or equal than the value") flagset.Lookup("ignore-on-battery").NoOptDefVal = "100" // 0 is flag not set, 100 is for a flag with no value (meaning just battery discharge) + flagset.StringVarP(&flags.remote, "remote", "r", flags.remote, "remote server to download configuration files from") + // keep the "remote" flag hidden for now + _ = flagset.MarkHidden("remote") flagset.SetNormalizeFunc(func(f *pflag.FlagSet, name string) pflag.NormalizedName { switch name { diff --git a/fuse/file.go b/fuse/file.go new file mode 100644 index 000000000..e79cfac6f --- /dev/null +++ b/fuse/file.go @@ -0,0 +1,25 @@ +package fuse + +import "io/fs" + +type File struct { + name string + fileInfo fs.FileInfo + data []byte +} + +func NewFile(name string, fileInfo fs.FileInfo, data []byte) *File { + return &File{ + name: name, + fileInfo: fileInfo, + data: data, + } +} + +func (f *File) Close() { + // emptying file data + for i := range f.data { + f.data[i] = 0 + } + f.data = nil +} diff --git a/fuse/fs_file.go b/fuse/fs_file.go new file mode 100644 index 000000000..e3f7b5ac5 --- /dev/null +++ b/fuse/fs_file.go @@ -0,0 +1,40 @@ +//go:build !windows + +package fuse + +import ( + "context" + "syscall" + + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +type fsFile struct { + fs.Inode + attr fuse.Attr + file File +} + +var _ = (fs.NodeOpener)((*fsFile)(nil)) +var _ = (fs.NodeGetattrer)((*fsFile)(nil)) + +func (fsf *fsFile) Getattr(ctx context.Context, f fs.FileHandle, out *fuse.AttrOut) syscall.Errno { + out.Attr = fsf.attr + return 0 +} + +// Open only needs to send the flags back to the kernel +func (fsf *fsFile) Open(ctx context.Context, flags uint32) (fs.FileHandle, uint32, syscall.Errno) { + // tell the kernel not to cache the data + return fsf, fuse.FOPEN_DIRECT_IO, fs.OK +} + +// Read simply returns the data from the file +func (fsf *fsFile) Read(ctx context.Context, f fs.FileHandle, dest []byte, off int64) (fuse.ReadResult, syscall.Errno) { + end := int(off) + len(dest) + if end > len(fsf.file.data) { + end = len(fsf.file.data) + } + return fuse.ReadResultData(fsf.file.data[off:end]), fs.OK +} diff --git a/fuse/memfs.go b/fuse/memfs.go new file mode 100644 index 000000000..fb9bc8320 --- /dev/null +++ b/fuse/memfs.go @@ -0,0 +1,122 @@ +//go:build !windows + +// Simple implementation of a read-only filesystem in memory. +// +// Based on the examples at https://pkg.go.dev/github.com/hanwen/go-fuse/v2/fs#pkg-examples +package fuse + +import ( + "archive/tar" + "context" + iofs "io/fs" + "log" + "os" + "path/filepath" + "strings" + "syscall" + + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +type memFS struct { + fs.Inode + + files []File +} + +func newMemFS(files []File) *memFS { + return &memFS{ + files: files, + } +} + +var _ = (fs.InodeEmbedder)((*memFS)(nil)) + +// The root populates the tree in its OnAdd method +var _ = (fs.NodeOnAdder)((*memFS)(nil)) + +// Close erases the data from all the files +func (memfs *memFS) Close() { + for i := range memfs.files { + memfs.files[i].Close() + } + memfs.files = nil +} + +// OnAdd is called once we are attached to an Inode. We can +// then construct a tree. We construct the entire tree, and +// we don't want parts of the tree to disappear when the +// kernel is short on memory, so we use persistent inodes. +func (memfs *memFS) OnAdd(ctx context.Context) { + for _, file := range memfs.files { + dir, base := filepath.Split(filepath.Clean(file.name)) + + p := memfs.EmbeddedInode() + for _, comp := range strings.Split(dir, "/") { + if len(comp) == 0 { + continue + } + ch := p.GetChild(comp) + if ch == nil { + ch = p.NewPersistentInode(ctx, + &fs.Inode{}, + fs.StableAttr{Mode: syscall.S_IFDIR}) + p.AddChild(comp, ch, false) + } + p = ch + } + + attr := fileInfoToAttr(file.fileInfo) + switch { + case file.fileInfo.Mode().Type()&os.ModeSymlink == os.ModeSymlink: + file.data = nil + fsfile := &fsFile{ + attr: attr, + file: file, + } + p.AddChild(base, memfs.NewPersistentInode(ctx, fsfile, fs.StableAttr{Mode: syscall.S_IFLNK}), false) + + case file.fileInfo.Mode().IsDir(): + fsdir := &fsFile{ + attr: attr, + file: file, + } + p.AddChild(base, memfs.NewPersistentInode(ctx, fsdir, fs.StableAttr{Mode: syscall.S_IFDIR}), false) + + case file.fileInfo.Mode().IsRegular(): + fsfile := &fsFile{ + attr: attr, + file: file, + } + p.AddChild(base, memfs.NewPersistentInode(ctx, fsfile, fs.StableAttr{}), false) + + default: + log.Printf("entry %q: unsupported type '%c'", file.name, file.fileInfo.Mode().Type()) + } + } +} + +func fileInfoToAttr(fileInfo iofs.FileInfo) fuse.Attr { + var out fuse.Attr + if header, ok := fileInfo.Sys().(*tar.Header); ok { + out.Mode = uint32(header.Mode) //nolint:gosec + out.Size = uint64(header.Size) //nolint:gosec + out.Uid = uint32(header.Uid) //nolint:gosec + out.Gid = uint32(header.Gid) //nolint:gosec + out.SetTimes(&header.AccessTime, &header.ModTime, &header.ChangeTime) + } else { + out.Mode = uint32(fileInfo.Mode()) + out.Size = uint64(fileInfo.Size()) //nolint:gosec + out.Uid = uint32(os.Geteuid()) //nolint:gosec + out.Gid = uint32(os.Getegid()) //nolint:gosec + modTime := fileInfo.ModTime() + out.SetTimes(nil, &modTime, nil) + } + out.Nlink = 1 + const bs = 512 + out.Blksize = bs + out.Blocks = (out.Size + bs - 1) / bs + + return out +} diff --git a/fuse/memfs_test.go b/fuse/memfs_test.go new file mode 100644 index 000000000..1845336e9 --- /dev/null +++ b/fuse/memfs_test.go @@ -0,0 +1,77 @@ +//go:build !windows + +package fuse + +import ( + "archive/tar" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var memfsContents = map[string]string{ + "emptydir/": "", + "file.txt": "content", + "dir/subfile.txt": "other content", + "dir with space/other file.txt": "different content", +} + +func TestMemFS(t *testing.T) { + const fileMode = 0o764 + + files := make([]File, 0) + now := time.Now() + for filename, fileContents := range memfsContents { + h := &tar.Header{ + Name: filename, + Size: int64(len(fileContents)), + Mode: fileMode, + Uid: 100, + Gid: 100, + ModTime: now, + } + + isDir := strings.HasSuffix(filename, "/") + if isDir { + h.Typeflag = tar.TypeDir + } + + files = append(files, File{ + name: filename, + fileInfo: h.FileInfo(), + data: []byte(fileContents), + }) + } + + mnt := t.TempDir() + closeMount, err := MountFS(mnt, files) + if err != nil && strings.Contains(err.Error(), "no FUSE mount utility found") { + t.Skip("no FUSE mount utility found") + } + require.NoError(t, err, "cannot mount FS") + defer closeMount() + + for filename, fileContents := range memfsContents { + fullPath := filepath.Join(mnt, filename) + + filestat, err := os.Stat(fullPath) + require.NoErrorf(t, err, "os.Stat %q", filename) + + if strings.HasSuffix(filename, "/") { + assert.True(t, filestat.IsDir(), "is dir %q", filename) + + } else { + assert.False(t, filestat.IsDir(), "is file %q", filename) + + contents, err := os.ReadFile(fullPath) + assert.NoErrorf(t, err, "read %q", filename) + + assert.Equalf(t, fileContents, string(contents), "file %q", filename) + } + } +} diff --git a/fuse/memfs_windows_test.go b/fuse/memfs_windows_test.go new file mode 100644 index 000000000..e99b6381e --- /dev/null +++ b/fuse/memfs_windows_test.go @@ -0,0 +1,14 @@ +//go:build windows + +package fuse + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMountFS(t *testing.T) { + _, err := MountFS("mnt", []File{}) + assert.Error(t, err) +} diff --git a/fuse/mount.go b/fuse/mount.go new file mode 100644 index 000000000..7aae32fef --- /dev/null +++ b/fuse/mount.go @@ -0,0 +1,40 @@ +//go:build !windows + +package fuse + +import ( + "fmt" + + "github.com/creativeprojects/clog" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +func MountFS(mountpoint string, files []File) (func(), error) { + memFS := newMemFS(files) + + clog.Debugf("mounting filesystem at %s", mountpoint) + + opts := &fs.Options{ + MountOptions: fuse.MountOptions{ + Debug: false, // generates a LOT of logs + FsName: "resticprofile", + DisableXAttrs: true, + EnableLocks: false, + }, + } + server, err := fs.Mount(mountpoint, memFS, opts) + if err != nil { + return nil, fmt.Errorf("failed to mount filesystem: %w", err) + } + closeFS := func() { + clog.Debug("unmounting filesystem") + err := server.Unmount() // don't need to call Wait after Unmount + if err != nil { + clog.Errorf("failed to unmount filesystem: %v", err) + } + + memFS.Close() + } + return closeFS, nil +} diff --git a/fuse/mount_windows.go b/fuse/mount_windows.go new file mode 100644 index 000000000..625407b0e --- /dev/null +++ b/fuse/mount_windows.go @@ -0,0 +1,7 @@ +package fuse + +import "errors" + +func MountFS(_ string, _ []File) (func(), error) { + return nil, errors.New("not supported on Windows") +} diff --git a/go.mod b/go.mod index 15affef83..143c09813 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/creativeprojects/go-selfupdate v1.5.0 github.com/distatus/battery v0.11.0 github.com/fatih/color v1.18.0 + github.com/hanwen/go-fuse/v2 v2.8.0 github.com/joho/godotenv v1.5.1 github.com/mackerelio/go-osstat v0.2.6 github.com/mattn/go-colorable v0.1.14 @@ -23,6 +24,7 @@ require ( github.com/spf13/pflag v1.0.7 github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.10.0 + golang.org/x/crypto v0.40.0 golang.org/x/sys v0.34.0 golang.org/x/term v0.33.0 golang.org/x/text v0.27.0 @@ -70,7 +72,6 @@ require ( github.com/xanzy/go-gitlab v0.115.0 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.38.0 // indirect golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/time v0.11.0 // indirect diff --git a/go.sum b/go.sum index b3a87e76d..c79cc3d7c 100644 --- a/go.sum +++ b/go.sum @@ -48,6 +48,8 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/govalues/decimal v0.1.36 h1:dojDpsSvrk0ndAx8+saW5h9WDIHdWpIwrH/yhl9olyU= github.com/govalues/decimal v0.1.36/go.mod h1:Ee7eI3Llf7hfqDZtpj8Q6NCIgJy1iY3kH1pSwDrNqlM= +github.com/hanwen/go-fuse/v2 v2.8.0 h1:wV8rG7rmCz8XHSOwBZhG5YcVqcYjkzivjmbaMafPlAs= +github.com/hanwen/go-fuse/v2 v2.8.0/go.mod h1:yE6D2PqWwm3CbYRxFXV9xUd8Md5d6NG0WBs5spCswmI= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= @@ -65,6 +67,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 h1:PpXWgLPs+Fqr325bN2FD2ISlRRztXibcX6e8f5FR5Dc= github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/mackerelio/go-osstat v0.2.6 h1:gs4U8BZeS1tjrL08tt5VUliVvSWP26Ai2Ob8Lr7f2i0= @@ -77,6 +81,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/moby/sys/mountinfo v0.7.2 h1:1shs6aH5s4o5H2zQLn796ADW1wMrIwHsyJ2v9KouLrg= +github.com/moby/sys/mountinfo v0.7.2/go.mod h1:1YOa8w8Ih7uW0wALDUgT1dTTSBrZ+HiBLGws92L2RU4= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= @@ -142,8 +148,8 @@ go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN8 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= -golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= -golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= diff --git a/main.go b/main.go index 8b9046687..b17767040 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "math/rand" @@ -138,6 +139,27 @@ func main() { banner() + if flags.remote != "" { + ctx := context.TODO() + closeFS, remoteParameters, err := setupRemoteConfiguration(ctx, flags.remote) + if err != nil { + // need to setup console logging to display the error message + closeLogger := setupLogging(nil) + defer closeLogger() + clog.Error(err) + exitCode = constants.ExitCannotSetupRemoteConfiguration + return + } + if flags.config == constants.DefaultConfigurationFile && remoteParameters.ConfigurationFile != "" { + flags.config = remoteParameters.ConfigurationFile + } + if flags.name == constants.DefaultProfileName && remoteParameters.ProfileName != "" { + flags.name = remoteParameters.ProfileName + } + flags.resticArgs = remoteParameters.CommandLineArguments + shutdown.AddHook(closeFS) + } + // resticprofile own commands (configuration file may not be loaded) if len(flags.resticArgs) > 0 { if ownCommands.Exists(flags.resticArgs[0], false) { @@ -270,7 +292,13 @@ func main() { } func banner() { - clog.Debugf("resticprofile %s compiled with %s", version, runtime.Version()) + clog.Debugf( + "resticprofile %s compiled with %s %s/%s", + version, + runtime.Version(), + runtime.GOOS, + runtime.GOARCH, + ) } func loadConfig(flags commandLineFlags, silent bool) (cfg *config.Config, global *config.Global, err error) { diff --git a/own_commands.go b/own_commands.go index ff23beae8..b8caf0e44 100644 --- a/own_commands.go +++ b/own_commands.go @@ -5,6 +5,8 @@ import ( "io" "os" "strings" + + "github.com/creativeprojects/clog" ) // commandContext is the context for running a command. @@ -23,6 +25,7 @@ type ownCommand struct { hide bool // don't display the command in help and completion hideInCompletion bool // don't display the command in completion noProfile bool // true if the command doesn't need a profile name + experimental bool // display a warning when using this command flags map[string]string // own command flags should be simple enough to be handled manually for now } @@ -61,6 +64,9 @@ func (o *OwnCommands) Run(ctx *Context) error { if command == nil { return fmt.Errorf("command not found: %v", ctx.request.command) } + if command.experimental { + clog.Warningf("%s: this command is experimental and its behaviour may change in the future", ctx.request.command) + } return command.action(os.Stdout, commandContext{ ownCommands: o, Context: *ctx, diff --git a/remote.go b/remote.go new file mode 100644 index 000000000..1f31fb420 --- /dev/null +++ b/remote.go @@ -0,0 +1,135 @@ +package main + +import ( + "archive/tar" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/creativeprojects/clog" + "github.com/creativeprojects/resticprofile/constants" + "github.com/creativeprojects/resticprofile/fuse" + "github.com/creativeprojects/resticprofile/remote" +) + +func loadRemoteFiles(ctx context.Context, endpoint string) ([]fuse.File, *remote.Manifest, error) { + var parameters *remote.Manifest + + client := http.DefaultClient + request, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, http.NoBody) + if err != nil { + return nil, nil, fmt.Errorf("failed to create request: %w", err) + } + request.Header.Set("Accept", "application/x-tar") + + resp, err := client.Do(request) + if err != nil { + return nil, nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + buf := &bytes.Buffer{} + _, _ = buf.ReadFrom(resp.Body) + return nil, nil, fmt.Errorf("http error %d: %q", resp.StatusCode, strings.TrimSpace(buf.String())) + } + + if resp.Header.Get("Content-Type") != "application/x-tar" { + return nil, nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) + } + + files := []fuse.File{} + reader := tar.NewReader(resp.Body) + for { + hdr, err := reader.Next() + if err == io.EOF { + break // End of archive + } + if err != nil { + return nil, nil, fmt.Errorf("failed to read tar header: %w", err) + } + if !filepath.IsLocal(hdr.Name) { + return nil, nil, fmt.Errorf("invalid file name: %s", hdr.Name) + } + if hdr.Name == constants.ManifestFilename { + clog.Debugf("downloading manifest (%d bytes)", hdr.Size) + parameters, err = getManifestParameters(reader) + if err != nil { + return nil, nil, fmt.Errorf("failed to read manifest: %w", err) + } + } else { + clog.Debugf("downloading file %s (%d bytes)", hdr.Name, hdr.Size) + data := make([]byte, hdr.Size) + read, err := reader.Read(data) + if err != nil && err != io.EOF { + return nil, nil, fmt.Errorf("failed to download file content: %w", err) + } + if read != int(hdr.Size) { + return nil, nil, fmt.Errorf("file size mismatch: expected %d, got %d", hdr.Size, read) + } + files = append(files, *fuse.NewFile(hdr.Name, hdr.FileInfo(), data)) + } + } + + return files, parameters, nil +} + +func getManifestParameters(reader io.Reader) (*remote.Manifest, error) { + manifest := &remote.Manifest{} + decoder := json.NewDecoder(reader) + err := decoder.Decode(manifest) + if err != nil { + return nil, fmt.Errorf("failed to decode manifest: %w", err) + } + return manifest, nil +} + +// setupRemoteConfiguration downloads the configuration files from the remote endpoint and mounts the virtual FS +func setupRemoteConfiguration(ctx context.Context, remoteEndpoint string) (func(), *remote.Manifest, error) { + files, parameters, err := loadRemoteFiles(ctx, remoteEndpoint) + if err != nil { + return nil, nil, err + } + + closeMountpoint := func() {} + mountpoint := parameters.Mountpoint + if mountpoint == "" { + // generates a temporary directory + mountpoint, err = os.MkdirTemp("", "resticprofile-") + if err != nil { + return nil, parameters, fmt.Errorf("failed to create mount directory: %w", err) + } + closeMountpoint = func() { + err = os.Remove(mountpoint) + if err != nil { + clog.Errorf("failed to remove mountpoint: %v", err) + } + } + } + + closeFs, err := fuse.MountFS(mountpoint, files) + if err != nil { + return closeMountpoint, parameters, err + } + + wd, _ := os.Getwd() + err = os.Chdir(mountpoint) + if err != nil { + return func() { + closeFs() + closeMountpoint() + }, parameters, fmt.Errorf("failed to change directory: %w", err) + } + + return func() { + _ = os.Chdir(wd) + closeFs() + closeMountpoint() + }, parameters, nil +} diff --git a/remote/manifest.go b/remote/manifest.go new file mode 100644 index 000000000..53fdb6a7a --- /dev/null +++ b/remote/manifest.go @@ -0,0 +1,9 @@ +package remote + +type Manifest struct { + Version string // resticprofile version + ConfigurationFile string + ProfileName string + Mountpoint string // Mountpoint of the virtual FS if configured + CommandLineArguments []string +} diff --git a/remote/tar.go b/remote/tar.go new file mode 100644 index 000000000..5cfd3357a --- /dev/null +++ b/remote/tar.go @@ -0,0 +1,95 @@ +package remote + +import ( + "archive/tar" + "fmt" + "io" + "os" + "time" + + "github.com/creativeprojects/clog" + "github.com/spf13/afero" +) + +type Tar struct { + writer *tar.Writer + fs afero.Fs +} + +func NewTar(w io.Writer) *Tar { + return &Tar{ + writer: tar.NewWriter(w), + fs: afero.NewOsFs(), + } +} + +func (t *Tar) WithFs(fs afero.Fs) *Tar { + t.fs = fs + return t +} + +func (t *Tar) SendFiles(files []string) error { + for _, filename := range files { + fileInfo, err := t.fs.Stat(filename) + if err != nil { + clog.Errorf("unable to stat file %s: %v", filename, err) + continue + } + fileHeader, err := tar.FileInfoHeader(fileInfo, "") + if err != nil { + clog.Errorf("unable to create tar header for file %s: %v", filename, err) + continue + } + err = t.writer.WriteHeader(fileHeader) + if err != nil { + clog.Errorf("unable to write tar header for file %s: %v", filename, err) + break + } + file, err := t.fs.Open(filename) + if err != nil { + clog.Errorf("unable to open file %s: %v", filename, err) + continue + } + defer file.Close() + + written, err := io.Copy(t.writer, file) + if err != nil { + clog.Errorf("unable to write file %s: %v", filename, err) + break + } + if written != fileInfo.Size() { + clog.Errorf("file %s: written %d bytes, expected %d", filename, written, fileInfo.Size()) + break + } + clog.Debugf("file %s: written %d bytes", filename, written) + } + return nil +} + +func (t *Tar) SendFile(name string, data []byte) error { + header := &tar.Header{ + Name: name, + Size: int64(len(data)), + ModTime: time.Now(), + Mode: 0o444, + Typeflag: tar.TypeReg, + Uid: os.Geteuid(), + Gid: os.Getegid(), + } + if err := t.writer.WriteHeader(header); err != nil { + return err + } + written, err := t.writer.Write(data) + if err != nil { + return err + } + if written != len(data) { + return fmt.Errorf("manifest written %d bytes, expected %d", written, len(data)) + } + clog.Debugf("manifest written %d bytes", written) + return nil +} + +func (t *Tar) Close() error { + return t.writer.Close() +} diff --git a/remote/tar_test.go b/remote/tar_test.go new file mode 100644 index 000000000..6961960fe --- /dev/null +++ b/remote/tar_test.go @@ -0,0 +1,170 @@ +package remote + +import ( + "archive/tar" + "bytes" + "io" + "testing" + + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSendFile(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + fileName string + data []byte + }{ + { + name: "send empty file", + fileName: "empty.txt", + data: []byte{}, + }, + { + name: "send file with content", + fileName: "test.txt", + data: []byte("test content"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + buf := new(bytes.Buffer) + tar := NewTar(buf) + + // Execute + err := tar.SendFile(tt.fileName, tt.data) + require.NoError(t, err) + tar.Close() + + // Verify the tar contains the correct file + fs := afero.NewMemMapFs() + err = extractTarToFs(buf.Bytes(), fs) + assert.NoError(t, err) + + // Check file exists and has correct content + fileExists, err := afero.Exists(fs, tt.fileName) + assert.NoError(t, err) + assert.True(t, fileExists) + + content, err := afero.ReadFile(fs, tt.fileName) + assert.NoError(t, err) + assert.Equal(t, tt.data, content) + }) + } +} + +func TestSendFiles(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + files map[string]string + filePaths []string + }{ + { + name: "send multiple files", + files: map[string]string{"file1.txt": "content1", "file2.txt": "content2"}, + filePaths: []string{"file1.txt", "file2.txt"}, + }, + { + name: "send empty file among others", + files: map[string]string{"empty.txt": "", "notempty.txt": "content"}, + filePaths: []string{"empty.txt", "notempty.txt"}, + }, + { + name: "send non-existent file", + files: map[string]string{"exists.txt": "content"}, + filePaths: []string{"exists.txt", "nonexistent.txt"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup + buf := new(bytes.Buffer) + tar := NewTar(buf) + + // Create a memory filesystem with the test files + memFs := afero.NewMemMapFs() + for name, content := range tt.files { + err := afero.WriteFile(memFs, name, []byte(content), 0644) + assert.NoError(t, err) + } + + tar.WithFs(memFs) + + // Execute + err := tar.SendFiles(tt.filePaths) + require.NoError(t, err) + tar.Close() + + // Verify the tar contains the correct files + outputFs := afero.NewMemMapFs() + err = extractTarToFs(buf.Bytes(), outputFs) + assert.NoError(t, err) + + // Check each expected file exists and has correct content + for name, expectedContent := range tt.files { + // Only check files that were in the filePaths list + included := false + for _, path := range tt.filePaths { + if path == name { + included = true + break + } + } + + if !included { + continue + } + + fileExists, err := afero.Exists(outputFs, name) + assert.NoError(t, err) + + if _, ok := tt.files[name]; ok { + assert.True(t, fileExists) + + content, err := afero.ReadFile(outputFs, name) + assert.NoError(t, err) + assert.Equal(t, []byte(expectedContent), content) + } + } + }) + } +} + +// Helper function to extract tar contents to an afero filesystem +func extractTarToFs(tarData []byte, fs afero.Fs) error { + reader := bytes.NewReader(tarData) + tr := tar.NewReader(reader) + + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + switch header.Typeflag { + case tar.TypeReg: + file, err := fs.Create(header.Name) + if err != nil { + return err + } + if _, err := io.CopyN(file, tr, 1000); err != nil && err != io.EOF { + file.Close() + return err + } + file.Close() + } + } + return nil +} diff --git a/serve.go b/serve.go new file mode 100644 index 000000000..4307057e5 --- /dev/null +++ b/serve.go @@ -0,0 +1,152 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "os/signal" + "path" + "time" + + "github.com/creativeprojects/clog" + "github.com/creativeprojects/resticprofile/config" + "github.com/creativeprojects/resticprofile/constants" + "github.com/creativeprojects/resticprofile/remote" + "github.com/creativeprojects/resticprofile/ssh" +) + +func serveCommand(w io.Writer, cmdCtx commandContext) error { + if len(cmdCtx.flags.resticArgs) < 2 { + return fmt.Errorf("missing argument: port") + } + handler := http.NewServeMux() + handler.HandleFunc("GET /configuration/{remote}", func(resp http.ResponseWriter, req *http.Request) { + remoteName := req.PathValue("remote") + if !cmdCtx.config.HasRemote(remoteName) { + sendError(resp, http.StatusNotFound, fmt.Errorf("remote %q not found", remoteName)) + return + } + remoteConfig, err := cmdCtx.config.GetRemote(remoteName) + if err != nil { + sendError(resp, http.StatusBadRequest, fmt.Errorf("error while getting remote configuration: %w", err)) + return + } + + sendRemoteFiles(remoteConfig, remoteName, nil, resp) + }) + + quit := make(chan os.Signal, 1) + signal.Notify(quit, os.Interrupt) + defer signal.Stop(quit) + + server := &http.Server{ + Addr: fmt.Sprintf("localhost:%s", cmdCtx.flags.resticArgs[1]), + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, + } + + // put the shutdown code in a goroutine + go func(server *http.Server, quit chan os.Signal) { + <-quit + + clog.Info("shutting down the server") + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + err := server.Shutdown(ctx) + if err != nil { + clog.Errorf("error while shutting down the server: %v", err) + } + + }(server, quit) + + // we want to return the server error if any so we need to keep it in the main thread. + clog.Infof("listening on %s", server.Addr) + err := server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + return err + } + return nil +} + +func sendProfileCommand(w io.Writer, cmdCtx commandContext) error { + if len(cmdCtx.flags.resticArgs) < 2 { + return fmt.Errorf("missing argument: remote name") + } + remoteName := cmdCtx.flags.resticArgs[1] + if !cmdCtx.config.HasRemote(remoteName) { + return fmt.Errorf("remote not found") + } + remoteConfig, err := cmdCtx.config.GetRemote(remoteName) + if err != nil { + return err + } + // send the files to the remote using tar + handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + sendRemoteFiles(remoteConfig, remoteName, cmdCtx.flags.resticArgs[2:], resp) + }) + sshConfig := ssh.Config{ + Host: remoteConfig.Host, + Username: remoteConfig.Username, + PrivateKeyPath: remoteConfig.PrivateKeyPath, + KnownHostsPath: remoteConfig.KnownHostsPath, + SSHConfigPath: remoteConfig.SSHConfig, + Handler: handler, + } + // cnx := ssh.NewInternalClient(sshConfig) + cnx := ssh.NewOpenSSHClient(sshConfig) + defer cnx.Close() + + err = cnx.Connect() + if err != nil { + return err + } + binaryPath := remoteConfig.BinaryPath + if binaryPath == "" { + binaryPath = "resticprofile" + } + arguments := []string{ + "-v", + "-r", fmt.Sprintf("http://localhost:%d/configuration/%s", cnx.TunnelPeerPort(), remoteName), + } + err = cnx.Run(binaryPath, arguments...) + if err != nil { + return fmt.Errorf("failed to run resticprofile on peer: %w", err) + } + return nil +} + +func sendRemoteFiles(remoteConfig *config.Remote, remoteName string, extraArgs []string, resp http.ResponseWriter) { + // prepare manifest file + manifest := remote.Manifest{ + Version: version, + ConfigurationFile: path.Base(remoteConfig.ConfigurationFile), // need to take file path into consideration + ProfileName: remoteConfig.ProfileName, + CommandLineArguments: extraArgs, + } + manifestData, err := json.Marshal(manifest) + if err != nil { + sendError(resp, http.StatusInternalServerError, fmt.Errorf("error while generating manifest: %w", err)) + return + } + + clog.Debugf("sending configuration for %q", remoteName) + resp.Header().Set("Content-Type", "application/x-tar") + resp.WriteHeader(http.StatusOK) + + tar := remote.NewTar(resp) + defer tar.Close() + _ = tar.SendFiles(append(remoteConfig.SendFiles, remoteConfig.ConfigurationFile)) + _ = tar.SendFile(constants.ManifestFilename, manifestData) +} + +func sendError(resp http.ResponseWriter, status int, err error) { + resp.Header().Set("Content-Type", "text/plain") + resp.WriteHeader(status) + _, _ = resp.Write([]byte(err.Error())) + _, _ = resp.Write([]byte("\n")) + clog.Error(err) +} diff --git a/serve_test.go b/serve_test.go new file mode 100644 index 000000000..d8f68c890 --- /dev/null +++ b/serve_test.go @@ -0,0 +1,19 @@ +package main + +import ( + "net/http/httptest" + "testing" + + "github.com/creativeprojects/resticprofile/config" + "github.com/stretchr/testify/assert" +) + +func TestSendRemoteFiles(t *testing.T) { + recorder := httptest.NewRecorder() + sendRemoteFiles(&config.Remote{ + ConfigurationFile: "test_config.json", + ProfileName: "test_profile", + }, "test_remote", []string{"arg1", "arg2"}, recorder) + assert.Equal(t, recorder.Code, 200) + assert.Equal(t, recorder.Header().Get("Content-Type"), "application/x-tar") +} diff --git a/ssh/client.go b/ssh/client.go new file mode 100644 index 000000000..112078d34 --- /dev/null +++ b/ssh/client.go @@ -0,0 +1,9 @@ +package ssh + +type Client interface { + Name() string + Connect() error + Close() + Run(command string, arguments ...string) error + TunnelPeerPort() int +} diff --git a/ssh/config.go b/ssh/config.go new file mode 100644 index 000000000..99ced1a7e --- /dev/null +++ b/ssh/config.go @@ -0,0 +1,45 @@ +package ssh + +import ( + "fmt" + "net/http" + "os" + "path/filepath" + "strings" +) + +type Config struct { + Host string + Username string + PrivateKeyPath string + KnownHostsPath string + SSHConfigPath string // Path to the OpenSSH config file, if any + Handler http.Handler +} + +func (c *Config) Validate() error { + if c.Host == "" { + return fmt.Errorf("host is required") + } + if c.Username == "" { + return fmt.Errorf("username is required") + } + if c.PrivateKeyPath == "" { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("unable to get current user home directory: %w", err) + } + c.PrivateKeyPath = filepath.Join(home, ".ssh/id_rsa") // we can go through all the default name for each key type + } + if c.KnownHostsPath == "" { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("unable to get current user home directory: %w", err) + } + c.KnownHostsPath = filepath.Join(home, ".ssh/known_hosts") + } + if !strings.Contains(c.Host, ":") { + c.Host = c.Host + ":22" + } + return nil +} diff --git a/ssh/internal_client.go b/ssh/internal_client.go new file mode 100644 index 000000000..7502b957c --- /dev/null +++ b/ssh/internal_client.go @@ -0,0 +1,171 @@ +package ssh + +import ( + "context" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "time" + + "github.com/creativeprojects/clog" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +const startPort = 10001 + +type InternalClient struct { + config Config + port int + client *ssh.Client + tunnel net.Listener + server *http.Server +} + +func NewInternalClient(config Config) *InternalClient { + return &InternalClient{ + config: config, + port: startPort, + } +} + +func (s *InternalClient) Name() string { + return "InternalSSH" +} + +func (s *InternalClient) Connect() error { + err := s.config.Validate() + if err != nil { + return err + } + var hostKeyCallback ssh.HostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { + clog.Debugf("initiating SSH connection to %s using internal client", remote.String()) + return nil + } + if s.config.KnownHostsPath != "" && s.config.KnownHostsPath != "none" && s.config.KnownHostsPath != "/dev/null" { + hostKeyCallback, err = knownhosts.New(s.config.KnownHostsPath) + if err != nil { + return fmt.Errorf("cannot load host keys from known_hosts: %w", err) + } + } + key, err := os.ReadFile(s.config.PrivateKeyPath) + if err != nil { + return fmt.Errorf("unable to read private key: %w", err) + } + + // Create the Signer for this private key. + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return fmt.Errorf("unable to parse private key: %w", err) + } + + // The algorithms returned by ssh.SupportedAlgorithms() are different from + // the default ones and do not include algorithms that are considered + // insecure, such as those using SHA-1, returned by + // ssh.InsecureAlgorithms(). + algorithms := ssh.SupportedAlgorithms() + + config := &ssh.ClientConfig{ + User: s.config.Username, + Auth: []ssh.AuthMethod{ + // Use the PublicKeys method for remote authentication. + ssh.PublicKeys(signer), + }, + HostKeyCallback: hostKeyCallback, + HostKeyAlgorithms: algorithms.HostKeys, + Config: ssh.Config{ + KeyExchanges: algorithms.KeyExchanges, + Ciphers: algorithms.Ciphers, + MACs: algorithms.MACs, + }, + } + + // Connect to the remote server and perform the SSH handshake. + s.client, err = ssh.Dial("tcp", s.config.Host, config) + if err != nil { + return fmt.Errorf("unable to connect: %w", err) + } + + // Request the remote side to open a local port + s.tunnel, err = s.client.Listen("tcp", fmt.Sprintf("localhost:%d", s.port)) // increment the port in a loop in case of an error + if err != nil { + log.Fatal("unable to register tcp forward: ", err) + } + + go func() { + s.server = &http.Server{ + Handler: s.config.Handler, + ReadHeaderTimeout: 5 * time.Second, + } + // Serve HTTP with your SSH server acting as a reverse proxy. + err := s.server.Serve(s.tunnel) + if err != nil && err != http.ErrServerClosed && err != io.EOF { + clog.Warningf("unable to serve http: %s", err) + } + }() + time.Sleep(100 * time.Millisecond) // wait for the server to start + return nil +} + +func (s *InternalClient) TunnelPeerPort() int { + return s.port +} + +func (s *InternalClient) Run(command string, arguments ...string) error { + // Each ClientConn can support multiple interactive sessions, + // represented by a Session. + session, err := s.client.NewSession() + if err != nil { + return fmt.Errorf("failed to create session: %w", err) + } + defer session.Close() + + // request a pseudo terminal to display colors + if termType := os.Getenv("TERM"); termType != "" { + modes := ssh.TerminalModes{ + ssh.ECHO: 0, // disable echoing + } + if err := session.RequestPty(termType, 40, 80, modes); err != nil { + clog.Warningf("request for pseudo terminal failed: %s", err) + } + } + + // Once a Session is created, we can execute a single command on + // the remote side using the Run method. + session.Stdout = os.Stdout + session.Stderr = os.Stderr + if err := session.Run(command); err != nil { + return fmt.Errorf("failed to run: %w", err) + } + return nil +} + +func (s *InternalClient) Close() { + // close the tunnel first otherwise it fails with error: "ssh: cancel-tcpip-forward failed" + if s.tunnel != nil { + err := s.tunnel.Close() + if err != nil { + clog.Warningf("unable to close tunnel: %s", err) + } + } + if s.server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + err := s.server.Shutdown(ctx) + if err != nil { + clog.Warningf("unable to close http server: %s", err) + } + } + if s.client != nil { + err := s.client.Close() + if err != nil { + clog.Warningf("unable to close ssh connection: %s", err) + } + } +} + +// verify interface +var _ Client = (*InternalClient)(nil) diff --git a/ssh/openssh_client.go b/ssh/openssh_client.go new file mode 100644 index 000000000..a9448a983 --- /dev/null +++ b/ssh/openssh_client.go @@ -0,0 +1,213 @@ +package ssh + +import ( + "bytes" + "context" + "fmt" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "strconv" + "sync" + "time" + + "github.com/creativeprojects/clog" +) + +type OpenSSHClient struct { + config Config + sshHost string + sshUserHost string + sshPort int + listener net.Listener + server *http.Server + wg sync.WaitGroup + socket string + peerTunnelPort int +} + +func NewOpenSSHClient(config Config) *OpenSSHClient { + return &OpenSSHClient{ + config: config, + } +} + +func (c *OpenSSHClient) Name() string { + return "OpenSSH" +} + +// Connect establishes the SSH connection and starts the file server. +// It returns an error if the connection or server setup fails. +func (c *OpenSSHClient) Connect() error { + c.sshHost, c.sshPort = parseHost(c.config.Host) + c.sshUserHost = c.sshHost + if c.config.Username != "" { + c.sshUserHost = fmt.Sprintf("%s@%s", c.config.Username, c.sshHost) + } + err := c.startFileServer() + if err != nil { + return err + } + err = c.startSSH(context.Background()) + if err != nil { + return fmt.Errorf("error while starting SSH connection: %w", err) + } + err = c.startTunnel(context.Background()) + if err != nil { + return fmt.Errorf("error while starting SSH tunnel: %w", err) + } + return nil +} + +func (c *OpenSSHClient) startFileServer() error { + var err error + c.listener, err = net.Listen("tcp", "localhost:0") + if err != nil { + return err + } + c.server = &http.Server{ + Handler: c.config.Handler, + ReadHeaderTimeout: 5 * time.Second, + } + c.wg.Add(1) + go func() { + defer c.wg.Done() + defer c.listener.Close() + + clog.Debugf("file server listening locally on %s", c.listener.Addr().String()) + err := c.server.Serve(c.listener) + if err != nil && err != http.ErrServerClosed { + clog.Error("error while serving HTTP:", err) + } + }() + return nil +} + +func (c *OpenSSHClient) startSSH(ctx context.Context) error { + c.socket = filepath.Join(os.TempDir(), fmt.Sprintf("ssh-%d.sock", os.Getpid())) + args := make([]string, 0, 10) + args = append(args, + "-f", // Requests ssh to go to background just before command execution + "-M", // Places the ssh client into “master” mode for connection sharing + "-N", // Do not execute a remote command + "-S", c.socket, // Specifies the location of the control socket + ) + if c.config.SSHConfigPath != "" { + args = append(args, "-F", c.config.SSHConfigPath) + } + if c.sshPort > 0 { + args = append(args, "-p", strconv.Itoa(c.sshPort)) + } + if c.config.KnownHostsPath != "" { + args = append(args, "-o", fmt.Sprintf("UserKnownHostsFile=%s", c.config.KnownHostsPath)) + } + if c.config.PrivateKeyPath != "" { + args = append(args, "-i", c.config.PrivateKeyPath) + } + args = append(args, c.sshUserHost) + cmd := exec.CommandContext(ctx, "ssh", args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + clog.Debugf("running command: %s", cmd.String()) + err := cmd.Run() + if err != nil { + return fmt.Errorf("error while running ssh command: %w", err) + } + return nil +} + +func (c *OpenSSHClient) stopSSH(ctx context.Context) error { + if c.socket == "" { + return nil + } + args := []string{ + "-S", c.socket, // Specifies the location of the control socket + "-O", "exit", // Requests the master to exit + c.sshUserHost, // Not used in this case, but required by ssh + } + cmd := exec.CommandContext(ctx, "ssh", args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + clog.Debugf("running command: %s", cmd.String()) + err := cmd.Run() + if err != nil { + return fmt.Errorf("error while running ssh command: %w", err) + } + return nil +} + +func (c *OpenSSHClient) startTunnel(ctx context.Context) error { + if c.socket == "" { + return nil + } + args := []string{ + "-S", c.socket, // Specifies the location of the control socket + "-O", "forward", // Requests the master to exit + fmt.Sprintf("-R 0:localhost:%d", c.listener.Addr().(*net.TCPAddr).Port), // Forward random remote port to local port + c.sshUserHost, // Not used in this case, but required by ssh + } + cmd := exec.CommandContext(ctx, "ssh", args...) + clog.Debugf("running command: %s", cmd.String()) + output, err := cmd.Output() + if err != nil { + return fmt.Errorf("error while running ssh command: %w", err) + } + if len(output) == 0 { + return fmt.Errorf("no output from SSH tunnel command") + } + output = bytes.TrimSpace(output) + port, err := strconv.Atoi(string(output)) + if err != nil { + return fmt.Errorf("error parsing SSH tunnel output: %w", err) + } + c.peerTunnelPort = port + clog.Debugf("port %d opened in tunnel", c.peerTunnelPort) + return nil +} + +func (c *OpenSSHClient) Close() { + ctx := context.Background() + if c.server != nil { + err := c.server.Shutdown(ctx) + if err != nil { + clog.Warningf("unable to shutdown server: %s", err) + } + c.server = nil + } + err := c.stopSSH(ctx) + if err != nil { + clog.Warningf("unable to stop SSH connection: %s", err) + } + c.wg.Wait() +} + +func (c *OpenSSHClient) Run(command string, arguments ...string) error { + if c.socket == "" { + return nil + } + args := append([]string{ + "-t", // Force pseudo-terminal allocation + "-t", // Even when stdin is not attached + "-S", c.socket, // Specifies the location of the control socket + c.sshUserHost, // Not used in this case, but required by ssh + command, + }, arguments...) + cmd := exec.CommandContext(context.Background(), "ssh", args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + clog.Debugf("running command: %s", cmd.String()) + err := cmd.Run() + if err != nil { + return fmt.Errorf("error while running ssh command: %w", err) + } + return nil +} + +func (c *OpenSSHClient) TunnelPeerPort() int { + return c.peerTunnelPort +} + +// verify interface +var _ Client = (*OpenSSHClient)(nil) diff --git a/ssh/parse_host.go b/ssh/parse_host.go new file mode 100644 index 000000000..17c264162 --- /dev/null +++ b/ssh/parse_host.go @@ -0,0 +1,21 @@ +package ssh + +import ( + "strconv" + "strings" +) + +func parseHost(host string) (string, int) { + if strings.Contains(host, ":") { + parts := strings.Split(host, ":") + if len(parts) > 2 { + // If there are more than two parts, we assume the first part is the host and the rest is the port + host = strings.Join(parts[:len(parts)-1], ":") + port, _ := strconv.Atoi(parts[len(parts)-1]) + return host, port + } + port, _ := strconv.Atoi(parts[1]) + return parts[0], port + } + return host, 0 +} diff --git a/ssh/parse_host_test.go b/ssh/parse_host_test.go new file mode 100644 index 000000000..f4974db81 --- /dev/null +++ b/ssh/parse_host_test.go @@ -0,0 +1,69 @@ +package ssh + +import ( + "testing" +) + +func TestParseHost(t *testing.T) { + tests := []struct { + name string + host string + wantHost string + wantPort int + }{ + { + name: "host only", + host: "example.com", + wantHost: "example.com", + wantPort: 0, + }, + { + name: "host with port", + host: "example.com:22", + wantHost: "example.com", + wantPort: 22, + }, + { + name: "IPv4 with port", + host: "192.168.1.1:2222", + wantHost: "192.168.1.1", + wantPort: 2222, + }, + { + name: "IPv6 with port", + host: "[2001:db8::1]:22", + wantHost: "[2001:db8::1]", + wantPort: 22, + }, + { + name: "IPv6 without brackets with port", + host: "2001:db8::1:22", + wantHost: "2001:db8::1", + wantPort: 22, + }, + { + name: "host with multiple colons", + host: "user:pass@example.com:22", + wantHost: "user:pass@example.com", + wantPort: 22, + }, + { + name: "invalid port", + host: "example.com:abc", + wantHost: "example.com", + wantPort: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotHost, gotPort := parseHost(tt.host) + if gotHost != tt.wantHost { + t.Errorf("parseHost() host = %v, want %v", gotHost, tt.wantHost) + } + if gotPort != tt.wantPort { + t.Errorf("parseHost() port = %v, want %v", gotPort, tt.wantPort) + } + }) + } +} diff --git a/ssh/ssh_test.go b/ssh/ssh_test.go new file mode 100644 index 000000000..0dc927ecf --- /dev/null +++ b/ssh/ssh_test.go @@ -0,0 +1,70 @@ +//go:build ssh + +package ssh + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSSHClient(t *testing.T) { + tmpDir := os.Getenv("SSH_TESTS_TMPDIR") + if tmpDir == "" { + tmpDir = filepath.Join(os.TempDir(), "resticprofile-ssh-tests") + } + + fixtures := []struct { + name string + config Config + connectErr bool + }{ + { + name: "no public key", + config: Config{ + Host: "localhost:2222", + Username: "resticprofile", + KnownHostsPath: filepath.Join(tmpDir, "known_hosts"), + }, + connectErr: true, + }, + { + name: "wrong username", + config: Config{ + Host: "localhost:2222", + Username: "otheruser", + KnownHostsPath: filepath.Join(tmpDir, "known_hosts"), + PrivateKeyPath: filepath.Join(tmpDir, "id_rsa"), + }, + connectErr: true, + }, + { + name: "successful connection", + config: Config{ + Host: "localhost:2222", + Username: "resticprofile", + KnownHostsPath: filepath.Join(tmpDir, "known_hosts"), + PrivateKeyPath: filepath.Join(tmpDir, "id_rsa"), + }, + connectErr: false, + }, + } + + for _, fixture := range fixtures { + for _, client := range []Client{NewOpenSSHClient(fixture.config), NewInternalClient(fixture.config)} { + t.Run(client.Name()+" "+fixture.name, func(t *testing.T) { + defer client.Close() + + err := client.Connect() + if fixture.connectErr { + require.Error(t, err) + t.Log(err) + return + } + require.NoError(t, err) + }) + } + } +} diff --git a/ssh/test/allow_tcp_forwarding.conf b/ssh/test/allow_tcp_forwarding.conf new file mode 100644 index 000000000..79af9a816 --- /dev/null +++ b/ssh/test/allow_tcp_forwarding.conf @@ -0,0 +1,2 @@ +Match User resticprofile + AllowTcpForwarding yes diff --git a/ssh/test/docker-compose.yml b/ssh/test/docker-compose.yml new file mode 100644 index 000000000..ca650602d --- /dev/null +++ b/ssh/test/docker-compose.yml @@ -0,0 +1,20 @@ +--- +services: + openssh-server: + image: lscr.io/linuxserver/openssh-server:latest + container_name: openssh-server + hostname: openssh-server + environment: + - PUID=${USER_ID:-1000} + - PGID=${GROUP_ID:-1000} + - TZ=Europe/London + - PUBLIC_KEY_FILE=/id_rsa.pub + - SUDO_ACCESS=false + - PASSWORD_ACCESS=false + - USER_NAME=resticprofile + - LOG_STDOUT=true + volumes: + - ${PWD}/allow_tcp_forwarding.conf:/config/sshd/sshd_config.d/allow_tcp_forwarding.conf + - ${SSH_TESTS_TMPDIR}/id_rsa.pub:/id_rsa.pub + ports: + - 2222:2222