@@ -190,46 +190,51 @@ func encryptFile(
190
190
return meta , tmpOutputFile .Name (), nil
191
191
}
192
192
193
- func decryptFile (
193
+ func decryptFileKey (
194
194
metadata * encryptMetadata ,
195
- sfe * snowflakeFileEncryption ,
196
- filename string ,
197
- chunkSize int ,
198
- tmpDir string ) (
199
- string , error ) {
200
- if chunkSize == 0 {
201
- chunkSize = aes .BlockSize * 4 * 1024
202
- }
195
+ sfe * snowflakeFileEncryption ) ([]byte , []byte , error ) {
203
196
decodedKey , err := base64 .StdEncoding .DecodeString (sfe .QueryStageMasterKey )
204
197
if err != nil {
205
- return "" , err
198
+ return nil , nil , err
206
199
}
207
200
keyBytes , err := base64 .StdEncoding .DecodeString (metadata .key ) // encrypted file key
208
201
if err != nil {
209
- return "" , err
202
+ return nil , nil , err
210
203
}
211
204
ivBytes , err := base64 .StdEncoding .DecodeString (metadata .iv )
212
205
if err != nil {
213
- return "" , err
206
+ return nil , nil , err
214
207
}
215
208
216
209
// decrypt file key
217
210
decryptedKey := make ([]byte , len (keyBytes ))
218
211
if err = decryptECB (decryptedKey , keyBytes , decodedKey ); err != nil {
219
- return "" , err
212
+ return nil , nil , err
220
213
}
221
214
decryptedKey , err = paddingTrim (decryptedKey )
222
215
if err != nil {
223
- return "" , err
216
+ return nil , nil , err
224
217
}
225
218
226
- // decrypt file
219
+ return decryptedKey , ivBytes , err
220
+ }
221
+
222
+ func initCBC (decryptedKey []byte , ivBytes []byte ) (cipher.BlockMode , error ) {
227
223
block , err := aes .NewCipher (decryptedKey )
228
224
if err != nil {
229
- return "" , err
225
+ return nil , err
230
226
}
231
227
mode := cipher .NewCBCDecrypter (block , ivBytes )
232
228
229
+ return mode , err
230
+ }
231
+
232
+ func decryptFile (
233
+ metadata * encryptMetadata ,
234
+ sfe * snowflakeFileEncryption ,
235
+ filename string ,
236
+ chunkSize int ,
237
+ tmpDir string ) (string , error ) {
233
238
tmpOutputFile , err := os .CreateTemp (tmpDir , baseName (filename )+ "#" )
234
239
if err != nil {
235
240
return "" , err
@@ -240,11 +245,37 @@ func decryptFile(
240
245
return "" , err
241
246
}
242
247
defer infile .Close ()
248
+ totalFileSize , err := decryptStream (metadata , sfe , chunkSize , infile , tmpOutputFile )
249
+ if err != nil {
250
+ return "" , err
251
+ }
252
+ tmpOutputFile .Truncate (int64 (totalFileSize ))
253
+ return tmpOutputFile .Name (), nil
254
+ }
255
+
256
+ func decryptStream (
257
+ metadata * encryptMetadata ,
258
+ sfe * snowflakeFileEncryption ,
259
+ chunkSize int ,
260
+ src io.Reader ,
261
+ out io.Writer ) (int , error ) {
262
+ if chunkSize == 0 {
263
+ chunkSize = aes .BlockSize * 4 * 1024
264
+ }
265
+ decryptedKey , ivBytes , err := decryptFileKey (metadata , sfe )
266
+ if err != nil {
267
+ return 0 , err
268
+ }
269
+ mode , err := initCBC (decryptedKey , ivBytes )
270
+ if err != nil {
271
+ return 0 , err
272
+ }
273
+
243
274
var totalFileSize int
244
275
var prevChunk []byte
245
276
for {
246
277
chunk := make ([]byte , chunkSize )
247
- n , err := infile .Read (chunk )
278
+ n , err := src .Read (chunk )
248
279
if n == 0 || err != nil {
249
280
break
250
281
} else if n % aes .BlockSize != 0 {
@@ -255,17 +286,16 @@ func decryptFile(
255
286
totalFileSize += n
256
287
chunk = chunk [:n ]
257
288
mode .CryptBlocks (chunk , chunk )
258
- tmpOutputFile .Write (chunk )
289
+ out .Write (chunk )
259
290
prevChunk = chunk
260
291
}
261
292
if err != nil {
262
- return "" , err
293
+ return 0 , err
263
294
}
264
295
if prevChunk != nil {
265
296
totalFileSize -= paddingOffset (prevChunk )
266
297
}
267
- tmpOutputFile .Truncate (int64 (totalFileSize ))
268
- return tmpOutputFile .Name (), nil
298
+ return totalFileSize , err
269
299
}
270
300
271
301
type materialDescriptor struct {
0 commit comments