Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions lib/evmone_precompiles/modexp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,41 +34,63 @@ constexpr unsigned ctz(const intx::uint<N>& x) noexcept
return tz;
}

class Exponent
{
const uint8_t* data_ = nullptr;
size_t num_bits_ = 0;

public:
explicit Exponent(std::span<const uint8_t> bytes)
{
const auto it = std::ranges::find_if(bytes, [](auto x) { return x != 0; });
const auto trimmed_bytes = std::span{it, bytes.end()};
num_bits_ = trimmed_bytes.empty() ? 0 :
static_cast<size_t>(std::bit_width(trimmed_bytes[0])) +
(trimmed_bytes.size() - 1) * 8;
data_ = trimmed_bytes.data();
}

[[nodiscard]] size_t num_bits() const noexcept { return num_bits_; }

bool operator[](size_t index) const noexcept
{
// TODO: Replace this with a custom iterator type.
const auto exp_size = (num_bits_ + 7) / 8;
const auto byte_index = index / 8;
const auto byte = data_[exp_size - 1 - byte_index];
const auto bit_index = index % 8;
const auto bit = (byte >> bit_index) & 1;
return bit != 0;
}
};

template <typename UIntT>
UIntT modexp_odd(const UIntT& base, std::span<const uint8_t> exp, const UIntT& mod) noexcept
UIntT modexp_odd(const UIntT& base, Exponent exp, const UIntT& mod) noexcept
{
const evmmax::ModArith<UIntT> arith{mod};
const auto base_mont = arith.to_mont(base);

auto ret = arith.to_mont(1);
for (const auto e : exp)
for (auto i = exp.num_bits(); i != 0; --i)
{
for (size_t i = 8; i != 0; --i)
{
ret = arith.mul(ret, ret);
const auto bit = (e >> (i - 1)) & 1;
if (bit != 0)
ret = arith.mul(ret, base_mont);
}
ret = arith.mul(ret, ret);
if (exp[i - 1])
ret = arith.mul(ret, base_mont);
}

return arith.from_mont(ret);
}

template <typename UIntT>
UIntT modexp_pow2(const UIntT& base, std::span<const uint8_t> exp, unsigned k) noexcept
UIntT modexp_pow2(const UIntT& base, Exponent exp, unsigned k) noexcept
{
assert(k != 0); // Modulus of 1 should be covered as "odd".
UIntT ret = 1;
for (auto e : exp)
for (auto i = exp.num_bits(); i != 0; --i)
{
for (size_t i = 8; i != 0; --i)
{
ret *= ret;
const auto bit = (e >> (i - 1)) & 1;
if (bit != 0)
ret *= base;
}
ret *= ret;
if (exp[i - 1])
ret *= base;
}

const auto mod_pow2_mask = (UIntT{1} << k) - 1;
Expand Down Expand Up @@ -102,7 +124,7 @@ UIntT load(std::span<const uint8_t> data) noexcept
}

template <size_t Size>
void modexp_impl(std::span<const uint8_t> base_bytes, std::span<const uint8_t> exp,
void modexp_impl(std::span<const uint8_t> base_bytes, Exponent exp,
std::span<const uint8_t> mod_bytes, uint8_t* output) noexcept
{
using UIntT = intx::uint<Size * 8>;
Expand Down Expand Up @@ -142,20 +164,19 @@ void modexp(std::span<const uint8_t> base, std::span<const uint8_t> exp,
assert(base.size() <= MAX_INPUT_SIZE);
assert(mod.size() <= MAX_INPUT_SIZE);

const auto it = std::ranges::find_if(exp, [](auto x) { return x != 0; });
exp = std::span{it, exp.end()};
const Exponent exp_obj{exp};

if (const auto size = std::max(mod.size(), base.size()); size <= 16)
modexp_impl<16>(base, exp, mod, output);
modexp_impl<16>(base, exp_obj, mod, output);
else if (size <= 32)
modexp_impl<32>(base, exp, mod, output);
modexp_impl<32>(base, exp_obj, mod, output);
else if (size <= 64)
modexp_impl<64>(base, exp, mod, output);
modexp_impl<64>(base, exp_obj, mod, output);
else if (size <= 128)
modexp_impl<128>(base, exp, mod, output);
modexp_impl<128>(base, exp_obj, mod, output);
else if (size <= 256)
modexp_impl<256>(base, exp, mod, output);
modexp_impl<256>(base, exp_obj, mod, output);
else
modexp_impl<MAX_INPUT_SIZE>(base, exp, mod, output);
modexp_impl<MAX_INPUT_SIZE>(base, exp_obj, mod, output);
}
} // namespace evmone::crypto