Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions policies/recipes/steps.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
agentendpointpb.SoftwareRecipe_Step_RunScript_POWERSHELL: ".ps1",
}

var chown = chownFunc

func stepCopyFile(step *agentendpointpb.SoftwareRecipe_Step_CopyFile, artifacts map[string]string, runEnvs []string, stepDir string) error {
dest, err := util.NormPath(step.Destination)
if err != nil {
Expand Down Expand Up @@ -127,19 +129,46 @@
return strings.HasSuffix(name, "/")
}

func ensureFilePathBelongsToDir(dirPath string, filePath string) error {
dirAbs, err := filepath.Abs(dirPath)
if err != nil {
return err
}
fileAbs, err := filepath.Abs(filePath)
if err != nil {
return err
}

rel, err := filepath.Rel(dirAbs, fileAbs)
if err != nil {
return err
}

if strings.HasPrefix(rel, "..") {
return fmt.Errorf("path %s, does not belongs to dir %s, rel %s", filePath, dirPath, rel)
}

return nil
}

func extractZip(zipPath string, dst string) error {
zr, err := zip.OpenReader(zipPath)
if err != nil {
return err
}
defer zr.Close()

// Check for conflicts
// Check that we can extract zip
for _, f := range zr.File {
filen, err := util.NormPath(util.SanitizePath(filepath.Join(dst, f.Name)))
filen, err := util.NormPath(filepath.Join(dst, f.Name))
if err != nil {
return err
}

if err := ensureFilePathBelongsToDir(dst, filen); err != nil {
return fmt.Errorf("unable to extract zip arhive %s: %w", zipPath, err)
}

stat, err := os.Stat(filen)
if os.IsNotExist(err) {
continue
Expand All @@ -151,12 +180,12 @@
// it's ok if directories already exist
continue
}
return fmt.Errorf("file exists: %s", filen)
return fmt.Errorf("unable to extract zip archive %s: file %s is already exists", zipPath, filen)
}

// Create files.
for _, f := range zr.File {
filen, err := util.NormPath(util.SanitizePath(filepath.Join(dst, f.Name)))
filen, err := util.NormPath(filepath.Join(dst, f.Name))
if err != nil {
return err
}
Expand Down Expand Up @@ -240,6 +269,11 @@
if err != nil {
return err
}

if err := ensureFilePathBelongsToDir(dst, filen); err != nil {
return err
}

stat, err := os.Stat(filen)
if os.IsNotExist(err) {
continue
Expand Down Expand Up @@ -270,7 +304,7 @@
tr := tar.NewReader(decompressed)

if err := checkForConflicts(tr, dst); err != nil {
return err
return fmt.Errorf("unable to extract tar arhive %s: %s", tarName, err)
}

file.Seek(0, 0)
Expand All @@ -289,10 +323,11 @@
if err != nil {
return err
}
filen, err := util.NormPath(filepath.Join(dst, util.SanitizePath(header.Name)))
filen, err := util.NormPath(filepath.Join(dst, header.Name))
if err != nil {
return err
}

filedir := filepath.Dir(filen)

if err := os.MkdirAll(filedir, 0700); err != nil {
Expand Down Expand Up @@ -506,7 +541,7 @@
return err
}

func chown(file string, uid, gid int) error {
func chownFunc(file string, uid, gid int) error {
// os.Chown unsupported on windows
if runtime.GOOS == "windows" {
return nil
Expand Down
207 changes: 207 additions & 0 deletions policies/recipes/steps_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package recipes

import (
"archive/tar"
"archive/zip"
"context"
"fmt"
"os"
"path/filepath"
"regexp"
"testing"
"time"

"cloud.google.com/go/osconfig/agentendpoint/apiv1beta/agentendpointpb"
)

type fileEntry struct {
name string
content []byte
}

func Test_extractTar(t *testing.T) {
chownActual := chownFunc
chown = func(string, int, int) error {
return nil
}

defer func() { chown = chownActual }()

tests := []struct {
name string
entries []fileEntry
wantErrRegexp *regexp.Regexp
}{
{
name: "base case scenario",
entries: []fileEntry{
{
name: "test1", content: []byte("test1"),
},
{
name: "test2", content: []byte("test2"),
},
},
wantErrRegexp: nil,
},
{
name: "tar with vulnerable path, fail with expected error",
entries: []fileEntry{
{
name: "../test1", content: []byte("test1"),
},
{
name: "test2", content: []byte("test2"),
},
},
wantErrRegexp: regexp.MustCompile("^unable to extract tar arhive /tmp/[0-9]+/extractTar.tar: path /tmp/test1, does not belongs to dir /tmp/[0-9]+, rel ../test1$"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

tmpDir, tmpFile, err := getTempDirAndFile(t, "extractTar.tar")
if err != nil {
t.Errorf("unable to create tmp file: %s", err)
}

ensureTar(t, tmpFile.Name(), tt.entries)

ctx := context.Background()
err = extractTar(ctx, tmpFile.Name(), tmpDir, agentendpointpb.SoftwareRecipe_Step_ExtractArchive_TAR)
if tt.wantErrRegexp == nil && err == nil {
return
}
fmt.Println(err.Error())

msg := fmt.Sprintf("%s", err)
if !tt.wantErrRegexp.MatchString(msg) {
t.Errorf("Unexpecte error, expect message to match regexp %s, got %s", tt.wantErrRegexp, err)
}
})

}
}
func Test_extractZip(t *testing.T) {
tests := []struct {
name string
entries []fileEntry
wantErrRegexp *regexp.Regexp
}{
{
name: "base case scenario",
entries: []fileEntry{
{
name: "test1", content: []byte("test1"),
},
{
name: "test2", content: []byte("test2"),
},
},
wantErrRegexp: nil,
},
{
name: "zip with vulnerable path, fail with expected error",
entries: []fileEntry{
{
name: "../test1", content: []byte("test1"),
},
{
name: "test2", content: []byte("test2"),
},
},
wantErrRegexp: regexp.MustCompile("^unable to extract zip arhive /tmp/[0-9]+/extractZip.zip: path /tmp/test1, does not belongs to dir /tmp/[0-9]+, rel ../test1$"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

tmpDir, tmpFile, err := getTempDirAndFile(t, "extractZip.zip")
if err != nil {
t.Errorf("unable to create tmp file: %s", err)
}

ensureZip(t, tmpFile.Name(), tt.entries)

err = extractZip(tmpFile.Name(), tmpDir)
if tt.wantErrRegexp == nil && err == nil {
return
}

msg := fmt.Sprintf("%s", err)
if !tt.wantErrRegexp.MatchString(msg) {
t.Errorf("Unexpecte error, expect message to match regexp %s, got %s", tt.wantErrRegexp, err)
}
})

}
}

func getTempDirAndFile(t *testing.T, fileName string) (dir string, file *os.File, err error) {
tmpDir := filepath.Join(os.TempDir(), fmt.Sprintf("%d", time.Now().UnixNano()))
if err := os.MkdirAll(tmpDir, os.ModePerm); err != nil {
t.Errorf("unable to create tmp dir: %s", err)
return "", nil, err
}

tmpFile, err := os.OpenFile(filepath.Join(tmpDir, fileName), os.O_CREATE|os.O_RDWR, os.ModePerm)
if err != nil {
t.Errorf("unable to create tmp file: %s", err)
return "", nil, err
}

return tmpDir, tmpFile, nil
}

func ensureZip(t *testing.T, dst string, entries []fileEntry) {
fd, err := os.OpenFile(dst, os.O_RDWR, os.ModePerm)
if err != nil {
t.Errorf("unable to open file: %s", err)
}
w := zip.NewWriter(fd)

for _, entry := range entries {
f, err := w.Create(entry.name)
if err != nil {
t.Errorf("unable to create file: %s", err)
}

if _, err = f.Write(entry.content); err != nil {
t.Errorf("unable to write content to file: %s", err)
}
}

if err := w.Close(); err != err {
t.Errorf("unable to close file: %s", err)
}
}

func ensureTar(t *testing.T, dst string, entries []fileEntry) {
fd, err := os.OpenFile(dst, os.O_RDWR, os.ModePerm)
if err != nil {
t.Errorf("unable to open file: %s", err)
}
w := tar.NewWriter(fd)

for _, entry := range entries {
hdr := &tar.Header{
Name: entry.name,
Mode: 0600,
Size: int64(len(entry.content)),
}

if err := w.WriteHeader(hdr); err != nil {
t.Errorf("unable to create file: %s", err)
}

if _, err = w.Write(entry.content); err != nil {
t.Errorf("unable to write content to file: %s", err)
}
}

if err := w.Close(); err != err {
t.Errorf("unable to close file: %s", err)
}
}