Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate the use of points and scalars from secp256k1_ecmult_strauss_batch. #900

229 changes: 135 additions & 94 deletions src/ecmult_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

/* The number of objects allocated on the scratch space for ecmult_multi algorithms */
#define PIPPENGER_SCRATCH_OBJECTS 6
#define STRAUSS_SCRATCH_OBJECTS 7
#define STRAUSS_SCRATCH_OBJECTS 3

#define PIPPENGER_MAX_BUCKET_WINDOW 12

Expand All @@ -56,44 +56,51 @@

#define ECMULT_MAX_POINTS_PER_BATCH 5000000

/** Fill a table 'prej' with precomputed odd multiples of a. Prej will contain
* the values [1*a,3*a,...,(2*n-1)*a], so it space for n values. zr[0] will
* contain prej[0].z / a.z. The other zr[i] values = prej[i].z / prej[i-1].z.
* Prej's Z values are undefined, except for the last value.
/** Fill a table 'pre_a' with precomputed odd multiples of a.
* pre_a will contain [1*a,3*a,...,(2*n-1)*a], so it needs space for n group elements.
* zr needs space for n field elements.
*
* Although pre_a is an array of _ge rather than _gej, it actually represents elements
* in Jacobian coordinates with their z coordinates omitted. The omitted z-coordinates
* can be recovered using z and zr. Using the notation z(b) to represent the omitted
* z coordinate of b:
* - z(pre_a[n-1]) = 'z'
* - z(pre_a[i-1]) = z(pre_a[i]) / zr[i] for n > i > 0
*
* Lastly the zr[0] value, which isn't used above, is set so that:
* - a.z = z(pre_a[0]) / zr[0]
*/
static void secp256k1_ecmult_odd_multiples_table(int n, secp256k1_gej *prej, secp256k1_fe *zr, const secp256k1_gej *a) {
secp256k1_gej d;
secp256k1_ge a_ge, d_ge;
static void secp256k1_ecmult_odd_multiples_table(int n, secp256k1_ge *pre_a, secp256k1_fe *zr, secp256k1_fe *z, const secp256k1_gej *a) {
secp256k1_gej d, ai;
secp256k1_ge d_ge;
int i;

VERIFY_CHECK(!a->infinity);

secp256k1_gej_double_var(&d, a, NULL);

/*
* Perform the additions on an isomorphism where 'd' is affine: drop the z coordinate
* of 'd', and scale the 1P starting value's x/y coordinates without changing its z.
* Perform the additions using an isomorphism that divides the z coordinate by the
* constant d.z. The group law is the same in the image, and since 'd' maps to an
* affine point, addition becomes more efficient (using mixed coordinates).
*
* phi(x, y, z) = (x, y, z/d.z)
* phi(d) = (d.x, d.y, 1)
* phi(a) = (a.x, a.y, a.z/d.z) = (a.x*d.z^2, a.y*d.z^3, a.z)
*/
d_ge.x = d.x;
d_ge.y = d.y;
d_ge.infinity = 0;

secp256k1_ge_set_gej_zinv(&a_ge, a, &d.z);
prej[0].x = a_ge.x;
prej[0].y = a_ge.y;
prej[0].z = a->z;
prej[0].infinity = 0;
secp256k1_ge_set_xy(&d_ge, &d.x, &d.y);
secp256k1_ge_set_gej_zinv(&pre_a[0], a, &d.z);
secp256k1_gej_set_ge(&ai, &pre_a[0]);
ai.z = a->z;

zr[0] = d.z;
for (i = 1; i < n; i++) {
secp256k1_gej_add_ge_var(&prej[i], &prej[i-1], &d_ge, &zr[i]);
secp256k1_gej_add_ge_var(&ai, &ai, &d_ge, &zr[i]);
secp256k1_ge_set_xy(&pre_a[i], &ai.x, &ai.y);
}

/*
* Each point in 'prej' has a z coordinate too small by a factor of 'd.z'. Only
* the final point's z coordinate is actually used though, so just update that.
*/
secp256k1_fe_mul(&prej[n-1].z, &prej[n-1].z, &d.z);
/* Multiply by d.z to undo the isomorphism that makes 'd' affine. */
secp256k1_fe_mul(z, &ai.z, &d.z);
}

/** Fill a table 'pre' with precomputed odd multiples of a.
Expand All @@ -106,13 +113,11 @@ static void secp256k1_ecmult_odd_multiples_table(int n, secp256k1_gej *prej, sec
* and use the precomputed table in <ecmult_static_pre_g.h> for G.
*/
static void secp256k1_ecmult_odd_multiples_table_globalz_windowa(secp256k1_ge *pre, secp256k1_fe *globalz, const secp256k1_gej *a) {
secp256k1_gej prej[ECMULT_TABLE_SIZE(WINDOW_A)];
secp256k1_fe zr[ECMULT_TABLE_SIZE(WINDOW_A)];

/* Compute the odd multiples in Jacobian form. */
secp256k1_ecmult_odd_multiples_table(ECMULT_TABLE_SIZE(WINDOW_A), prej, zr, a);
/* Bring them to the same Z denominator. */
secp256k1_ge_globalz_set_table_gej(ECMULT_TABLE_SIZE(WINDOW_A), pre, globalz, prej, zr);
secp256k1_ecmult_odd_multiples_table(ECMULT_TABLE_SIZE(WINDOW_A), pre, zr, globalz, a);
secp256k1_ge_globalz_fixup_table(ECMULT_TABLE_SIZE(WINDOW_A), pre, zr);
}

/** The following two macro retrieves a particular odd multiple from a table
Expand All @@ -129,6 +134,20 @@ static void secp256k1_ecmult_odd_multiples_table_globalz_windowa(secp256k1_ge *p
} \
} while(0)

#define ECMULT_TABLE_GET_GE_LAMBDA(r,pre,aux,n,w) do { \
VERIFY_CHECK(((n) & 1) == 1); \
VERIFY_CHECK((n) >= -((1 << ((w)-1)) - 1)); \
VERIFY_CHECK((n) <= ((1 << ((w)-1)) - 1)); \
if ((n) > 0) { \
(r)->x = (aux)[((n)-1)/2]; \
(r)->y = (pre)[((n)-1)/2].y; \
} else { \
(r)->x = (aux)[(-(n)-1)/2]; \
secp256k1_fe_negate(&((r)->y), &((pre)[(-(n)-1)/2].y), 1); \
} \
(r)->infinity = 0; \
} while(0)

#define ECMULT_TABLE_GET_GE_STORAGE(r,pre,n,w) do { \
VERIFY_CHECK(((n) & 1) == 1); \
VERIFY_CHECK((n) >= -((1 << ((w)-1)) - 1)); \
Expand Down Expand Up @@ -201,23 +220,22 @@ static int secp256k1_ecmult_wnaf(int *wnaf, int len, const secp256k1_scalar *a,
}

struct secp256k1_strauss_point_state {
secp256k1_scalar na_1, na_lam;
int wnaf_na_1[129];
int wnaf_na_lam[129];
int bits_na_1;
int bits_na_lam;
size_t input_pos;
};

struct secp256k1_strauss_state {
secp256k1_gej* prej;
secp256k1_fe* zr;
/* aux is used to hold z-ratios, and then used to hold pre_a[i].x * BETA values. */
secp256k1_fe* aux;
secp256k1_ge* pre_a;
secp256k1_ge* pre_a_lam;
struct secp256k1_strauss_point_state* ps;
};

