Skip to content

Commit b0c7cbd

Browse files
committed
refactor: cleanup move pkcs1 v2.2 related functions into one file
1 parent 6a1d139 commit b0c7cbd

File tree

6 files changed

+272
-252
lines changed

6 files changed

+272
-252
lines changed

Diff for: src/ToyPublicKeys.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ include("utils/random.jl")
55
include("utils/number_theory.jl")
66
include("utils/string.jl")
77
include("rsa.jl")
8+
include("pkcs1_v2_2.jl")
89
include("padding/pkcs_1_v1_5.jl")
9-
include("padding/pkcs_1_v2_2.jl")
1010
include("dh.jl")
1111
export RSAKey, RSAPrivateKey, RSAPublicKey, generate_rsa_key_pair, encrypt, decrypt, sign, verify_signature, dh_params
1212
end

Diff for: src/padding/pkcs_1_v2_2.jl

-111
This file was deleted.

Diff for: src/pkcs1_v2_2.jl

+266
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
function I2OSP(x::BigInt, xLen::Integer)
2+
_order = 0
3+
_endian = 0
4+
_nails = 0
5+
n = (Base.GMP.MPZ.sizeinbase(x, 2) / 8) |> ceil |> Integer
6+
ret = zeros(UInt8, n)
7+
Base.GMP.MPZ.export!(ret, x; order=_order, nails=_nails, endian=_endian)
8+
ret_len = ret |> length
9+
if ret_len < xLen
10+
ret = vcat(zeros(UInt8, xLen - ret_len), ret)
11+
elseif ret_len > xLen
12+
error("ret_len > xLen") |> throw
13+
end
14+
return ret
15+
end
16+
17+
function OS2IP(x::Vector{UInt8})
18+
bi = BigInt()
19+
_order = 0
20+
_endian = 0
21+
_nails = 0
22+
Base.GMP.MPZ.import!(
23+
bi, length(x), _order, sizeof(eltype(x)), _endian, _nails, pointer(x)
24+
)
25+
return bi
26+
end
27+
28+
function rsaes_oaep_encode(msg::Vector{UInt8},
29+
key::RSAKey;
30+
label=UInt8[],
31+
MGF=ToyPublicKeys.MGF1,
32+
hash=SHA.sha1)
33+
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
34+
lHash = hash(label)
35+
hLen = lHash |> length
36+
LLen = label |> length
37+
LLen > (big"2" << 60) && throw(error("label too long"))
38+
mLen = msg |> length
39+
mLen > k - 2 * hLen - 2 && throw(error("message too long"))
40+
pLen = (k - mLen - 2 * hLen - 2)
41+
PS = zeros(UInt8, pLen)
42+
seed = rand(UInt8, hLen)
43+
dbMask = MGF(seed, k - hLen - 1, hash=hash)
44+
DB = vcat(lHash, PS, ones(UInt8, 1), msg)
45+
maskedDB = DB .⊻ dbMask
46+
seedMask = MGF(maskedDB, hLen)
47+
maskedSeed = seed .⊻ seedMask
48+
EM = vcat(zeros(UInt8, 1), maskedSeed, maskedDB)
49+
return EM
50+
end
51+
52+
function rsaes_oaep_decode(msg::Vector{UInt8},
53+
key::RSAKey;
54+
label=UInt8[],
55+
MGF=ToyPublicKeys.MGF1,
56+
hash=SHA.sha1)
57+
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
58+
lHash = hash(label)
59+
hLen = lHash |> length
60+
Y = msg[1]
61+
Y != 0 && error("Y != 0") |> throw
62+
maskedSeed = msg[2:end][1:hLen]
63+
DBLen = k - hLen - 1
64+
maskedDB = msg[2:end][hLen + 1:end]
65+
seedMask = MGF(maskedDB, hLen)
66+
seed = maskedSeed .⊻ seedMask
67+
dbMask = MGF(seed, k - hLen - 1)
68+
DB = maskedDB .⊻ dbMask
69+
lHashPrime = DB[1:hLen]
70+
lHashPrime != lHash && error("lHashPrime != lHash")
71+
PSIndexInView = findfirst(a -> a==1, DB[hLen + 1:end])
72+
isnothing(PSIndexInView) && error("not EME-OAEP encoded or malformed")
73+
PSIndex = PSIndexInView + hLen
74+
PS = DB[hLen + 1:PSIndex - 1]
75+
any(PS .!= 0) && error("PS should be zero filled vector") |> throw
76+
X = DB[PSIndex]
77+
X != 1 && error("X != 1") |> throw
78+
M = DB[PSIndex + 1:end]
79+
return M
80+
end
81+
82+
function MGF1(mgfSeed::Vector{UInt8},
83+
maskLen:: Integer;
84+
hash = SHA.sha1)
85+
hLen = hash(UInt8[]) |> length
86+
maskLen >= (2 << 32) && error("mask too long") |> throw
87+
T = zeros(UInt8, 0)
88+
for counter in big"0":BigInt((ceil(maskLen / hLen) - 1))
89+
C = I2OSP(counter, 4)
90+
_T = vcat(mgfSeed, C)
91+
__T = hash(_T)
92+
T = vcat(T, __T)
93+
end
94+
return T[1:maskLen]
95+
end
96+
97+
function emsa_pss_encode(M::Vector{UInt8},
98+
emBits::Integer;
99+
MGF=ToyPublicKeys.MGF1,
100+
hash=SHA.sha1,
101+
sLen=0)
102+
hLen = hash(UInt8[]) |> length
103+
(emBits >= 8*hLen + 8 * sLen + 9) || error("emBits !>= 8hLen + 8 * sLen + 9") |> throw
104+
emLen = ceil(emBits/8) |> Integer
105+
emLen < hLen + sLen + 2 && error("encoding error") |> throw
106+
length(M) > (big"2" << 60) && error("message too long") |> throw
107+
salt = UInt8[]
108+
if sLen > 0
109+
salt = rand(UInt8, sLen)
110+
end
111+
mHash = hash(M)
112+
MPrime = vcat(zeros(UInt8, 8), mHash, salt)
113+
H = hash(MPrime)
114+
PS = zeros(UInt8, emLen - sLen - hLen - 2)
115+
DB = vcat(PS, ones(UInt8, 1), salt)
116+
dbMask = MGF(H, emLen - hLen - 1)
117+
maskedDB = DB .⊻ dbMask
118+
maskedDB[1] &= 0xFF >> (8 * emLen - emBits)
119+
EM = vcat(maskedDB, H, UInt8[0xbc])
120+
return EM
121+
end
122+
123+
function emsa_pss_verify(M::Vector{UInt8},
124+
EM::Vector{UInt8},
125+
emBits::Integer;
126+
MGF=ToyPublicKeys.MGF1,
127+
hash=SHA.sha1,
128+
sLen=0)
129+
length(M) > (big"2" << 60) && error("inconsistent") |> throw
130+
mHash = hash(M)
131+
hLen = mHash |> length
132+
emLen = ceil(emBits/8) |> Integer
133+
emLen < hLen + sLen + 2 && error("inconsistent") |> throw
134+
EM[end] != 0xbc && error("inconsistent") |> throw
135+
maskedDB = EM[1:emLen - hLen - 1]
136+
H = EM[emLen - hLen:emLen - 1]
137+
(maskedDB[1] & ~(0xFF >> (8 * emLen - emBits))) != 0 && error("inconsistent") |> throw
138+
dbMask = MGF(H, emLen - hLen - 1)
139+
DB = maskedDB .⊻ dbMask
140+
DB[1] &= 0xFF >> (8 * emLen - emBits)
141+
(DB[1:emLen - hLen - sLen - 2] .!= 0) |> any && error("inconsistent") |> throw
142+
DB[emLen - hLen - sLen - 1] != 1 && error("inconsistent") |> throw
143+
salt = DB[end - sLen + 1 : end]
144+
MPrime = MPrime = vcat(zeros(UInt8, 8), mHash, salt)
145+
HPrime = hash(MPrime)
146+
return H == HPrime
147+
end
148+
149+
function rsaes_oaep_encrypt(M::Vector{UInt8},
150+
key::RSAPublicKey;
151+
label="",
152+
hash=SHA.sha1,
153+
MGF=MGF1)
154+
EM = rsaes_oaep_encode(M, key, label=label, hash=hash, MGF=MGF)
155+
m = OS2IP(EM)
156+
c = RSAEP(pkcs1_v1_5, m, key)
157+
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
158+
C = I2OSP(c, k)
159+
return C
160+
end
161+
162+
function rsaes_oaep_decrypt(C::Vector{UInt8},
163+
key::RSAPrivateKey;
164+
label="",
165+
hash=SHA.sha1,
166+
MGF=MGF1)
167+
c = OS2IP(C)
168+
m = RSADP(pkcs1_v1_5, c, key)
169+
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
170+
EM = I2OSP(m, k)
171+
M = rsaes_oaep_decode(EM, key, label=label, hash=hash, MGF=MGF)
172+
return M
173+
end
174+
175+
function rsaes_pkvs1_v1_5_encrypt(M::String, key::RSAPublicKey)
176+
EM = pad(pkcs1_v1_5, M)
177+
m = OS2IP(EM)
178+
c = RSAEP(pkcs1_v1_5, m, key)
179+
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
180+
C = I2OSP(c, k)
181+
return C
182+
end
183+
184+
function rsaes_pkvs1_v1_5_decrypt(C::String,
185+
key::RSAPrivateKey)
186+
c = OS2IP(C)
187+
m = RSADP(pkcs1_v1_5, c, key)
188+
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
189+
EM = I2OSP(m, k)
190+
m = unpad(pkcs1_v1_5, EM)
191+
return m
192+
end
193+
194+
function rsassa_pss_sign(M::Vector{UInt8},
195+
key::RSAPrivateKey)
196+
modBits = Base.GMP.MPZ.sizeinbase(key.modulus, 2)
197+
EM = emsa_pss_encode(M, modBits - 1)
198+
m = OS2IP(EM)
199+
s = RSASP1(pkcs1_v1_5, m, key)
200+
k = (modBits/8) |> ceil |> Integer
201+
S = I2OSP(s, k)
202+
return S
203+
end
204+
205+
function rsassa_pss_verify(M::Vector{UInt8},
206+
S::Vector{UInt8},
207+
key::RSAPublicKey)
208+
modBits = Base.GMP.MPZ.sizeinbase(key.modulus, 2)
209+
k = (modBits/8) |> ceil |> Integer
210+
length(S) != k && error("invalid signature") |> throw
211+
s = OS2IP(S)
212+
m = RSAVP1(pkcs1_v1_5, s, key)
213+
emLen = ceil((modBits - 1)/8) |> Integer
214+
EM = I2OSP(m, emLen)
215+
result = emsa_pss_verify(M, EM, modBits - 1)
216+
return result
217+
end
218+
219+
function validate(key::RSAPrivateKey)
220+
(length(key.primes) >= 2) || error("length(key.primes) < 2") |> throw
221+
all((key.public_exponent > 0, key.exponent > 0)) && all(key.primes .> 0) || error("all((key.public_exponent, key.exponent) .> 0, key.primes .> 0)") |> throw
222+
(length(key.exponent) > 0) || error("length(key.primes) < 2") |> throw
223+
(prod(key.primes) == key.modulus) || error("(prod(key.primes) != key.modulus)") |> throw
224+
((key.exponent * key.public_exponent) % lcm((key.primes .- 1)...)) == 1 || error("(key.exponent * key.public_exponent) % lcm((key.exponent - 1), (key.public_exponent - 1)) != 1") |> throw
225+
(key.public_exponent * key.crt_exponents[1]) % (key.primes[1] - 1) == 1 || error(" (key.public_exponent * key.crt_exponents[1]) % (key.primes[1] - 1) != 1") |> throw
226+
(key.public_exponent * key.crt_exponents[2]) % (key.primes[2] - 1) == 1 || error("(key.public_exponent * key.crt_exponents[2]) % (key.primes[2] - 1) != 1") |> throw
227+
# is this consistent with the struct?
228+
(key.primes[2] * key.crt_coefficients[1]) % key.primes[1] == 1 || error("(key.primes[2] * key.crt_coefficients[2]) % key.primes[1] != 1") |> throw
229+
end
230+
231+
"""
232+
RSAStep(::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey)
233+
234+
Fast implementation of the RSA exponentiation step when RSAPrivateKey is provided.
235+
It uses [Chinese remainer theorem](https://en.wikipedia.org/wiki/Chinese_remainder_theorem) for very fast `exp() mod n` calculations.
236+
"""
237+
function RSAEP(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey)
238+
return RSAStep(v, msg, key)
239+
end
240+
241+
"""
242+
RSAStep(::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey)
243+
244+
RSA exponentiation step when only public key is available.
245+
Uses [repeated squares](https://en.wikipedia.org/wiki/Exponentiation_by_squaring)
246+
and other fast modulo exponentiation tricks in its GMP implementation (Base.GMP.MPZ.powm).
247+
"""
248+
function RSADP(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey)
249+
return RSAStep(v, msg, key)
250+
end
251+
252+
"""
253+
RSASP1(::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey)
254+
255+
"""
256+
function RSASP1(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey)
257+
return RSAStep(v, msg, key)
258+
end
259+
260+
"""
261+
RSAVP1(::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey)
262+
263+
"""
264+
function RSAVP1(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey)
265+
return RSAStep(v, msg, key)
266+
end

0 commit comments

Comments
 (0)