Skip to content

Commit 228fabf

Browse files
committed
review fixes
1 parent 872ecd2 commit 228fabf

File tree

5 files changed

+113
-70
lines changed

5 files changed

+113
-70
lines changed

aes.go

+3-13
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@ import (
77
"fmt"
88
)
99

10-
const (
11-
// in FileWriter we use chunks upto aesChunkSize bytes to encrypt data
12-
aesChunkSize = 1024 * 1024
13-
)
14-
1510
// calculateIV `shifts` IV to given offset
1611
// based on calculateIV from AesCtrCryptoCodec.java
1712
func calculateIV(offset int64, initIV []byte) ([]byte, error) {
@@ -37,10 +32,8 @@ func calculateIV(offset int64, initIV []byte) ([]byte, error) {
3732
return iv, nil
3833
}
3934

40-
// aesCtrStep perform AES-CTR XOR operation on given byte string.
41-
// Once encryption and decryption are exactly the same operation for CTR mode,
42-
// this function can be used to perform both.
43-
func aesCtrStep(offset int64, enc *transparentEncryptionInfo, b []byte) ([]byte, error) {
35+
// aesCreateCTRStream create stream to encrypt/decrypt data from specific offset
36+
func aesCreateCTRStream(offset int64, enc *transparentEncryptionInfo) (cipher.Stream, error) {
4437
iv, err := calculateIV(offset, enc.iv)
4538
if err != nil {
4639
return nil, err
@@ -61,8 +54,5 @@ func aesCtrStep(offset int64, enc *transparentEncryptionInfo, b []byte) ([]byte,
6154
tmp := make([]byte, padding)
6255
stream.XORKeyStream(tmp, tmp)
6356
}
64-
65-
text := make([]byte, len(b))
66-
stream.XORKeyStream(text, b)
67-
return text, nil
57+
return stream, nil
6858
}

aes_test.go

+21-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,31 @@
11
package hdfs
22

33
import (
4+
"bytes"
5+
"crypto/cipher"
46
"testing"
57

68
"github.com/stretchr/testify/assert"
79
)
810

9-
func TestAesChunks(t *testing.T) {
11+
// aesCtrRead perform AES-CTR XOR operation on given byte string.
12+
// Once encryption and decryption are exactly the same operation for CTR mode,
13+
// this function can be used to perform both.
14+
func aesCtrStep(offset int64, enc *transparentEncryptionInfo, b []byte) ([]byte, error) {
15+
stream, err := aesCreateCTRStream(offset, enc)
16+
if err != nil {
17+
return nil, err
18+
}
19+
20+
r := make([]byte, len(b))
21+
_, err = cipher.StreamReader{S: stream, R: bytes.NewReader(b)}.Read(r)
22+
if err != nil {
23+
return nil, err
24+
}
25+
return r, nil
26+
}
27+
28+
func TestAesIV(t *testing.T) {
1029
originalText := []byte("some random plain text, nice to have it quite long")
1130
key := []byte("0123456789abcdef")
1231

@@ -46,5 +65,5 @@ func TestAesChunks(t *testing.T) {
4665
decryptedByChunks = append(decryptedByChunks, tmp...)
4766
pos += int64(x)
4867
}
49-
assert.Equal(t, decryptedByChunks, originalText)
68+
assert.Equal(t, decryptedByChunks, decryptedText)
5069
}

file_reader.go

+22-20
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ type transparentEncryptionInfo struct {
4040
key []byte
4141
iv []byte
4242
cipher cipher.Block
43+
stream cipher.Stream
4344
}
4445

4546
// Open returns an FileReader which can be used for reading.
@@ -184,6 +185,12 @@ func (f *FileReader) Seek(offset int64, whence int) (int64, error) {
184185
if f.offset != off {
185186
f.offset = off
186187

188+
// To make things simpler, we just destroy cipher.Stream (if any)
189+
// It will be recreated in Read()
190+
if f.enc != nil {
191+
f.enc.stream = nil
192+
}
193+
187194
if f.blockReader != nil {
188195
// If the seek is within the next few chunks, it's much more
189196
// efficient to throw away a few bytes than to reconnect and start
@@ -209,25 +216,6 @@ func (f *FileReader) Read(b []byte) (int, error) {
209216
return 0, io.ErrClosedPipe
210217
}
211218

212-
offset := f.offset
213-
n, err := f.readImpl(b)
214-
215-
// Decrypt data chunk if file from HDFS encrypted zone.
216-
if f.enc != nil && n > 0 {
217-
plaintext, err := aesCtrStep(offset, f.enc, b[:n])
218-
if err != nil {
219-
f.offset = offset
220-
return 0, err
221-
}
222-
for i := 0; i < n; i++ {
223-
b[i] = plaintext[i]
224-
}
225-
}
226-
227-
return n, err
228-
}
229-
230-
func (f *FileReader) readImpl(b []byte) (int, error) {
231219
if f.info.IsDir() {
232220
return 0, &os.PathError{
233221
"read",
@@ -259,7 +247,21 @@ func (f *FileReader) readImpl(b []byte) (int, error) {
259247
}
260248
}
261249

262-
n, err := f.blockReader.Read(b)
250+
var n int
251+
var err error
252+
253+
if f.enc != nil {
254+
if f.enc.stream == nil {
255+
f.enc.stream, err = aesCreateCTRStream(f.offset, f.enc)
256+
if err != nil {
257+
return 0, err
258+
}
259+
}
260+
n, err = cipher.StreamReader{S: f.enc.stream, R: f.blockReader}.Read(b)
261+
} else {
262+
n, err = f.blockReader.Read(b)
263+
}
264+
263265
f.offset += int64(n)
264266

265267
if err != nil && err != io.EOF {

file_writer.go

+20-30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package hdfs
22

33
import (
4+
"crypto/cipher"
45
"errors"
56
"os"
67
"time"
@@ -201,28 +202,6 @@ func (f *FileWriter) SetDeadline(t time.Time) error {
201202
// of this, it is important that Close is called after all data has been
202203
// written.
203204
func (f *FileWriter) Write(b []byte) (int, error) {
204-
// Encrypt data chunk if file in HDFS encrypted zone.
205-
if f.enc != nil && len(b) > 0 {
206-
var offset int
207-
for offset < len(b) {
208-
size := min(len(b)-offset, aesChunkSize)
209-
ciphertext, err := aesCtrStep(f.offset, f.enc, b[offset:offset+size])
210-
if err != nil {
211-
return offset, err
212-
}
213-
writtenSize, err := f.writeImpl(ciphertext)
214-
offset += writtenSize
215-
if err != nil {
216-
return offset, err
217-
}
218-
}
219-
return offset, nil
220-
} else {
221-
return f.writeImpl(b)
222-
}
223-
}
224-
225-
func (f *FileWriter) writeImpl(b []byte) (int, error) {
226205
if f.blockWriter == nil {
227206
err := f.startNewBlock()
228207
if err != nil {
@@ -232,7 +211,25 @@ func (f *FileWriter) writeImpl(b []byte) (int, error) {
232211

233212
off := 0
234213
for off < len(b) {
235-
n, err := f.blockWriter.Write(b[off:])
214+
var n int
215+
var err error
216+
217+
if f.enc != nil {
218+
if f.enc.stream == nil {
219+
f.enc.stream, err = aesCreateCTRStream(f.offset, f.enc)
220+
if err != nil {
221+
return 0, err
222+
}
223+
}
224+
n, err = cipher.StreamWriter{S: f.enc.stream, W: f.blockWriter}.Write(b[off:])
225+
// If blockWriter writes less than expected bytes,
226+
// we must recreate stream chipher, since it's internal counter goes forward.
227+
if n != len(b[off:]) {
228+
f.enc.stream = nil
229+
}
230+
} else {
231+
n, err = f.blockWriter.Write(b[off:])
232+
}
236233
off += n
237234
f.offset += int64(n)
238235
if err == transfer.ErrEndOfBlock {
@@ -364,10 +361,3 @@ func (f *FileWriter) finalizeBlock() error {
364361
f.blockWriter = nil
365362
return nil
366363
}
367-
368-
func min(a, b int) int {
369-
if a < b {
370-
return a
371-
}
372-
return b
373-
}

file_writer_test.go

+47-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"io/ioutil"
77
"math/rand"
88
"os"
9+
"os/exec"
910
"path/filepath"
1011
"strings"
1112
"testing"
@@ -560,6 +561,10 @@ func TestEncryptedZoneWriteChunks(t *testing.T) {
560561
bytes, err := ioutil.ReadAll(reader)
561562
require.NoError(t, err)
562563
assert.Equal(t, originalText, bytes)
564+
565+
hdfsOut, err := exec.Command("hadoop", "dfs", "-cat", "/_test/kms/write_chunks.txt").Output()
566+
require.NoError(t, err)
567+
assert.Equal(t, originalText, hdfsOut)
563568
}
564569

565570
func TestEncryptedZoneAppendChunks(t *testing.T) {
@@ -586,28 +591,65 @@ func TestEncryptedZoneAppendChunks(t *testing.T) {
586591
bytes, err := ioutil.ReadAll(reader)
587592
require.NoError(t, err)
588593
assert.Equal(t, originalText, bytes)
594+
595+
hdfsOut, err := exec.Command("hadoop", "dfs", "-cat", "/_test/kms/append_chunks.txt").Output()
596+
require.NoError(t, err)
597+
assert.Equal(t, originalText, hdfsOut)
589598
}
590599

591600
func TestEncryptedZoneLargeBlock(t *testing.T) {
592601
skipWithoutEncryptedZone(t)
593602

594-
// Generate quite large (aesChunkSize * 1.5 bytes) block, so we can trigger encryption in chunks.
595-
str := "some random text"
596-
originalText := []byte(strings.Repeat(str, aesChunkSize*1.5/len(str)))
603+
// Generate quite large data block, so we can trigger encryption in chunks.
604+
mobydick, err := os.Open("testdata/mobydick.txt")
605+
require.NoError(t, err)
606+
originalText, err := ioutil.ReadAll(mobydick)
607+
require.NoError(t, err)
597608
client := getClient(t)
598609

599610
// Create file with small (128Kb) block size, so encrypted chunk will be placed over multiple hdfs blocks.
600-
writer, err := client.CreateFile("/_test/kms/large_write.txt", 1, 131072, 0755)
611+
writer, err := client.CreateFile("/_test/kms/mobydick.unittest", 1, 131072, 0755)
601612
require.NoError(t, err)
602613

603614
_, err = writer.Write(originalText)
604615
require.NoError(t, err)
605616
assertClose(t, writer)
606617

607-
reader, err := client.Open("/_test/kms/large_write.txt")
618+
reader, err := client.Open("/_test/kms/mobydick.unittest")
608619
require.NoError(t, err)
620+
bytes, err := ioutil.ReadAll(reader)
621+
require.NoError(t, err)
622+
assert.Equal(t, originalText, bytes)
609623

624+
// Ensure read after seek works as expected:
625+
_, err = reader.Seek(35657, io.SeekStart)
626+
require.NoError(t, err)
627+
bytes = make([]byte, 64)
628+
_, err = reader.Read(bytes)
629+
require.NoError(t, err)
630+
assert.Equal(t, []byte("By reason of these things, then, the whaling voyage was welcome;"), bytes)
631+
632+
hdfsOut, err := exec.Command("hadoop", "dfs", "-cat", "/_test/kms/mobydick.unittest").Output()
633+
require.NoError(t, err)
634+
assert.Equal(t, originalText, hdfsOut)
635+
}
636+
637+
func TestEncryptedZoneReadAfterJava(t *testing.T) {
638+
skipWithoutEncryptedZone(t)
639+
640+
err := exec.Command("hadoop", "dfs", "-copyFromLocal", "testdata/mobydick.txt", "/_test/kms/mobydick.java").Run()
641+
require.NoError(t, err)
642+
643+
mobydick, err := os.Open("testdata/mobydick.txt")
644+
require.NoError(t, err)
645+
originalText, err := ioutil.ReadAll(mobydick)
646+
require.NoError(t, err)
647+
648+
client := getClient(t)
649+
reader, err := client.Open("/_test/kms/mobydick.java")
650+
require.NoError(t, err)
610651
bytes, err := ioutil.ReadAll(reader)
611652
require.NoError(t, err)
653+
612654
assert.Equal(t, originalText, bytes)
613655
}

0 commit comments

Comments
 (0)