|
| 1 | +function I2OSP(x::BigInt, xLen::Integer) |
| 2 | + _order = 0 |
| 3 | + _endian = 0 |
| 4 | + _nails = 0 |
| 5 | + n = (Base.GMP.MPZ.sizeinbase(x, 2) / 8) |> ceil |> Integer |
| 6 | + ret = zeros(UInt8, n) |
| 7 | + Base.GMP.MPZ.export!(ret, x; order=_order, nails=_nails, endian=_endian) |
| 8 | + ret_len = ret |> length |
| 9 | + if ret_len < xLen |
| 10 | + ret = vcat(zeros(UInt8, xLen - ret_len), ret) |
| 11 | + elseif ret_len > xLen |
| 12 | + error("ret_len > xLen") |> throw |
| 13 | + end |
| 14 | + return ret |
| 15 | +end |
| 16 | + |
| 17 | +function OS2IP(x::Vector{UInt8}) |
| 18 | + bi = BigInt() |
| 19 | + _order = 0 |
| 20 | + _endian = 0 |
| 21 | + _nails = 0 |
| 22 | + Base.GMP.MPZ.import!( |
| 23 | + bi, length(x), _order, sizeof(eltype(x)), _endian, _nails, pointer(x) |
| 24 | + ) |
| 25 | + return bi |
| 26 | +end |
| 27 | + |
| 28 | +function rsaes_oaep_encode(msg::Vector{UInt8}, |
| 29 | + key::RSAKey; |
| 30 | + label=UInt8[], |
| 31 | + MGF=ToyPublicKeys.MGF1, |
| 32 | + hash=SHA.sha1) |
| 33 | + k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer |
| 34 | + lHash = hash(label) |
| 35 | + hLen = lHash |> length |
| 36 | + LLen = label |> length |
| 37 | + LLen > (big"2" << 60) && throw(error("label too long")) |
| 38 | + mLen = msg |> length |
| 39 | + mLen > k - 2 * hLen - 2 && throw(error("message too long")) |
| 40 | + pLen = (k - mLen - 2 * hLen - 2) |
| 41 | + PS = zeros(UInt8, pLen) |
| 42 | + seed = rand(UInt8, hLen) |
| 43 | + dbMask = MGF(seed, k - hLen - 1, hash=hash) |
| 44 | + DB = vcat(lHash, PS, ones(UInt8, 1), msg) |
| 45 | + maskedDB = DB .⊻ dbMask |
| 46 | + seedMask = MGF(maskedDB, hLen) |
| 47 | + maskedSeed = seed .⊻ seedMask |
| 48 | + EM = vcat(zeros(UInt8, 1), maskedSeed, maskedDB) |
| 49 | + return EM |
| 50 | +end |
| 51 | + |
| 52 | +function rsaes_oaep_decode(msg::Vector{UInt8}, |
| 53 | + key::RSAKey; |
| 54 | + label=UInt8[], |
| 55 | + MGF=ToyPublicKeys.MGF1, |
| 56 | + hash=SHA.sha1) |
| 57 | + k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer |
| 58 | + lHash = hash(label) |
| 59 | + hLen = lHash |> length |
| 60 | + Y = msg[1] |
| 61 | + Y != 0 && error("Y != 0") |> throw |
| 62 | + maskedSeed = msg[2:end][1:hLen] |
| 63 | + DBLen = k - hLen - 1 |
| 64 | + maskedDB = msg[2:end][hLen + 1:end] |
| 65 | + seedMask = MGF(maskedDB, hLen) |
| 66 | + seed = maskedSeed .⊻ seedMask |
| 67 | + dbMask = MGF(seed, k - hLen - 1) |
| 68 | + DB = maskedDB .⊻ dbMask |
| 69 | + lHashPrime = DB[1:hLen] |
| 70 | + lHashPrime != lHash && error("lHashPrime != lHash") |
| 71 | + PSIndexInView = findfirst(a -> a==1, DB[hLen + 1:end]) |
| 72 | + isnothing(PSIndexInView) && error("not EME-OAEP encoded or malformed") |
| 73 | + PSIndex = PSIndexInView + hLen |
| 74 | + PS = DB[hLen + 1:PSIndex - 1] |
| 75 | + any(PS .!= 0) && error("PS should be zero filled vector") |> throw |
| 76 | + X = DB[PSIndex] |
| 77 | + X != 1 && error("X != 1") |> throw |
| 78 | + M = DB[PSIndex + 1:end] |
| 79 | + return M |
| 80 | +end |
| 81 | + |
| 82 | +function MGF1(mgfSeed::Vector{UInt8}, |
| 83 | + maskLen:: Integer; |
| 84 | + hash = SHA.sha1) |
| 85 | + hLen = hash(UInt8[]) |> length |
| 86 | + maskLen >= (2 << 32) && error("mask too long") |> throw |
| 87 | + T = zeros(UInt8, 0) |
| 88 | + for counter in big"0":BigInt((ceil(maskLen / hLen) - 1)) |
| 89 | + C = I2OSP(counter, 4) |
| 90 | + _T = vcat(mgfSeed, C) |
| 91 | + __T = hash(_T) |
| 92 | + T = vcat(T, __T) |
| 93 | + end |
| 94 | + return T[1:maskLen] |
| 95 | +end |
| 96 | + |
| 97 | +function emsa_pss_encode(M::Vector{UInt8}, |
| 98 | + emBits::Integer; |
| 99 | + MGF=ToyPublicKeys.MGF1, |
| 100 | + hash=SHA.sha1, |
| 101 | + sLen=0) |
| 102 | + hLen = hash(UInt8[]) |> length |
| 103 | + (emBits >= 8*hLen + 8 * sLen + 9) || error("emBits !>= 8hLen + 8 * sLen + 9") |> throw |
| 104 | + emLen = ceil(emBits/8) |> Integer |
| 105 | + emLen < hLen + sLen + 2 && error("encoding error") |> throw |
| 106 | + length(M) > (big"2" << 60) && error("message too long") |> throw |
| 107 | + salt = UInt8[] |
| 108 | + if sLen > 0 |
| 109 | + salt = rand(UInt8, sLen) |
| 110 | + end |
| 111 | + mHash = hash(M) |
| 112 | + MPrime = vcat(zeros(UInt8, 8), mHash, salt) |
| 113 | + H = hash(MPrime) |
| 114 | + PS = zeros(UInt8, emLen - sLen - hLen - 2) |
| 115 | + DB = vcat(PS, ones(UInt8, 1), salt) |
| 116 | + dbMask = MGF(H, emLen - hLen - 1) |
| 117 | + maskedDB = DB .⊻ dbMask |
| 118 | + maskedDB[1] &= 0xFF >> (8 * emLen - emBits) |
| 119 | + EM = vcat(maskedDB, H, UInt8[0xbc]) |
| 120 | + return EM |
| 121 | +end |
| 122 | + |
| 123 | +function emsa_pss_verify(M::Vector{UInt8}, |
| 124 | + EM::Vector{UInt8}, |
| 125 | + emBits::Integer; |
| 126 | + MGF=ToyPublicKeys.MGF1, |
| 127 | + hash=SHA.sha1, |
| 128 | + sLen=0) |
| 129 | + length(M) > (big"2" << 60) && error("inconsistent") |> throw |
| 130 | + mHash = hash(M) |
| 131 | + hLen = mHash |> length |
| 132 | + emLen = ceil(emBits/8) |> Integer |
| 133 | + emLen < hLen + sLen + 2 && error("inconsistent") |> throw |
| 134 | + EM[end] != 0xbc && error("inconsistent") |> throw |
| 135 | + maskedDB = EM[1:emLen - hLen - 1] |
| 136 | + H = EM[emLen - hLen:emLen - 1] |
| 137 | + (maskedDB[1] & ~(0xFF >> (8 * emLen - emBits))) != 0 && error("inconsistent") |> throw |
| 138 | + dbMask = MGF(H, emLen - hLen - 1) |
| 139 | + DB = maskedDB .⊻ dbMask |
| 140 | + DB[1] &= 0xFF >> (8 * emLen - emBits) |
| 141 | + (DB[1:emLen - hLen - sLen - 2] .!= 0) |> any && error("inconsistent") |> throw |
| 142 | + DB[emLen - hLen - sLen - 1] != 1 && error("inconsistent") |> throw |
| 143 | + salt = DB[end - sLen + 1 : end] |
| 144 | + MPrime = MPrime = vcat(zeros(UInt8, 8), mHash, salt) |
| 145 | + HPrime = hash(MPrime) |
| 146 | + return H == HPrime |
| 147 | +end |
| 148 | + |
| 149 | +function rsaes_oaep_encrypt(M::Vector{UInt8}, |
| 150 | + key::RSAPublicKey; |
| 151 | + label="", |
| 152 | + hash=SHA.sha1, |
| 153 | + MGF=MGF1) |
| 154 | + EM = rsaes_oaep_encode(M, key, label=label, hash=hash, MGF=MGF) |
| 155 | + m = OS2IP(EM) |
| 156 | + c = RSAEP(pkcs1_v1_5, m, key) |
| 157 | + k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer |
| 158 | + C = I2OSP(c, k) |
| 159 | + return C |
| 160 | +end |
| 161 | + |
| 162 | +function rsaes_oaep_decrypt(C::Vector{UInt8}, |
| 163 | + key::RSAPrivateKey; |
| 164 | + label="", |
| 165 | + hash=SHA.sha1, |
| 166 | + MGF=MGF1) |
| 167 | + c = OS2IP(C) |
| 168 | + m = RSADP(pkcs1_v1_5, c, key) |
| 169 | + k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer |
| 170 | + EM = I2OSP(m, k) |
| 171 | + M = rsaes_oaep_decode(EM, key, label=label, hash=hash, MGF=MGF) |
| 172 | + return M |
| 173 | +end |
| 174 | + |
| 175 | +function rsaes_pkvs1_v1_5_encrypt(M::String, key::RSAPublicKey) |
| 176 | + EM = pad(pkcs1_v1_5, M) |
| 177 | + m = OS2IP(EM) |
| 178 | + c = RSAEP(pkcs1_v1_5, m, key) |
| 179 | + k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer |
| 180 | + C = I2OSP(c, k) |
| 181 | + return C |
| 182 | +end |
| 183 | + |
| 184 | +function rsaes_pkvs1_v1_5_decrypt(C::String, |
| 185 | + key::RSAPrivateKey) |
| 186 | + c = OS2IP(C) |
| 187 | + m = RSADP(pkcs1_v1_5, c, key) |
| 188 | + k = (Base.GMP.MPZ.sizeinbase(key.modulus, 2)/8) |> ceil |> Integer |
| 189 | + EM = I2OSP(m, k) |
| 190 | + m = unpad(pkcs1_v1_5, EM) |
| 191 | + return m |
| 192 | +end |
| 193 | + |
| 194 | +function rsassa_pss_sign(M::Vector{UInt8}, |
| 195 | + key::RSAPrivateKey) |
| 196 | + modBits = Base.GMP.MPZ.sizeinbase(key.modulus, 2) |
| 197 | + EM = emsa_pss_encode(M, modBits - 1) |
| 198 | + m = OS2IP(EM) |
| 199 | + s = RSASP1(pkcs1_v1_5, m, key) |
| 200 | + k = (modBits/8) |> ceil |> Integer |
| 201 | + S = I2OSP(s, k) |
| 202 | + return S |
| 203 | +end |
| 204 | + |
| 205 | +function rsassa_pss_verify(M::Vector{UInt8}, |
| 206 | + S::Vector{UInt8}, |
| 207 | + key::RSAPublicKey) |
| 208 | + modBits = Base.GMP.MPZ.sizeinbase(key.modulus, 2) |
| 209 | + k = (modBits/8) |> ceil |> Integer |
| 210 | + length(S) != k && error("invalid signature") |> throw |
| 211 | + s = OS2IP(S) |
| 212 | + m = RSAVP1(pkcs1_v1_5, s, key) |
| 213 | + emLen = ceil((modBits - 1)/8) |> Integer |
| 214 | + EM = I2OSP(m, emLen) |
| 215 | + result = emsa_pss_verify(M, EM, modBits - 1) |
| 216 | + return result |
| 217 | +end |
| 218 | + |
| 219 | +function validate(key::RSAPrivateKey) |
| 220 | + (length(key.primes) >= 2) || error("length(key.primes) < 2") |> throw |
| 221 | + all((key.public_exponent > 0, key.exponent > 0)) && all(key.primes .> 0) || error("all((key.public_exponent, key.exponent) .> 0, key.primes .> 0)") |> throw |
| 222 | + (length(key.exponent) > 0) || error("length(key.primes) < 2") |> throw |
| 223 | + (prod(key.primes) == key.modulus) || error("(prod(key.primes) != key.modulus)") |> throw |
| 224 | + ((key.exponent * key.public_exponent) % lcm((key.primes .- 1)...)) == 1 || error("(key.exponent * key.public_exponent) % lcm((key.exponent - 1), (key.public_exponent - 1)) != 1") |> throw |
| 225 | + (key.public_exponent * key.crt_exponents[1]) % (key.primes[1] - 1) == 1 || error(" (key.public_exponent * key.crt_exponents[1]) % (key.primes[1] - 1) != 1") |> throw |
| 226 | + (key.public_exponent * key.crt_exponents[2]) % (key.primes[2] - 1) == 1 || error("(key.public_exponent * key.crt_exponents[2]) % (key.primes[2] - 1) != 1") |> throw |
| 227 | + # is this consistent with the struct? |
| 228 | + (key.primes[2] * key.crt_coefficients[1]) % key.primes[1] == 1 || error("(key.primes[2] * key.crt_coefficients[2]) % key.primes[1] != 1") |> throw |
| 229 | +end |
| 230 | + |
| 231 | +""" |
| 232 | + RSAStep(::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey) |
| 233 | +
|
| 234 | +Fast implementation of the RSA exponentiation step when RSAPrivateKey is provided. |
| 235 | +It uses [Chinese remainer theorem](https://en.wikipedia.org/wiki/Chinese_remainder_theorem) for very fast `exp() mod n` calculations. |
| 236 | +""" |
| 237 | +function RSAEP(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey) |
| 238 | + return RSAStep(v, msg, key) |
| 239 | +end |
| 240 | + |
| 241 | +""" |
| 242 | + RSAStep(::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey) |
| 243 | +
|
| 244 | +RSA exponentiation step when only public key is available. |
| 245 | +Uses [repeated squares](https://en.wikipedia.org/wiki/Exponentiation_by_squaring) |
| 246 | +and other fast modulo exponentiation tricks in its GMP implementation (Base.GMP.MPZ.powm). |
| 247 | +""" |
| 248 | +function RSADP(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey) |
| 249 | + return RSAStep(v, msg, key) |
| 250 | +end |
| 251 | + |
| 252 | +""" |
| 253 | + RSASP1(::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey) |
| 254 | +
|
| 255 | +""" |
| 256 | +function RSASP1(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPrivateKey) |
| 257 | + return RSAStep(v, msg, key) |
| 258 | +end |
| 259 | + |
| 260 | +""" |
| 261 | + RSAVP1(::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey) |
| 262 | +
|
| 263 | +""" |
| 264 | +function RSAVP1(v::pkcs1_v1_5_t, msg::BigInt, key::RSAPublicKey) |
| 265 | + return RSAStep(v, msg, key) |
| 266 | +end |
0 commit comments