diff --git a/lib/evmone_precompiles/modexp.cpp b/lib/evmone_precompiles/modexp.cpp index e45c4ddeca..b590b5e908 100644 --- a/lib/evmone_precompiles/modexp.cpp +++ b/lib/evmone_precompiles/modexp.cpp @@ -34,41 +34,63 @@ constexpr unsigned ctz(const intx::uint& x) noexcept return tz; } +class Exponent +{ + const uint8_t* data_ = nullptr; + size_t num_bits_ = 0; + +public: + explicit Exponent(std::span bytes) + { + const auto it = std::ranges::find_if(bytes, [](auto x) { return x != 0; }); + const auto trimmed_bytes = std::span{it, bytes.end()}; + num_bits_ = trimmed_bytes.empty() ? 0 : + static_cast(std::bit_width(trimmed_bytes[0])) + + (trimmed_bytes.size() - 1) * 8; + data_ = trimmed_bytes.data(); + } + + [[nodiscard]] size_t num_bits() const noexcept { return num_bits_; } + + bool operator[](size_t index) const noexcept + { + // TODO: Replace this with a custom iterator type. + const auto exp_size = (num_bits_ + 7) / 8; + const auto byte_index = index / 8; + const auto byte = data_[exp_size - 1 - byte_index]; + const auto bit_index = index % 8; + const auto bit = (byte >> bit_index) & 1; + return bit != 0; + } +}; + template -UIntT modexp_odd(const UIntT& base, std::span exp, const UIntT& mod) noexcept +UIntT modexp_odd(const UIntT& base, Exponent exp, const UIntT& mod) noexcept { const evmmax::ModArith arith{mod}; const auto base_mont = arith.to_mont(base); auto ret = arith.to_mont(1); - for (const auto e : exp) + for (auto i = exp.num_bits(); i != 0; --i) { - for (size_t i = 8; i != 0; --i) - { - ret = arith.mul(ret, ret); - const auto bit = (e >> (i - 1)) & 1; - if (bit != 0) - ret = arith.mul(ret, base_mont); - } + ret = arith.mul(ret, ret); + if (exp[i - 1]) + ret = arith.mul(ret, base_mont); } return arith.from_mont(ret); } template -UIntT modexp_pow2(const UIntT& base, std::span exp, unsigned k) noexcept +UIntT modexp_pow2(const UIntT& base, Exponent exp, unsigned k) noexcept { assert(k != 0); // Modulus of 1 should be covered as "odd". UIntT ret = 1; - for (auto e : exp) + for (auto i = exp.num_bits(); i != 0; --i) { - for (size_t i = 8; i != 0; --i) - { - ret *= ret; - const auto bit = (e >> (i - 1)) & 1; - if (bit != 0) - ret *= base; - } + ret *= ret; + if (exp[i - 1]) + ret *= base; } const auto mod_pow2_mask = (UIntT{1} << k) - 1; @@ -102,7 +124,7 @@ UIntT load(std::span data) noexcept } template -void modexp_impl(std::span base_bytes, std::span exp, +void modexp_impl(std::span base_bytes, Exponent exp, std::span mod_bytes, uint8_t* output) noexcept { using UIntT = intx::uint; @@ -142,20 +164,19 @@ void modexp(std::span base, std::span exp, assert(base.size() <= MAX_INPUT_SIZE); assert(mod.size() <= MAX_INPUT_SIZE); - const auto it = std::ranges::find_if(exp, [](auto x) { return x != 0; }); - exp = std::span{it, exp.end()}; + const Exponent exp_obj{exp}; if (const auto size = std::max(mod.size(), base.size()); size <= 16) - modexp_impl<16>(base, exp, mod, output); + modexp_impl<16>(base, exp_obj, mod, output); else if (size <= 32) - modexp_impl<32>(base, exp, mod, output); + modexp_impl<32>(base, exp_obj, mod, output); else if (size <= 64) - modexp_impl<64>(base, exp, mod, output); + modexp_impl<64>(base, exp_obj, mod, output); else if (size <= 128) - modexp_impl<128>(base, exp, mod, output); + modexp_impl<128>(base, exp_obj, mod, output); else if (size <= 256) - modexp_impl<256>(base, exp, mod, output); + modexp_impl<256>(base, exp_obj, mod, output); else - modexp_impl(base, exp, mod, output); + modexp_impl(base, exp_obj, mod, output); } } // namespace evmone::crypto