Skip to content

Commit 673e693

Browse files
optimize PUT memory usage
1 parent 951ae61 commit 673e693

File tree

3 files changed

+63
-40
lines changed

3 files changed

+63
-40
lines changed

connection_util.go

+23-9
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (sc *snowflakeConn) processFileTransfer(
119119
if sfa.options.MultiPartThreshold == 0 {
120120
sfa.options.MultiPartThreshold = dataSizeThreshold
121121
}
122-
if err := sfa.execute(); err != nil {
122+
if err = sfa.execute(); err != nil {
123123
return nil, err
124124
}
125125
data, err = sfa.result()
@@ -134,18 +134,32 @@ func (sc *snowflakeConn) processFileTransfer(
134134
return data, nil
135135
}
136136

137-
func getFileStream(ctx context.Context) (*bytes.Buffer, error) {
137+
func getReaderFromContext(ctx context.Context) io.Reader {
138138
s := ctx.Value(fileStreamFile)
139-
if s == nil {
140-
return nil, nil
141-
}
142139
r, ok := s.(io.Reader)
143140
if !ok {
144-
return nil, errors.New("incorrect io.Reader")
141+
return nil
142+
}
143+
return r
144+
}
145+
146+
func getFileStream(ctx context.Context) (*bytes.Buffer, error) {
147+
r := getReaderFromContext(ctx)
148+
if r == nil {
149+
return nil, nil
150+
}
151+
152+
// read a small amount of data to check if file stream will be used
153+
buf := make([]byte, defaultStringBufferSize)
154+
for {
155+
_, err := r.Read(buf)
156+
if err != nil {
157+
return nil, err
158+
} else {
159+
break
160+
}
145161
}
146-
buf := new(bytes.Buffer)
147-
_, err := buf.ReadFrom(r)
148-
return buf, err
162+
return bytes.NewBuffer(buf), nil
149163
}
150164

151165
func getFileTransferOptions(ctx context.Context) *SnowflakeFileTransferOptions {

file_transfer_agent.go

+5-20
Original file line numberDiff line numberDiff line change
@@ -463,20 +463,9 @@ func (sfa *snowflakeFileTransferAgent) processFileCompressionType() error {
463463
if currentFileCompressionType == nil {
464464
var mtype *mimetype.MIME
465465
var err error
466-
if meta.srcStream != nil {
467-
r := getReaderFromBuffer(&meta.srcStream)
468-
mtype, err = mimetype.DetectReader(r)
469-
if err != nil {
470-
return err
471-
}
472-
if _, err = io.ReadAll(r); err != nil { // flush out tee buffer
473-
return err
474-
}
475-
} else {
476-
mtype, err = mimetype.DetectFile(fileName)
477-
if err != nil {
478-
return err
479-
}
466+
mtype, err = mimetype.DetectFile(fileName)
467+
if err != nil {
468+
return err
480469
}
481470
currentFileCompressionType = lookupByExtension(mtype.Extension())
482471
}
@@ -858,7 +847,7 @@ func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileM
858847
fileUtil := new(snowflakeFileUtil)
859848
if meta.requireCompress {
860849
if meta.srcStream != nil {
861-
meta.realSrcStream, _, err = fileUtil.compressFileWithGzipFromStream(&meta.srcStream)
850+
meta.realSrcStream, _, err = fileUtil.compressFileWithGzipFromStream(sfa.ctx)
862851
} else {
863852
meta.realSrcFileName, _, err = fileUtil.compressFileWithGzip(meta.srcFileName, tmpDir)
864853
}
@@ -868,11 +857,7 @@ func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileM
868857
}
869858

870859
if meta.srcStream != nil {
871-
if meta.realSrcStream != nil {
872-
meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.realSrcStream)
873-
} else {
874-
meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.srcStream)
875-
}
860+
meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.realSrcStream, &meta.srcStream, sfa.ctx)
876861
} else {
877862
meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForFile(meta.realSrcFileName)
878863
}

file_util.go

+35-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package gosnowflake
55
import (
66
"bytes"
77
"compress/gzip"
8+
"context"
89
"crypto/sha256"
910
"encoding/base64"
1011
"io"
@@ -23,16 +24,28 @@ const (
2324
readWriteFileMode os.FileMode = 0666
2425
)
2526

26-
func (util *snowflakeFileUtil) compressFileWithGzipFromStream(srcStream **bytes.Buffer) (*bytes.Buffer, int, error) {
27-
r := getReaderFromBuffer(srcStream)
28-
buf, err := io.ReadAll(r)
29-
if err != nil {
30-
return nil, -1, err
31-
}
27+
func (util *snowflakeFileUtil) compressFileWithGzipFromStream(ctx context.Context) (*bytes.Buffer, int, error) {
3228
var c bytes.Buffer
3329
w := gzip.NewWriter(&c)
34-
if _, err := w.Write(buf); err != nil { // write buf to gzip writer
35-
return nil, -1, err
30+
buf := make([]byte, fileChunkSize)
31+
r := getReaderFromContext(ctx)
32+
if r == nil {
33+
return nil, -1, nil
34+
}
35+
36+
// read the whole file in chunks
37+
for {
38+
n, err := r.Read(buf)
39+
if err == io.EOF {
40+
break
41+
}
42+
if err != nil {
43+
return nil, -1, err
44+
}
45+
// write buf to gzip writer
46+
if _, err = w.Write(buf[:n]); err != nil {
47+
return nil, -1, err
48+
}
3649
}
3750
if err := w.Close(); err != nil {
3851
return nil, -1, err
@@ -75,11 +88,22 @@ func (util *snowflakeFileUtil) compressFileWithGzip(fileName string, tmpDir stri
7588
return gzipFileName, stat.Size(), err
7689
}
7790

78-
func (util *snowflakeFileUtil) getDigestAndSizeForStream(stream **bytes.Buffer) (string, int64, error) {
91+
func (util *snowflakeFileUtil) getDigestAndSizeForStream(realSrcStream **bytes.Buffer, srcStream **bytes.Buffer, ctx context.Context) (string, int64, error) {
92+
var r io.Reader
93+
var stream **bytes.Buffer
94+
if realSrcStream != nil {
95+
r = getReaderFromBuffer(srcStream)
96+
stream = realSrcStream
97+
} else {
98+
r = getReaderFromContext(ctx)
99+
stream = srcStream
100+
}
101+
if r == nil {
102+
return "", 0, nil
103+
}
104+
79105
m := sha256.New()
80-
r := getReaderFromBuffer(stream)
81106
chunk := make([]byte, fileChunkSize)
82-
83107
for {
84108
n, err := r.Read(chunk)
85109
if err == io.EOF {

0 commit comments

Comments
 (0)