Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rsaes #59

Merged
merged 5 commits into from
Mar 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions src/padding/pkcs_1_v2_2.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
function pad(::pkcs1_v2_2_t,
msg::Union{AbstractString, AbstractVector},
msg::Vector{UInt8},
key::RSAKey;
label="",
label=UInt8[],
MGF=ToyPublicKeys.MGF1,
hash=SHA.sha1)
msg = msg |> Vector{UInt8}
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
lHash = hash(label)
hLen = lHash |> length
Expand All @@ -25,15 +24,16 @@ function pad(::pkcs1_v2_2_t,
end

function unpad(::pkcs1_v2_2_t,
msg::Union{AbstractString, AbstractVector},
msg::Vector{UInt8},
key::RSAKey;
label="",
label=UInt8[],
MGF=ToyPublicKeys.MGF1,
hash=SHA.sha1)
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
lHash = hash(label)
hLen = lHash |> length
Y = msg[1]
Y != 0 && error("Y != 0") |> throw
maskedSeed = msg[2:end][1:hLen]
DBLen = k - hLen - 1
maskedDB = msg[2:end][hLen + 1:end]
Expand All @@ -42,21 +42,27 @@ function unpad(::pkcs1_v2_2_t,
dbMask = MGF(seed, k - hLen - 1)
DB = maskedDB .⊻ dbMask
lHashPrime = DB[1:hLen]
PSLen = findfirst(Vector{UInt8}([1]), DB)
PSIndex = (PSLen |> first) - 1
PS = DB[hLen + 1:PSIndex]
X = DB[PSIndex + 1]
M = DB[PSIndex + 2:end]
lHashPrime != lHash && error("lHashPrime != lHash")
PSIndexInView = findfirst(a -> a==1, DB[hLen + 1:end])
isnothing(PSIndexInView) && error("not EME-OAEP encoded or malformed")
PSIndex = PSIndexInView + hLen
PS = DB[hLen + 1:PSIndex - 1]
any(PS .!= 0) && error("PS should be zero filled vector") |> throw
X = DB[PSIndex]
X != 1 && error("X != 1") |> throw
M = DB[PSIndex + 1:end]
return M
end

function MGF1(mgfSeed::Vector{UInt8}, maskLen:: Integer; hash = SHA.sha1)
hLen = hash("") |> length
hLen = hash(UInt8[]) |> length
maskLen >= (2 << 32) && error("mask too long") |> throw
T = Vector{UInt8}()
T = zeros(UInt8, 0)
for counter in big"0":BigInt((ceil(maskLen / hLen) - 1))
C = I2OSP(counter, 4) |> Vector{UInt8}
T = vcat(T, hash(vcat(mgfSeed, C) |> String))
C = I2OSP(counter, 4)
_T = vcat(mgfSeed, C)
__T = hash(_T)
T = vcat(T, __T)
end
return T[1:maskLen]
end
54 changes: 39 additions & 15 deletions src/rsa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,8 @@ end
Fast implementation of the RSA exponentiation step when RSAPrivateKey is provided.
It uses [Chinese remainer theorem](https://en.wikipedia.org/wiki/Chinese_remainder_theorem) for very fast `exp() mod n` calculations.
"""
function RSAEP(::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey)
if !(0 <= msg < key.modulus)
error("msg has to be 0 <= msg < n, got: msg = $msg, n = $key.modulus")
end
ret = power_crt(
msg,
key.primes[1],
key.primes[2],
key.crt_exponents[1],
key.crt_exponents[2],
key.crt_coefficients[2],
)
ret < 0 && (ret += msg) #???
return ret
function RSAEP(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey)
return RSAStep(v, msg, key)
end

"""
Expand All @@ -83,7 +71,7 @@ RSA exponentiation step when only public key is available.
Uses [repeated squares](https://en.wikipedia.org/wiki/Exponentiation_by_squaring)
and other fast modulo exponentiation tricks in its GMP implementation (Base.GMP.MPZ.powm).
"""
function RSADP(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey)
function RSADP(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey)
return RSAStep(v, msg, key)
end

Expand Down Expand Up @@ -165,6 +153,42 @@ function RSAStep(::pkcs1_v1_5_t, msg::String, key::RSAKey)
return transformed_msg
end

function rsaes_oaep_encrypt(M::Vector{UInt8}, key::RSAPublicKey; label="", hash=SHA.sha1, MGF=MGF1)
EM = pad(pkcs1_v2_2, M, key, label=label, hash=hash, MGF=MGF)
m = OS2IP(EM)
c = RSAEP(pkcs1_v1_5, m, key)
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
C = I2OSP(c, k)
return C
end

function rsaes_oaep_decrypt(C::Vector{UInt8}, key::RSAPrivateKey; label="", hash=SHA.sha1, MGF=MGF1)
c = OS2IP(C)
m = RSADP(pkcs1_v1_5, c, key)
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
EM = I2OSP(m, k)
M = unpad(pkcs1_v2_2, EM, key, label=label, hash=hash, MGF=MGF)
return M
end

function rsaes_pkvs1_v1_5_encrypt(M::String, key::RSAPublicKey)
EM = pad(pkcs1_v1_5, M)
m = OS2IP(EM)
c = RSAEP(pkcs1_v1_5, m, key)
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
C = I2OSP(c, k)
return C
end

function rsaes_pkvs1_v1_5_decrypt(C::String, key::RSAPrivateKey)
c = OS2IP(C)
m = RSADP(pkcs1_v1_5, c, key)
k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer
EM = I2OSP(m, k)
m = unpad(pkcs1_v1_5, EM)
return m
end

"""
encrypt(::pkcs1_v1_5_t,
msg::Union{AbstractString,AbstractVector},
Expand Down
47 changes: 39 additions & 8 deletions src/utils/string.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,57 @@
function I2OSP(x::BigInt)
return I2OSP(x, Base.GMP.MPZ.sizeinbase(x, 16))
function s_to_os(buf::String)
_buf = buf |> uppercase
it = Iterators.Stateful(_buf)
part = Base.Iterators.partition(it, 2)
return join(map(join, part), ':')
end

function I2OSP(x::BigInt, xLen::Integer)
function i2osp(x::BigInt)
return i2osp(x, Base.GMP.MPZ.sizeinbase(x, 16))
end

function i2osp(x::BigInt, xLen::Integer)
xLen |> isodd && (xLen += 1)
base_size = Base.GMP.MPZ.sizeinbase(x, 16)
base_size > xLen && throw(error("integer too big for xLen"))
buf = zeros(UInt8, xLen)
fill!(buf, '0')
buf_ptr = pointer(buf)
Base.GMP.MPZ.get_str!(buf_ptr + xLen - base_size, 16, x)
_buf = String(buf) |> uppercase
it = Iterators.Stateful(_buf)
part = Base.Iterators.partition(it, 2)
return join(map(join, part), ':')
return s_to_os(buf |> String)
end

function OS2IP(x::String)
function os2ip(x::String)
# use Cstring instead..?
buf = replace(x, ":" => "") |> lowercase |> Vector{UInt8}
push!(buf, 0)
target = BigInt(0)
Base.GMP.MPZ.set_str!(target, pointer(buf), 16) == 0 || throw(error("string not valid base 16"))
return target
end

function I2OSP(x::BigInt, xLen::Integer)
_order = 0
_endian = 0
_nails = 0
n = (Base.GMP.MPZ.sizeinbase(x, 2) / 8) |> ceil |> Integer
ret = zeros(UInt8, n)
Base.GMP.MPZ.export!(ret, x; order=_order, nails=_nails, endian=_endian)
ret_len = ret |> length
if ret_len < xLen
ret = vcat(zeros(UInt8, xLen - ret_len), ret)
elseif ret_len > xLen
error("ret_len > xLen") |> throw
end
return ret
end

function OS2IP(x::Vector{UInt8})
bi = BigInt()
_order = 0
_endian = 0
_nails = 0
Base.GMP.MPZ.import!(
bi, length(x), _order, sizeof(eltype(x)), _endian, _nails, pointer(x)
)
return bi
end
12 changes: 7 additions & 5 deletions test/padding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@ end
test_vector = Vector{UInt8}([1,2,3])
Random.seed!(42)
padded = ToyPublicKeys.pad(ToyPublicKeys.pkcs1_v1_5, test_vector)
@test ToyPublicKeys.unpad(ToyPublicKeys.pkcs1_v1_5, padded) == test_vector
unpadded = ToyPublicKeys.unpad(ToyPublicKeys.pkcs1_v1_5, padded)
@test unpadded == test_vector
end

@testset "padding/pkcs_1_v2_2 pad(unpad) is identity" begin
test_vector = Vector{UInt8}([1,2,3])
test_vector = Vector{UInt8}([3,2,1])
Random.seed!(42)
private_key, public_key = ToyPublicKeys.generate_rsa_key_pair(ToyPublicKeys.pkcs1_v1_5, 2048)
padded = ToyPublicKeys.pad(ToyPublicKeys.pkcs1_v2_2,
test_vector,
public_key)
@test test_vector == ToyPublicKeys.unpad(ToyPublicKeys.pkcs1_v2_2,
padded,
public_key)
unpadded = ToyPublicKeys.unpad(ToyPublicKeys.pkcs1_v2_2,
padded,
public_key)
@test unpadded == test_vector
end
9 changes: 9 additions & 0 deletions test/rsa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,12 @@ end
signature = ToyPublicKeys.RSASP1(ToyPublicKeys.pkcs1_v1_5, msg, private_key)
@test ToyPublicKeys.RSAVP1(ToyPublicKeys.pkcs1_v1_5, signature, public_key) == msg
end

@testset "rsaes_oaep_decrypt(rsaes_oaep_encrypt) is true" begin
Random.seed!(42)
private_key, public_key = ToyPublicKeys.generate_rsa_key_pair(ToyPublicKeys.pkcs1_v1_5, 2048)
msg = Vector{UInt8}("123")
C = ToyPublicKeys.rsaes_oaep_encrypt(msg, public_key)
ret = ToyPublicKeys.rsaes_oaep_decrypt(C, private_key)
@test ret == msg
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using ToyPublicKeys
using Test
import Random

include("utils.jl")
include("padding.jl")
include("rsa.jl")
include("dh.jl")
include("padding.jl")
include("utils.jl")
14 changes: 7 additions & 7 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
@testset "I2OSP" begin
@testset "i2osp" begin
b = big"255"
@test b |> ToyPublicKeys.I2OSP == "FF"
@test b |> x -> ToyPublicKeys.I2OSP(x, 3) == "00:FF"
@test b |> ToyPublicKeys.i2osp == "FF"
@test b |> x -> ToyPublicKeys.i2osp(x, 3) == "00:FF"
end

@testset "OS2IP" begin
@testset "os2ip" begin
b = "FF"
@test b |> ToyPublicKeys.OS2IP == big"255"
@test b |> ToyPublicKeys.os2ip == big"255"
end

@testset "I2OSP |> OS2IP" begin
@testset "i2osp |> os2ip" begin
b = big"255"
@test b |> ToyPublicKeys.I2OSP |> ToyPublicKeys.OS2IP == b
@test b |> ToyPublicKeys.i2osp |> ToyPublicKeys.os2ip == b
end