@@ -111,7 +111,7 @@ def derive(self, password):
111
111
v = self ._algorithm .block_size
112
112
113
113
# Step 1 - Concatenate v/8 copies of ID
114
- d = chr ( self ._id ) * v
114
+ d = bytes ([ self ._id ] ) * v
115
115
116
116
def concatenate_string (inp ):
117
117
s = b''
@@ -135,7 +135,6 @@ def concatenate_string(inp):
135
135
c = int (math .ceil (float (self ._length ) / u ))
136
136
137
137
# Step 6
138
-
139
138
def digest (inp ):
140
139
h = Hash (self ._algorithm ())
141
140
h .update (inp )
@@ -144,17 +143,17 @@ def digest(inp):
144
143
def to_int (value ):
145
144
if value == b'' :
146
145
return 0
147
- return int ( value . encode ( "hex" ), 16 )
146
+ return int . from_bytes ( value , byteorder = 'big' )
148
147
149
- def to_bytes (value ):
150
- value = "%x" % value
151
- if len (value ) & 1 :
152
- value = "0" + value
153
- return value .decode ("hex" )
148
+ def to_bytes (value , length ):
149
+ try :
150
+ return value .to_bytes (length , byteorder = 'big' )
151
+ except OverflowError :
152
+ # If the integer is too large, we'll take the least significant bytes
153
+ return (value & ((1 << (8 * length )) - 1 )).to_bytes (length , byteorder = 'big' )
154
154
155
155
a = b'\x00 ' * (c * u )
156
156
for n in range (1 , c + 1 ):
157
-
158
157
a2 = digest (d + i )
159
158
for _ in range (2 , self ._iterations + 1 ):
160
159
a2 = digest (a2 )
@@ -172,13 +171,9 @@ def to_bytes(value):
172
171
start = n2 * v
173
172
end = (n2 + 1 ) * v
174
173
i_n2 = i [start :end ]
175
- i_n2 = to_bytes (to_int (i_n2 ) + b )
176
-
177
- i_n2_l = len (i_n2 )
178
- if i_n2_l > v :
179
- i_n2 = i_n2 [i_n2_l - v :]
174
+ i_n2 = to_bytes (to_int (i_n2 ) + b , v )
180
175
181
- i = i [0 :start ] + i_n2 + i [end :]
176
+ i = i [:start ] + i_n2 + i [end :]
182
177
183
178
# Step 7
184
179
start = (n - 1 ) * u
@@ -230,6 +225,8 @@ def __init__(self, salt, iterations, iv, password, hash_algorithm, enc_algorithm
230
225
self ._derive_key , self ._iv = self .derive_key (salt , iterations , password )
231
226
232
227
def derive_key (self , salt , iterations , password ):
228
+ if isinstance (password , str ):
229
+ password = password .encode ()
233
230
pkcs12_pbkdf1 = PKCS12_PBKDF1 (self ._hash_algorithm , 24 , salt , iterations , 1 )
234
231
key = pkcs12_pbkdf1 .derive (password )
235
232
@@ -248,11 +245,11 @@ def encrypt(self, plain_text):
248
245
return cipher_text
249
246
250
247
def decrypt (self , cipher_text ):
251
- padder = padding .PKCS7 (self ._hash_algorithm .block_size ).padder ()
252
- cipher_text = padder .update (cipher_text ) + padder .finalize ()
253
-
254
248
decryptor = Cipher (self ._enc_algorithm (self ._derive_key ), self ._enc_mode (self ._iv )).decryptor ()
255
- plain_text = decryptor .update (cipher_text ) + decryptor .finalize ()
249
+ padded_plain_text = decryptor .update (cipher_text ) + decryptor .finalize ()
250
+
251
+ unpadder = padding .PKCS7 (self ._hash_algorithm .block_size ).unpadder ()
252
+ plain_text = unpadder .update (padded_plain_text ) + unpadder .finalize ()
256
253
257
254
return plain_text
258
255
@@ -348,8 +345,8 @@ def rsec_decrypt(blob, key):
348
345
if len (key ) != 24 :
349
346
raise Exception ("Wrong key length" )
350
347
351
- blob = [ord ( i ) for i in blob ]
352
- key = [ord ( i ) for i in key ]
348
+ blob = [i for i in blob ]
349
+ key = [i for i in key ]
353
350
key1 = key [0 :8 ]
354
351
key2 = key [8 :16 ]
355
352
key3 = key [16 :24 ]
@@ -359,4 +356,4 @@ def rsec_decrypt(blob, key):
359
356
round_2 = cipher .crypt (RSECCipher .MODE_ENCODE , round_1 , key2 , len (round_1 ))
360
357
round_3 = cipher .crypt (RSECCipher .MODE_DECODE , round_2 , key1 , len (round_2 ))
361
358
362
- return '' . join ([ chr ( i ) for i in round_3 ] )
359
+ return bytes ( round_3 )
0 commit comments