static void secp256k1_ecmult_strauss_wnaf(const struct secp256k1_strauss_state *state, secp256k1_gej *r, size_t num, const secp256k1_gej *a, const secp256k1_scalar *na, const secp256k1_scalar *ng) {
typedef int (secp256k1_ecmult_strauss_multi_callback)(secp256k1_scalar *sc, secp256k1_gej *pt, size_t idx, void *data);

static int secp256k1_ecmult_strauss_wnaf(const struct secp256k1_strauss_state *state, secp256k1_gej *r, size_t num, secp256k1_ecmult_strauss_multi_callback cb, void *cbdata, size_t cb_offset, const secp256k1_scalar *ng) {
secp256k1_ge tmpa;
secp256k1_fe Z;
/* Split G factors. */
Expand All @@ -231,17 +249,23 @@ static void secp256k1_ecmult_strauss_wnaf(const struct secp256k1_strauss_state *
size_t np;
size_t no = 0;

secp256k1_fe_set_int(&Z, 1);
for (np = 0; np < num; ++np) {
if (secp256k1_scalar_is_zero(&na[np]) || secp256k1_gej_is_infinity(&a[np])) {
secp256k1_gej a;
secp256k1_fe az;
secp256k1_scalar na, na_1, na_lam;

if (!cb(&na, &a, np + cb_offset, cbdata)) return 0;
if (secp256k1_scalar_is_zero(&na) || secp256k1_gej_is_infinity(&a)) {
continue;
}
state->ps[no].input_pos = np;

/* split na into na_1 and na_lam (where na = na_1 + na_lam*lambda, and na_1 and na_lam are ~128 bit) */
secp256k1_scalar_split_lambda(&state->ps[no].na_1, &state->ps[no].na_lam, &na[np]);
secp256k1_scalar_split_lambda(&na_1, &na_lam, &na);

/* build wnaf representation for na_1 and na_lam. */
state->ps[no].bits_na_1 = secp256k1_ecmult_wnaf(state->ps[no].wnaf_na_1, 129, &state->ps[no].na_1, WINDOW_A);
state->ps[no].bits_na_lam = secp256k1_ecmult_wnaf(state->ps[no].wnaf_na_lam, 129, &state->ps[no].na_lam, WINDOW_A);
state->ps[no].bits_na_1 = secp256k1_ecmult_wnaf(state->ps[no].wnaf_na_1, 129, &na_1, WINDOW_A);
state->ps[no].bits_na_lam = secp256k1_ecmult_wnaf(state->ps[no].wnaf_na_lam, 129, &na_lam, WINDOW_A);
VERIFY_CHECK(state->ps[no].bits_na_1 <= 129);
VERIFY_CHECK(state->ps[no].bits_na_lam <= 129);
if (state->ps[no].bits_na_1 > bits) {
Expand All @@ -250,40 +274,36 @@ static void secp256k1_ecmult_strauss_wnaf(const struct secp256k1_strauss_state *
if (state->ps[no].bits_na_lam > bits) {
bits = state->ps[no].bits_na_lam;
}
++no;
}

/* Calculate odd multiples of a.
* All multiples are brought to the same Z 'denominator', which is stored
* in Z. Due to secp256k1' isomorphism we can do all operations pretending
* that the Z coordinate was 1, use affine addition formulae, and correct
* the Z coordinate of the result once at the end.
* The exception is the precomputed G table points, which are actually
* affine. Compared to the base used for other points, they have a Z ratio
* of 1/Z, so we can use secp256k1_gej_add_zinv_var, which uses the same
* isomorphism to efficiently add with a known Z inverse.
*/
if (no > 0) {
/* Compute the odd multiples in Jacobian form. */
secp256k1_ecmult_odd_multiples_table(ECMULT_TABLE_SIZE(WINDOW_A), state->prej, state->zr, &a[state->ps[0].input_pos]);
for (np = 1; np < no; ++np) {
secp256k1_gej tmp = a[state->ps[np].input_pos];
/* Calculate odd multiples of a.
* All multiples are brought to the same Z 'denominator', which is stored
* in Z. Due to secp256k1' isomorphism we can do all operations pretending
* that the Z coordinate was 1, use affine addition formulae, and correct
* the Z coordinate of the result once at the end.
* The exception is the precomputed G table points, which are actually
* affine. Compared to the base used for other points, they have a Z ratio
* of 1/Z, so we can use secp256k1_gej_add_zinv_var, which uses the same
* isomorphism to efficiently add with a known Z inverse.
*/
az = a.z;
if (no) {
#ifdef VERIFY
secp256k1_fe_normalize_var(&(state->prej[(np - 1) * ECMULT_TABLE_SIZE(WINDOW_A) + ECMULT_TABLE_SIZE(WINDOW_A) - 1].z));
secp256k1_fe_normalize_var(&Z);
#endif
secp256k1_gej_rescale(&tmp, &(state->prej[(np - 1) * ECMULT_TABLE_SIZE(WINDOW_A) + ECMULT_TABLE_SIZE(WINDOW_A) - 1].z));
secp256k1_ecmult_odd_multiples_table(ECMULT_TABLE_SIZE(WINDOW_A), state->prej + np * ECMULT_TABLE_SIZE(WINDOW_A), state->zr + np * ECMULT_TABLE_SIZE(WINDOW_A), &tmp);
secp256k1_fe_mul(state->zr + np * ECMULT_TABLE_SIZE(WINDOW_A), state->zr + np * ECMULT_TABLE_SIZE(WINDOW_A), &(a[state->ps[np].input_pos].z));
secp256k1_gej_rescale(&a, &Z);
}
/* Bring them to the same Z denominator. */
secp256k1_ge_globalz_set_table_gej(ECMULT_TABLE_SIZE(WINDOW_A) * no, state->pre_a, &Z, state->prej, state->zr);
} else {
secp256k1_fe_set_int(&Z, 1);
secp256k1_ecmult_odd_multiples_table(ECMULT_TABLE_SIZE(WINDOW_A), state->pre_a + no * ECMULT_TABLE_SIZE(WINDOW_A), state->aux + no * ECMULT_TABLE_SIZE(WINDOW_A), &Z, &a);
if (no) secp256k1_fe_mul(state->aux + no * ECMULT_TABLE_SIZE(WINDOW_A), state->aux + no * ECMULT_TABLE_SIZE(WINDOW_A), &az);

++no;
}

/* Bring them to the same Z denominator. */
secp256k1_ge_globalz_fixup_table(ECMULT_TABLE_SIZE(WINDOW_A) * no, state->pre_a, state->aux);

for (np = 0; np < no; ++np) {
for (i = 0; i < ECMULT_TABLE_SIZE(WINDOW_A); i++) {
secp256k1_ge_mul_lambda(&state->pre_a_lam[np * ECMULT_TABLE_SIZE(WINDOW_A) + i], &state->pre_a[np * ECMULT_TABLE_SIZE(WINDOW_A) + i]);
secp256k1_fe_mul(&state->aux[np * ECMULT_TABLE_SIZE(WINDOW_A) + i], &state->pre_a[np * ECMULT_TABLE_SIZE(WINDOW_A) + i].x, &secp256k1_const_beta);
}
}

Expand Down Expand Up @@ -313,7 +333,7 @@ static void secp256k1_ecmult_strauss_wnaf(const struct secp256k1_strauss_state *
secp256k1_gej_add_ge_var(r, r, &tmpa, NULL);
}
if (i < state->ps[np].bits_na_lam && (n = state->ps[np].wnaf_na_lam[i])) {
ECMULT_TABLE_GET_GE(&tmpa, state->pre_a_lam + np * ECMULT_TABLE_SIZE(WINDOW_A), n, WINDOW_A);
ECMULT_TABLE_GET_GE_LAMBDA(&tmpa, state->pre_a + np * ECMULT_TABLE_SIZE(WINDOW_A), state->aux + np * ECMULT_TABLE_SIZE(WINDOW_A), n, WINDOW_A);
secp256k1_gej_add_ge_var(r, r, &tmpa, NULL);
}
}
Expand All @@ -330,34 +350,66 @@ static void secp256k1_ecmult_strauss_wnaf(const struct secp256k1_strauss_state *
if (!r->infinity) {
secp256k1_fe_mul(&r->z, &r->z, &Z);
}

return 1;
}

struct secp256k1_ecmult_array_cb_data {
const secp256k1_scalar *na;
const secp256k1_gej *a;
};

static int secp256k1_ecmult_array_cb(secp256k1_scalar *sc, secp256k1_gej *pt, size_t idx, void *data) {
struct secp256k1_ecmult_array_cb_data *array_data = data;
*sc = array_data->na[idx];
if (array_data->a) {
*pt = array_data->a[idx];
return 1;
} else {
return secp256k1_scalar_is_zero(sc);
}
}

static void secp256k1_ecmult(secp256k1_gej *r, const secp256k1_gej *a, const secp256k1_scalar *na, const secp256k1_scalar *ng) {
secp256k1_gej prej[ECMULT_TABLE_SIZE(WINDOW_A)];
secp256k1_fe zr[ECMULT_TABLE_SIZE(WINDOW_A)];
secp256k1_fe aux[ECMULT_TABLE_SIZE(WINDOW_A)];
secp256k1_ge pre_a[ECMULT_TABLE_SIZE(WINDOW_A)];
struct secp256k1_strauss_point_state ps[1];
secp256k1_ge pre_a_lam[ECMULT_TABLE_SIZE(WINDOW_A)];
struct secp256k1_strauss_state state;
struct secp256k1_ecmult_array_cb_data data;

state.prej = prej;
state.zr = zr;
state.aux = aux;
state.pre_a = pre_a;
state.pre_a_lam = pre_a_lam;
state.ps = ps;
secp256k1_ecmult_strauss_wnaf(&state, r, 1, a, na, ng);
data.na = na;
data.a = a;
secp256k1_ecmult_strauss_wnaf(&state, r, 1, &secp256k1_ecmult_array_cb, &data, 0, ng);
}

static size_t secp256k1_strauss_scratch_size(size_t n_points) {
static const size_t point_size = (2 * sizeof(secp256k1_ge) + sizeof(secp256k1_gej) + sizeof(secp256k1_fe)) * ECMULT_TABLE_SIZE(WINDOW_A) + sizeof(struct secp256k1_strauss_point_state) + sizeof(secp256k1_gej) + sizeof(secp256k1_scalar);
static const size_t point_size = (sizeof(secp256k1_ge) + sizeof(secp256k1_fe)) * ECMULT_TABLE_SIZE(WINDOW_A) + sizeof(struct secp256k1_strauss_point_state);
return n_points*point_size;
}

struct secp256k1_ecmult_adaptor_cb_data {
secp256k1_ecmult_multi_callback *cb;
void *data;
};

static int secp256k1_ecmult_adaptor_cb(secp256k1_scalar *sc, secp256k1_gej *pt, size_t idx, void *data) {
secp256k1_ge tmp;
struct secp256k1_ecmult_adaptor_cb_data *adaptor_data = data;
int result = adaptor_data->cb(sc, &tmp, idx, adaptor_data->data);

if (result) {
secp256k1_gej_set_ge(pt, &tmp);
}

return result;
}

static int secp256k1_ecmult_strauss_batch(const secp256k1_callback* error_callback, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n_points, size_t cb_offset) {
secp256k1_gej* points;
secp256k1_scalar* scalars;
struct secp256k1_ecmult_adaptor_cb_data adaptor_data;
struct secp256k1_strauss_state state;
size_t i;
const size_t scratch_checkpoint = secp256k1_scratch_checkpoint(error_callback, scratch);

secp256k1_gej_set_infinity(r);
Expand All @@ -368,28 +420,17 @@ static int secp256k1_ecmult_strauss_batch(const secp256k1_callback* error_callba
/* We allocate STRAUSS_SCRATCH_OBJECTS objects on the scratch space. If these
* allocations change, make sure to update the STRAUSS_SCRATCH_OBJECTS
* constant and strauss_scratch_size accordingly. */
points = (secp256k1_gej*)secp256k1_scratch_alloc(error_callback, scratch, n_points * sizeof(secp256k1_gej));
scalars = (secp256k1_scalar*)secp256k1_scratch_alloc(error_callback, scratch, n_points * sizeof(secp256k1_scalar));
state.prej = (secp256k1_gej*)secp256k1_scratch_alloc(error_callback, scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_gej));
state.zr = (secp256k1_fe*)secp256k1_scratch_alloc(error_callback, scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_fe));
adaptor_data.cb = cb;
adaptor_data.data = cbdata;
state.aux = (secp256k1_fe*)secp256k1_scratch_alloc(error_callback, scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_fe));
state.pre_a = (secp256k1_ge*)secp256k1_scratch_alloc(error_callback, scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_ge));
state.pre_a_lam = (secp256k1_ge*)secp256k1_scratch_alloc(error_callback, scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_ge));
state.ps = (struct secp256k1_strauss_point_state*)secp256k1_scratch_alloc(error_callback, scratch, n_points * sizeof(struct secp256k1_strauss_point_state));

if (points == NULL || scalars == NULL || state.prej == NULL || state.zr == NULL || state.pre_a == NULL || state.pre_a_lam == NULL || state.ps == NULL) {
if (state.aux == NULL || state.pre_a == NULL || state.ps == NULL ||
!secp256k1_ecmult_strauss_wnaf(&state, r, n_points, &secp256k1_ecmult_adaptor_cb, &adaptor_data, cb_offset, inp_g_sc)) {
Copy link
Contributor

@robot-dreams robot-dreams Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're checking the return value of secp256k1_ecmult_strauss_wnaf when you call it here, would you also need to check it when you call it at the end of secp256k1_ecmult?

My concern is that if the callback (and thus the call to secp256k1_ecmult_strauss_wnaf) "fails", this could either (1) introduce a way for secp256k1_ecmult to fail unexpectedly, or (2) force secp256k1_ecmult to return int instead of void, and either way this would break the "public" API.

secp256k1_scratch_apply_checkpoint(error_callback, scratch, scratch_checkpoint);
return 0;
}

for (i = 0; i < n_points; i++) {
secp256k1_ge point;
if (!cb(&scalars[i], &point, i+cb_offset, cbdata)) {
secp256k1_scratch_apply_checkpoint(error_callback, scratch, scratch_checkpoint);
return 0;
}
secp256k1_gej_set_ge(&points[i], &point);
}
secp256k1_ecmult_strauss_wnaf(&state, r, n_points, points, scalars, inp_g_sc);
secp256k1_scratch_apply_checkpoint(error_callback, scratch, scratch_checkpoint);
return 1;
}
Expand Down
4 changes: 4 additions & 0 deletions src/field_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,9 @@ static int secp256k1_fe_sqrt(secp256k1_fe *r, const secp256k1_fe *a) {
}

static const secp256k1_fe secp256k1_fe_one = SECP256K1_FE_CONST(0, 0, 0, 0, 0, 0, 0, 1);
static const secp256k1_fe secp256k1_const_beta = SECP256K1_FE_CONST(
0x7ae96a2bul, 0x657c0710ul, 0x6e64479eul, 0xac3434e9ul,
0x9cf04975ul, 0x12f58995ul, 0xc1396c28ul, 0x719501eeul
);

#endif /* SECP256K1_FIELD_IMPL_H */
Loading