Skip to content

Commit f434413

Browse files
SNOW-1029646: Add WithFileGetStream that supports downloading a file into stream (#1192)
Added the context WithFileGetStream to support downloading a file into in-memory stream.
1 parent aa4b7cd commit f434413

12 files changed

+329
-115
lines changed

azure_storage_client.go

+24-10
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type azureAPI interface {
3434
UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
3535
UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
3636
DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
37+
DownloadStream(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error)
3738
GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
3839
}
3940

@@ -276,16 +277,29 @@ func (util *snowflakeAzureClient) nativeDownloadFile(
276277
if meta.mockAzureClient != nil {
277278
blobClient = meta.mockAzureClient
278279
}
279-
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
280-
if err != nil {
281-
return err
282-
}
283-
defer f.Close()
284-
_, err = blobClient.DownloadFile(
285-
context.Background(), f, &azblob.DownloadFileOptions{
286-
Concurrency: uint16(maxConcurrency)})
287-
if err != nil {
288-
return err
280+
if meta.options.getFileToStream {
281+
blobDownloadResponse, err := blobClient.DownloadStream(context.Background(), &azblob.DownloadStreamOptions{})
282+
if err != nil {
283+
return err
284+
}
285+
retryReader := blobDownloadResponse.NewRetryReader(context.Background(), &azblob.RetryReaderOptions{})
286+
defer retryReader.Close()
287+
_, err = meta.dstStream.ReadFrom(retryReader)
288+
if err != nil {
289+
return err
290+
}
291+
} else {
292+
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
293+
if err != nil {
294+
return err
295+
}
296+
defer f.Close()
297+
_, err = blobClient.DownloadFile(
298+
context.Background(), f, &azblob.DownloadFileOptions{
299+
Concurrency: uint16(maxConcurrency)})
300+
if err != nil {
301+
return err
302+
}
289303
}
290304
meta.resStatus = downloaded
291305
return nil

azure_storage_client_test.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,11 @@ func TestUnitDetectAzureTokenExpireError(t *testing.T) {
109109
}
110110

111111
type azureObjectAPIMock struct {
112-
UploadStreamFunc func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
113-
UploadFileFunc func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
114-
DownloadFileFunc func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
115-
GetPropertiesFunc func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
112+
UploadStreamFunc func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
113+
UploadFileFunc func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
114+
DownloadFileFunc func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
115+
DownloadStreamFunc func(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error)
116+
GetPropertiesFunc func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
116117
}
117118

118119
func (c *azureObjectAPIMock) UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) {
@@ -131,6 +132,10 @@ func (c *azureObjectAPIMock) DownloadFile(ctx context.Context, file *os.File, o
131132
return c.DownloadFileFunc(ctx, file, o)
132133
}
133134

135+
func (c *azureObjectAPIMock) DownloadStream(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error) {
136+
return c.DownloadStreamFunc(ctx, o)
137+
}
138+
134139
func TestUploadFileWithAzureUploadFailedError(t *testing.T) {
135140
info := execResponseStageInfo{
136141
Location: "azblob/storage/users/456/",

connection_util.go

+24-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package gosnowflake
55
import (
66
"bytes"
77
"context"
8+
"errors"
89
"fmt"
910
"io"
1011
"os"
@@ -88,10 +89,11 @@ func (sc *snowflakeConn) processFileTransfer(
8889
isInternal bool) (
8990
*execResponse, error) {
9091
sfa := snowflakeFileTransferAgent{
91-
sc: sc,
92-
data: &data.Data,
93-
command: query,
94-
options: new(SnowflakeFileTransferOptions),
92+
sc: sc,
93+
data: &data.Data,
94+
command: query,
95+
options: new(SnowflakeFileTransferOptions),
96+
streamBuffer: new(bytes.Buffer),
9597
}
9698
if fs := getFileStream(ctx); fs != nil {
9799
sfa.sourceStream = fs
@@ -112,6 +114,11 @@ func (sc *snowflakeConn) processFileTransfer(
112114
if err != nil {
113115
return nil, err
114116
}
117+
if sfa.options.getFileToStream {
118+
if err := writeFileStream(ctx, sfa.streamBuffer); err != nil {
119+
return nil, err
120+
}
121+
}
115122
return data, nil
116123
}
117124

@@ -138,6 +145,19 @@ func getFileTransferOptions(ctx context.Context) *SnowflakeFileTransferOptions {
138145
return o
139146
}
140147

148+
func writeFileStream(ctx context.Context, streamBuf *bytes.Buffer) error {
149+
s := ctx.Value(fileGetStream)
150+
w, ok := s.(io.Writer)
151+
if !ok {
152+
return errors.New("expected an io.Writer")
153+
}
154+
_, err := streamBuf.WriteTo(w)
155+
if err != nil {
156+
return err
157+
}
158+
return nil
159+
}
160+
141161
func (sc *snowflakeConn) populateSessionParameters(parameters []nameValueParameter) {
142162
// other session parameters (not all)
143163
logger.WithContext(sc.ctx).Infof("params: %#v", parameters)

doc.go

+13
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,19 @@ an absolute path rather than a relative path. For example:
12541254
12551255
db.Query("GET @~ file:///tmp/my_data_file auto_compress=false overwrite=false")
12561256
1257+
To download a file into an in-memory stream (rather than a file) use code similar to the code below.
1258+
1259+
var streamBuf bytes.Buffer
1260+
ctx := WithFileTransferOptions(context.Background(), &SnowflakeFileTransferOptions{getFileToStream: true})
1261+
ctx = WithFileGetStream(ctx, &streamBuf)
1262+
1263+
sql := "get @~/data1.txt.gz file:///tmp/testData"
1264+
dbt.mustExecContext(ctx, sql)
1265+
// streamBuf is now filled with the stream. Use bytes.NewReader(streamBuf.Bytes()) to read uncompressed stream or
1266+
// use gzip.NewReader(&streamBuf) for to read compressed stream.
1267+
1268+
Note: GET statements are not supported for multi-statement queries.
1269+
12571270
Specifying temporary directory for encryption and compression:
12581271
12591272
Putting and getting requires compression and/or encryption, which is done in the OS temporary directory.

encrypt_util.go

+51-21
Original file line numberDiff line numberDiff line change
@@ -190,46 +190,51 @@ func encryptFile(
190190
return meta, tmpOutputFile.Name(), nil
191191
}
192192

193-
func decryptFile(
193+
func decryptFileKey(
194194
metadata *encryptMetadata,
195-
sfe *snowflakeFileEncryption,
196-
filename string,
197-
chunkSize int,
198-
tmpDir string) (
199-
string, error) {
200-
if chunkSize == 0 {
201-
chunkSize = aes.BlockSize * 4 * 1024
202-
}
195+
sfe *snowflakeFileEncryption) ([]byte, []byte, error) {
203196
decodedKey, err := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey)
204197
if err != nil {
205-
return "", err
198+
return nil, nil, err
206199
}
207200
keyBytes, err := base64.StdEncoding.DecodeString(metadata.key) // encrypted file key
208201
if err != nil {
209-
return "", err
202+
return nil, nil, err
210203
}
211204
ivBytes, err := base64.StdEncoding.DecodeString(metadata.iv)
212205
if err != nil {
213-
return "", err
206+
return nil, nil, err
214207
}
215208

216209
// decrypt file key
217210
decryptedKey := make([]byte, len(keyBytes))
218211
if err = decryptECB(decryptedKey, keyBytes, decodedKey); err != nil {
219-
return "", err
212+
return nil, nil, err
220213
}
221214
decryptedKey, err = paddingTrim(decryptedKey)
222215
if err != nil {
223-
return "", err
216+
return nil, nil, err
224217
}
225218

226-
// decrypt file
219+
return decryptedKey, ivBytes, err
220+
}
221+
222+
func initCBC(decryptedKey []byte, ivBytes []byte) (cipher.BlockMode, error) {
227223
block, err := aes.NewCipher(decryptedKey)
228224
if err != nil {
229-
return "", err
225+
return nil, err
230226
}
231227
mode := cipher.NewCBCDecrypter(block, ivBytes)
232228

229+
return mode, err
230+
}
231+
232+
func decryptFile(
233+
metadata *encryptMetadata,
234+
sfe *snowflakeFileEncryption,
235+
filename string,
236+
chunkSize int,
237+
tmpDir string) (string, error) {
233238
tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#")
234239
if err != nil {
235240
return "", err
@@ -240,11 +245,37 @@ func decryptFile(
240245
return "", err
241246
}
242247
defer infile.Close()
248+
totalFileSize, err := decryptStream(metadata, sfe, chunkSize, infile, tmpOutputFile)
249+
if err != nil {
250+
return "", err
251+
}
252+
tmpOutputFile.Truncate(int64(totalFileSize))
253+
return tmpOutputFile.Name(), nil
254+
}
255+
256+
func decryptStream(
257+
metadata *encryptMetadata,
258+
sfe *snowflakeFileEncryption,
259+
chunkSize int,
260+
src io.Reader,
261+
out io.Writer) (int, error) {
262+
if chunkSize == 0 {
263+
chunkSize = aes.BlockSize * 4 * 1024
264+
}
265+
decryptedKey, ivBytes, err := decryptFileKey(metadata, sfe)
266+
if err != nil {
267+
return 0, err
268+
}
269+
mode, err := initCBC(decryptedKey, ivBytes)
270+
if err != nil {
271+
return 0, err
272+
}
273+
243274
var totalFileSize int
244275
var prevChunk []byte
245276
for {
246277
chunk := make([]byte, chunkSize)
247-
n, err := infile.Read(chunk)
278+
n, err := src.Read(chunk)
248279
if n == 0 || err != nil {
249280
break
250281
} else if n%aes.BlockSize != 0 {
@@ -255,17 +286,16 @@ func decryptFile(
255286
totalFileSize += n
256287
chunk = chunk[:n]
257288
mode.CryptBlocks(chunk, chunk)
258-
tmpOutputFile.Write(chunk)
289+
out.Write(chunk)
259290
prevChunk = chunk
260291
}
261292
if err != nil {
262-
return "", err
293+
return 0, err
263294
}
264295
if prevChunk != nil {
265296
totalFileSize -= paddingOffset(prevChunk)
266297
}
267-
tmpOutputFile.Truncate(int64(totalFileSize))
268-
return tmpOutputFile.Name(), nil
298+
return totalFileSize, err
269299
}
270300

271301
type materialDescriptor struct {

file_transfer_agent.go

+5
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ type SnowflakeFileTransferOptions struct {
9191
/* streaming PUT */
9292
compressSourceFromStream bool
9393

94+
/* streaming GET */
95+
getFileToStream bool
96+
9497
/* PUT */
9598
putCallback *snowflakeProgressPercentage
9699
putAzureCallback *snowflakeProgressPercentage
@@ -124,6 +127,7 @@ type snowflakeFileTransferAgent struct {
124127
useAccelerateEndpoint bool
125128
presignedURLs []string
126129
options *SnowflakeFileTransferOptions
130+
streamBuffer *bytes.Buffer
127131
}
128132

129133
func (sfa *snowflakeFileTransferAgent) execute() error {
@@ -411,6 +415,7 @@ func (sfa *snowflakeFileTransferAgent) initFileMetadata() error {
411415
name: baseName(fileName),
412416
srcFileName: fileName,
413417
dstFileName: dstFileName,
418+
dstStream: new(bytes.Buffer),
414419
stageLocationType: sfa.stageLocationType,
415420
stageInfo: sfa.stageInfo,
416421
localLocation: sfa.localLocation,

file_util.go

+3
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ type fileMetadata struct {
137137
srcStream *bytes.Buffer
138138
realSrcStream *bytes.Buffer
139139

140+
/* streaming GET */
141+
dstStream *bytes.Buffer
142+
140143
/* GCS */
141144
presignedURL *url.URL
142145
gcsFileHeaderDigest string

gcs_storage_client.go

+18-13
Original file line numberDiff line numberDiff line change
@@ -322,13 +322,24 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
322322
return meta.lastError
323323
}
324324

325-
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
326-
if err != nil {
327-
return err
328-
}
329-
defer f.Close()
330-
if _, err = io.Copy(f, resp.Body); err != nil {
331-
return err
325+
if meta.options.getFileToStream {
326+
if _, err := io.Copy(meta.dstStream, resp.Body); err != nil {
327+
return err
328+
}
329+
} else {
330+
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
331+
if err != nil {
332+
return err
333+
}
334+
defer f.Close()
335+
if _, err = io.Copy(f, resp.Body); err != nil {
336+
return err
337+
}
338+
fi, err := os.Stat(fullDstFileName)
339+
if err != nil {
340+
return err
341+
}
342+
meta.srcFileSize = fi.Size()
332343
}
333344

334345
var encryptMeta encryptMetadata
@@ -348,12 +359,6 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
348359
}
349360
}
350361
}
351-
352-
fi, err := os.Stat(fullDstFileName)
353-
if err != nil {
354-
return err
355-
}
356-
meta.srcFileSize = fi.Size()
357362
meta.resStatus = downloaded
358363
meta.gcsFileHeaderDigest = resp.Header.Get(gcsMetadataSfcDigest)
359364
meta.gcsFileHeaderContentLength = resp.ContentLength

0 commit comments

Comments
 (0)