Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add pkcs1 version as type dispatch #50

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@
import Random
using ToyPublicKeys
Random.seed!(42)
private_key, public_key = ToyPublicKeys.generate_rsa_key_pair(2048)
private_key, public_key = ToyPublicKeys.generate_rsa_key_pair(ToyPublicKeys.pkcs1_v1_5, 2048)
msg = "Super secret message!"
println(msg)
encrypted = ToyPublicKeys.encrypt(msg, public_key)
encrypted = ToyPublicKeys.encrypt(ToyPublicKeys.pkcs1_v1_5, msg, public_key)
println(encrypted)
decrypted = ToyPublicKeys.decrypt(encrypted, private_key)
decrypted = ToyPublicKeys.decrypt(ToyPublicKeys.pkcs1_v1_5, encrypted, private_key)
println(decrypted)
```
## Signatures and their verification
```@example
import Random
using ToyPublicKeys
Random.seed!(42)
private_key, public_key = ToyPublicKeys.generate_rsa_key_pair(2048)
private_key, public_key = ToyPublicKeys.generate_rsa_key_pair(ToyPublicKeys.pkcs1_v1_5, 2048)
msg = "Super authentic message!"
println(msg)
signature = ToyPublicKeys.sign(msg, private_key)
signature = ToyPublicKeys.sign(ToyPublicKeys.pkcs1_v1_5, msg, private_key)
println(signature)
authentic = ToyPublicKeys.verify_signature(msg, signature, public_key)
authentic = ToyPublicKeys.verify_signature(ToyPublicKeys.pkcs1_v1_5, msg, signature, public_key)
println(authentic)
```
46 changes: 13 additions & 33 deletions src/padding/pkcs_1_v1_5.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,12 @@
const padding_pkcs_1_v1_5_num_c_chars = 3
const padding_pkcs_1_v1_5_pad_start = 3
"""
is_padded(msg::AbstractVector{T}) where {T<:Base.BitInteger}

Checks for magic bytes of [PKCS#1 v1.5 padding](https://www.rfc-editor.org/rfc/rfc2313#section-8.1).
"""
function is_padded(msg::AbstractVector{T}) where {T<:Base.BitInteger}
if length(msg) < 3
return false
elseif msg[1] != T(0) && msg[2] != T(3)
return false
elseif nothing ==
findfirst(==(T(0)), view(msg, padding_pkcs_1_v1_5_pad_start:length(msg)))
return false
else
return true
end
end

"""
pad(msg::AbstractVector{T}, pad_length=32) where {T<:Base.BitInteger}
pad(::pkcs1_v1_5_t, msg::AbstractVector{T}, pad_length=32) where {T<:Base.BitInteger}

Core implementation of the [PKCS#1 v1.5 padding](https://www.rfc-editor.org/rfc/rfc2313#section-8.1).
"""
function pad(msg::AbstractVector{T}, pad_length=32) where {T<:Base.BitInteger}
function pad(::pkcs1_v1_5_t, msg::AbstractVector{T}, pad_length=32) where {T<:Base.BitInteger}
pad_length > 8 || throw(error("Will not create pad with length < 8"))
buff = rand(T(1):T(typemax(T)), pad_length + 3)
buff[1] = 0
Expand All @@ -34,39 +17,36 @@ function pad(msg::AbstractVector{T}, pad_length=32) where {T<:Base.BitInteger}
end

"""
pad(msg::T, pad_length=32) where {T<:AbstractString}
pad(::pkcs1_v1_5_t, msg::T, pad_length=32) where {T<:AbstractString}

Wrapper for the core pad function.
"""
function pad(msg::T, pad_length=32) where {T<:AbstractString}
function pad(::pkcs1_v1_5_t, msg::T, pad_length=32) where {T<:AbstractString}
pad_length > 8 || throw(error("Will not create pad with length < 8"))
msg_cu = codeunits(msg)
msg_padded = pad(msg_cu, pad_length)
msg_padded = pad(pkcs1_v1_5, msg_cu, pad_length)
return T(msg_padded)
end

"""
unpad(msg::AbstractVector{T}) where {T<:Base.BitInteger}
unpad(::pkcs1_v1_5_t, msg::AbstractVector{T}) where {T<:Base.BitInteger}

Core implementation for the [PKCS#1 v1.5 pad](https://www.rfc-editor.org/rfc/rfc2313#section-8.1) unwrapping.
"""
function unpad(msg::AbstractVector{T}) where {T<:Base.BitInteger}
if !is_padded(msg)
error("Not padded: $msg")
end
pos =
findfirst(==(T(0)), view(msg, padding_pkcs_1_v1_5_pad_start:length(msg))) +
padding_pkcs_1_v1_5_num_c_chars
function unpad(::pkcs1_v1_5_t, msg::AbstractVector{T}) where {T<:Base.BitInteger}
pos = findfirst(==(T(0)),
view(msg,
padding_pkcs_1_v1_5_pad_start:length(msg))) + padding_pkcs_1_v1_5_num_c_chars
return view(msg, pos:length(msg))
end

"""
unpad(msg::T) where {T<:AbstractString}
unpad(::pkcs1_v1_5_t, msg::T) where {T<:AbstractString}

Wrapper for the core unpad function.
"""
function unpad(msg::T) where {T<:AbstractString}
function unpad(::pkcs1_v1_5_t, msg::T) where {T<:AbstractString}
msg_cu = codeunits(msg)
msg_unpadded = unpad(msg_cu)
msg_unpadded = unpad(pkcs1_v1_5, msg_cu)
return T(msg_unpadded)
end
3 changes: 0 additions & 3 deletions src/padding/pkcs_1_v2_2.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
struct pkcs1_v2_2_t end
const pkcs1_v2_2 = pkcs1_v2_2_t()

function pad(::pkcs1_v2_2_t,
msg::Union{AbstractString, AbstractVector},
key::RSAKey;
Expand Down
85 changes: 48 additions & 37 deletions src/rsa.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
using SHA
# NOTE: This RSA implementation tries to follow RFC 2313, however it is not conformant with it. Future work: conform this with RFC 2313 or better, with RFC 2437

struct pkcs1_v1_5_t end
const pkcs1_v1_5 = pkcs1_v1_5_t()

struct pkcs1_v2_2_t end
const pkcs1_v2_2 = pkcs1_v2_2_t()

"""
RSAPrivateKey

Expand Down Expand Up @@ -37,12 +43,12 @@ Union of RSAPrivateKey and RSAPublicKey for methods, that do not require specifi
const RSAKey = Union{RSAPrivateKey,RSAPublicKey}

"""
RSAStep(msg::BigInt, key::RSAPrivateKey)
RSAStep(::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey)

Fast implementation of the RSA exponentiation step when RSAPrivateKey is provided.
It uses [Chinese remainer theorem](https://en.wikipedia.org/wiki/Chinese_remainder_theorem) for very fast `exp() mod n` calculations.
"""
function RSAStep(msg::BigInt, key::RSAPrivateKey)
function RSAStep(::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey)
if !(0 <= msg < key.modulus)
error("msg has to be 0 <= msg < n, got: msg = $msg, n = $key.modulus")
end
Expand All @@ -58,26 +64,26 @@ function RSAStep(msg::BigInt, key::RSAPrivateKey)
end

"""
RSAStep(msg::BigInt, key::RSAPublicKey)
RSAStep(::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey)

RSA exponentiation step when only public key is available.
Uses [repeated squares](https://en.wikipedia.org/wiki/Exponentiation_by_squaring)
and other fast modulo exponentiation tricks in its GMP implementation (Base.GMP.MPZ.powm).
"""
function RSAStep(msg::BigInt, key::RSAPublicKey)
function RSAStep(::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey)
if !(0 <= msg < key.modulus)
error("msg has to be 0 <= msg < n, got: msg = $msg, n = $key.modulus")
end
return Base.GMP.MPZ.powm(msg, key.exponent, key.modulus)
end

"""
RSAStep(msg::AbstractVector{T}, key::RSAKey) where {T<:Base.BitInteger}
RSAStep(::pkcs1_v1_5_t, msg::AbstractVector{T}, key::RSAKey) where {T<:Base.BitInteger}

RSA exponentiation step for AbstractVectors (arbitrary buffers).
Only prepares the buffer for [`RSAStep(msg::BigInt, key::RSAPublicKey)`](@ref).
"""
function RSAStep(msg::AbstractVector{T}, key::RSAKey) where {T<:Base.BitInteger}
function RSAStep(::pkcs1_v1_5_t, msg::AbstractVector{T}, key::RSAKey) where {T<:Base.BitInteger}
msg_bi = BigInt()
# https://gmplib.org/manual/Integer-Import-and-Export#index-mpz_005fimport
# void mpz_import (mpz_t rop, size_t count, int order, size_t size, int endian, size_t nails, const void *op)
Expand All @@ -87,7 +93,7 @@ function RSAStep(msg::AbstractVector{T}, key::RSAKey) where {T<:Base.BitInteger}
Base.GMP.MPZ.import!(
msg_bi, length(msg), _order, sizeof(eltype(msg)), _endian, _nails, pointer(msg)
)
result = RSAStep(msg_bi, key)
result = RSAStep(pkcs1_v1_5, msg_bi, key)
# https://gmplib.org/manual/Integer-Import-and-Export#index-mpz_005fexport
# void * mpz_export (void *rop, size_t *countp, int order, size_t size, int endian, size_t nails, const mpz_t op)
msg_buf = Vector{T}(undef, result.size)
Expand All @@ -96,58 +102,63 @@ function RSAStep(msg::AbstractVector{T}, key::RSAKey) where {T<:Base.BitInteger}
end

"""
RSAStep(msg::String, key::RSAKey)
RSAStep(::pkcs1_v1_5_t, msg::String, key::RSAKey)

RSA exponentiation step for Strings.
Only prepares the buffer for [`RSAStep(msg::BigInt, key::RSAPublicKey)`](@ref).
"""
function RSAStep(msg::String, key::RSAKey)
function RSAStep(::pkcs1_v1_5_t, msg::String, key::RSAKey)
msg_cu = codeunits(msg)
result = RSAStep(msg_cu, key)
result = RSAStep(pkcs1_v1_5, msg_cu, key)
transformed_msg = String(result)
return transformed_msg
end

"""
encrypt(msg::Union{AbstractString,AbstractVector}, key::RSAPublicKey; pad_length=32)
encrypt(::pkcs1_v1_5_t,
msg::Union{AbstractString,AbstractVector},
key::RSAPublicKey
; pad_length=32)

RSA encryption function with [PKCS#1 v1.5 padding](https://www.rfc-editor.org/rfc/rfc2313#section-8.1).
"""
function encrypt(
msg::Union{AbstractString,AbstractVector}, key::RSAPublicKey; pad_length=32
function encrypt(::pkcs1_v1_5_t,
msg::Union{AbstractString,AbstractVector},
key::RSAPublicKey
; pad_length=32
)
msg_padded = ToyPublicKeys.pad(msg, pad_length)
return RSAStep(msg_padded, key)
msg_padded = ToyPublicKeys.pad(pkcs1_v1_5, msg, pad_length)
return RSAStep(pkcs1_v1_5, msg_padded, key)
end

"""
ecrypt(msg::AbstractString, key::RSAPrivateKey)
decrypt(::pkcs1_v1_5_t, msg::AbstractString, key::RSAPrivateKey)

RSA decryption function for strings, expects [PKCS#1 v1.5 padding](https://www.rfc-editor.org/rfc/rfc2313#section-8.1).
"""
function decrypt(msg::AbstractString, key::RSAPrivateKey)
function decrypt(::pkcs1_v1_5_t, msg::AbstractString, key::RSAPrivateKey)
msg_ = codeunits(msg)
msg_decr = RSAStep(msg_, key)
unpaded = ToyPublicKeys.unpad(vcat(typeof(msg_decr)([0]), msg_decr)) # todo: leading zero is ignored, gotta deal with this
msg_decr = RSAStep(pkcs1_v1_5, msg_, key)
unpaded = ToyPublicKeys.unpad(pkcs1_v1_5, vcat(typeof(msg_decr)([0]), msg_decr)) # todo: leading zero is ignored, gotta deal with this
return String(unpaded)
end

"""
ecrypt(msg::AbstractString, key::RSAPrivateKey)
decrypt(::pkcs1_v1_5_t, msg::AbstractVector, key::RSAPrivateKey)

RSA decryption function for vectors (arbitrary buffers), expects [PKCS#1 v1.5 padding](https://www.rfc-editor.org/rfc/rfc2313#section-8.1).
"""
function decrypt(msg::AbstractVector, key::RSAPrivateKey)
msg_decr = RSAStep(msg, key)
return ToyPublicKeys.unpad(vcat(typeof(msg_decr)([0]), msg_decr)) # todo: leading zero is ignored, gotta deal with this
function decrypt(::pkcs1_v1_5_t, msg::AbstractVector, key::RSAPrivateKey)
msg_decr = RSAStep(pkcs1_v1_5, msg, key)
return ToyPublicKeys.unpad(pkcs1_v1_5, vcat(typeof(msg_decr)([0]), msg_decr)) # todo: leading zero is ignored, gotta deal with this
end

"""
generate_rsa_key_pair(bits::Integer)
generate_rsa_key_pair(::pkcs1_v1_5_t, bits::Integer)

RSA key pair constructor (hopefully) according to [RFC 2313](https://www.rfc-editor.org/rfc/rfc2313.txt)
"""
function generate_rsa_key_pair(bits::Integer)
function generate_rsa_key_pair(::pkcs1_v1_5_t, bits::Integer)
bits <= 0 && error("bits <= 0")
# todo: not enough bit size
e = big"65537"
Expand All @@ -174,36 +185,36 @@ function generate_rsa_key_pair(bits::Integer)
end

"""
sign(msg::String, key::RSAPrivateKey; pad_length=32)
sign(::pkcs1_v1_5_t, msg::String, key::RSAPrivateKey; pad_length=32)

Sign string with RSA key.
"""
function sign(msg::String, key::RSAPrivateKey; pad_length=32)
function sign(::pkcs1_v1_5_t, msg::String, key::RSAPrivateKey; pad_length=32)
digest = SHA.sha256(msg)
msg_padded = ToyPublicKeys.pad(digest, pad_length)
return String(RSAStep(msg_padded, key))
msg_padded = ToyPublicKeys.pad(pkcs1_v1_5, digest, pad_length)
return String(RSAStep(pkcs1_v1_5, msg_padded, key))
end

"""
sign(msg::AbstractVector, key::RSAPrivateKey; pad_length=32)
sign(::pkcs1_v1_5_t, msg::AbstractVector, key::RSAPrivateKey; pad_length=32)

Sign AbstractVector (arbitrary buffer using [SHA256](https://en.wikipedia.org/wiki/SHA-2)) with RSA key.
"""
function sign(msg::AbstractVector, key::RSAPrivateKey; pad_length=32)
function sign(::pkcs1_v1_5_t, msg::AbstractVector, key::RSAPrivateKey; pad_length=32)
digest = SHA.sha256(String(msg))
msg_padded = ToyPublicKeys.pad(digest, pad_length)
return RSAStep(msg_padded, key)
msg_padded = ToyPublicKeys.pad(pkcs1_v1_5, digest, pad_length)
return RSAStep(pkcs1_v1_5, msg_padded, key)
end

"""
verify_signature(msg::String, signature::String, key::RSAPublicKey)
verify_signature(::pkcs1_v1_5_t, msg::String, signature::String, key::RSAPublicKey)

Verify the signature.
"""
function verify_signature(msg::String, signature::String, key::RSAPublicKey)
function verify_signature(::pkcs1_v1_5_t, msg::String, signature::String, key::RSAPublicKey)
signature_ = codeunits(signature)
signature_decr = ToyPublicKeys.RSAStep(signature_, key)
unpaded_hash = ToyPublicKeys.unpad(vcat(typeof(signature_decr)([0]), signature_decr)) # todo: leading zero is ignored, gotta deal with this
signature_decr = ToyPublicKeys.RSAStep(pkcs1_v1_5, signature_, key)
unpaded_hash = ToyPublicKeys.unpad(pkcs1_v1_5, vcat(typeof(signature_decr)([0]), signature_decr)) # todo: leading zero is ignored, gotta deal with this
digest = SHA.sha256(msg)
return unpaded_hash == digest
end
12 changes: 6 additions & 6 deletions test/padding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,30 @@
test_vector = Vector{UInt8}([1,2,3])
Random.seed!(42)
padded_vector_correct = UInt8[0x00, 0x02, 0x7a, 0xb4, 0xac, 0x2b, 0x9d, 0xab, 0x75, 0x4d, 0xa9, 0xa4, 0x58, 0x45, 0x84, 0x17, 0x46, 0x31, 0x6d, 0x7c, 0x15, 0x84, 0x33, 0x9a, 0x66, 0x51, 0x6f, 0xdb, 0x52, 0x90, 0x53, 0x29, 0xb9, 0x5f, 0x00, 0x01, 0x02, 0x03]
padded_vec = ToyPublicKeys.pad(test_vector)
padded_vec = ToyPublicKeys.pad(ToyPublicKeys.pkcs1_v1_5, test_vector)
@test padded_vector_correct == padded_vec
unpadded_vector = ToyPublicKeys.unpad(padded_vector_correct)
unpadded_vector = ToyPublicKeys.unpad(ToyPublicKeys.pkcs1_v1_5, padded_vector_correct)
@test test_vector == unpadded_vector
end

@testset "padding/pkcs_1_v1_5 throws error for pad length < 8" begin
test_vector = Vector{UInt8}([1,2,3])
Random.seed!(42)
padded_vector_correct = UInt8[0x00, 0x02, 0x7a, 0xb4, 0xac, 0x2b, 0x9d, 0xab, 0x75, 0x4d, 0xa9, 0xa4, 0x58, 0x45, 0x84, 0x17, 0x46, 0x31, 0x6d, 0x7c, 0x15, 0x84, 0x33, 0x9a, 0x66, 0x51, 0x6f, 0xdb, 0x52, 0x90, 0x53, 0x29, 0xb9, 0x5f, 0x00, 0x01, 0x02, 0x03]
@test_throws ErrorException ToyPublicKeys.pad(test_vector, 7)
@test_throws ErrorException ToyPublicKeys.pad(ToyPublicKeys.pkcs1_v1_5, test_vector, 7)
end

@testset "padding/pkcs_1_v1_5 pad(unpad) is identity" begin
test_vector = Vector{UInt8}([1,2,3])
Random.seed!(42)
padded = ToyPublicKeys.pad(test_vector)
@test ToyPublicKeys.unpad(padded) == test_vector
padded = ToyPublicKeys.pad(ToyPublicKeys.pkcs1_v1_5, test_vector)
@test ToyPublicKeys.unpad(ToyPublicKeys.pkcs1_v1_5, padded) == test_vector
end

@testset "padding/pkcs_1_v2_2 pad(unpad) is identity" begin
test_vector = Vector{UInt8}([1,2,3])
Random.seed!(42)
private_key, public_key = ToyPublicKeys.generate_rsa_key_pair(2048)
private_key, public_key = ToyPublicKeys.generate_rsa_key_pair(ToyPublicKeys.pkcs1_v1_5, 2048)
padded = ToyPublicKeys.pad(ToyPublicKeys.pkcs1_v2_2,
test_vector,
public_key)
Expand Down
Loading