Skip to content

Commit acd3529

Browse files
authored
perf: non-native multilinear polynomial evaluation (#1087)
* perf: optimize multilinear evaluation * test: in same package * test: multilinear evaluation test * refactor: use new multi multilinear eval * fix: edge case with one var
1 parent 9761428 commit acd3529

File tree

5 files changed

+247
-47
lines changed

5 files changed

+247
-47
lines changed

Diff for: std/math/polynomial/polynomial.go

+73-29
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package polynomial
33
import (
44
"fmt"
55
"math/big"
6+
"math/bits"
67

78
"github.com/consensys/gnark/frontend"
89
"github.com/consensys/gnark/std/math/emulated"
@@ -102,45 +103,88 @@ func (p *Polynomial[FR]) EvalUnivariate(P Univariate[FR], at *emulated.Element[F
102103

103104
// EvalMultilinear evaluates multilinear polynomial at variable values at. It
104105
// returns the evaluation. The method does not mutate the inputs.
105-
func (p *Polynomial[FR]) EvalMultilinear(M Multilinear[FR], at []*emulated.Element[FR]) (*emulated.Element[FR], error) {
106-
var s *emulated.Element[FR]
107-
scaleCorrectionFactor := p.f.One()
108-
for len(M) > 1 {
109-
if len(M) >= minFoldScaledLogSize {
110-
M, s = p.foldScaled(M, at[0])
111-
scaleCorrectionFactor = p.f.Mul(scaleCorrectionFactor, s)
112-
} else {
113-
M = p.fold(M, at[0])
106+
func (p *Polynomial[FR]) EvalMultilinear(at []*emulated.Element[FR], M Multilinear[FR]) (*emulated.Element[FR], error) {
107+
ret, err := p.EvalMultilinearMany(at, M)
108+
if err != nil {
109+
return nil, err
110+
}
111+
return ret[0], nil
112+
}
113+
114+
// EvalMultilinearMany evaluates multilinear polynomials at variable values at. It
115+
// returns the evaluations. The method does not mutate the inputs.
116+
//
117+
// The method allows to share computations of computing the coefficients of the
118+
// multilinear polynomials at the given evaluation points.
119+
func (p *Polynomial[FR]) EvalMultilinearMany(at []*emulated.Element[FR], M ...Multilinear[FR]) ([]*emulated.Element[FR], error) {
120+
lenM := len(M[0])
121+
for i := range M {
122+
if len(M[i]) != lenM {
123+
return nil, fmt.Errorf("incompatible multilinear polynomial sizes")
114124
}
115-
at = at[1:]
116125
}
117-
if len(at) != 0 {
126+
mlelems := make([][]*emulated.Element[FR], len(M))
127+
for i := range M {
128+
mlelems[i] = FromSlice(M[i])
129+
}
130+
if bits.OnesCount(uint(lenM)) != 1 {
131+
return nil, fmt.Errorf("multilinear polynomial length must be a power of 2")
132+
}
133+
nbExpvars := bits.Len(uint(lenM)) - 1
134+
if len(at) != nbExpvars {
118135
return nil, fmt.Errorf("incompatible evaluation vector size")
119136
}
120-
return p.f.Mul(&M[0], scaleCorrectionFactor), nil
137+
split1 := nbExpvars / 2
138+
nbSplit1Elems := 1 << split1
139+
split2 := nbExpvars - split1
140+
nbSplit2Elems := 1 << split2
141+
partialMLEval1 := p.partialMultilinearEval(at[:split1])
142+
partialMLEval2 := p.partialMultilinearEval(at[split1:])
143+
sums := make([]*emulated.Element[FR], len(M))
144+
for k := range mlelems {
145+
partialSums := make([]*emulated.Element[FR], nbSplit2Elems)
146+
for i := range partialSums {
147+
b := make([]*emulated.Element[FR], nbSplit1Elems)
148+
for j := range b {
149+
b[j] = mlelems[k][i+j*nbSplit2Elems]
150+
}
151+
partialSums[i] = p.innerProduct(b, partialMLEval1)
152+
}
153+
sums[k] = p.innerProduct(partialSums, partialMLEval2)
154+
}
155+
return sums, nil
121156
}
122157

123-
func (p *Polynomial[FR]) fold(M Multilinear[FR], at *emulated.Element[FR]) Multilinear[FR] {
124-
mid := len(M) / 2
125-
R := make([]emulated.Element[FR], mid)
126-
for j := range R {
127-
diff := p.f.Sub(&M[mid+j], &M[j])
128-
diffAt := p.f.Mul(diff, at)
129-
R[j] = *p.f.Add(&M[j], diffAt)
158+
func (p *Polynomial[FR]) partialMultilinearEval(at []*emulated.Element[FR]) []*emulated.Element[FR] {
159+
if len(at) == 0 {
160+
return []*emulated.Element[FR]{p.f.One()}
161+
}
162+
res := []*emulated.Element[FR]{p.f.Sub(p.f.One(), at[len(at)-1]), at[len(at)-1]}
163+
at = at[:len(at)-1]
164+
for len(at) > 0 {
165+
newRes := make([]*emulated.Element[FR], len(res)*2)
166+
x := at[len(at)-1]
167+
for j := range res {
168+
resX := p.f.Mul(res[j], x)
169+
newRes[j] = p.f.Sub(res[j], resX)
170+
newRes[j+len(res)] = resX
171+
}
172+
res = newRes
173+
at = at[:len(at)-1]
130174
}
131-
return R
175+
return res
132176
}
133177

134-
func (p *Polynomial[FR]) foldScaled(M Multilinear[FR], at *emulated.Element[FR]) (Multilinear[FR], *emulated.Element[FR]) {
135-
denom := p.f.Sub(p.f.One(), at)
136-
coeff := p.f.Div(at, denom)
137-
mid := len(M) / 2
138-
R := make([]emulated.Element[FR], mid)
139-
for j := range R {
140-
tmp := p.f.Mul(&M[mid+j], coeff)
141-
R[j] = *p.f.Add(&M[j], tmp)
178+
func (p *Polynomial[FR]) innerProduct(a, b []*emulated.Element[FR]) *emulated.Element[FR] {
179+
if len(a) != len(b) {
180+
panic(fmt.Sprintf("incompatible sizes: %d and %d", len(a), len(b)))
181+
}
182+
muls := make([]*emulated.Element[FR], len(a))
183+
for i := range a {
184+
muls[i] = p.f.MulNoReduce(a[i], b[i])
142185
}
143-
return R, denom
186+
res := p.f.Sum(muls...)
187+
return res
144188
}
145189

146190
func (p *Polynomial[FR]) computeDeltaAtNaive(at *emulated.Element[FR], vLen int) []*emulated.Element[FR] {

Diff for: std/math/polynomial/polynomial_oldeval_test.go

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package polynomial
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/consensys/gnark/std/math/emulated"
7+
)
8+
9+
// evalMultilinearOld evaluates a multilinear polynomial at a given point.
10+
// This is the old version of the function, which is kept for comparison purposes.
11+
func (p *Polynomial[FR]) evalMultilinearOld(M Multilinear[FR], at []*emulated.Element[FR]) (*emulated.Element[FR], error) {
12+
var s *emulated.Element[FR]
13+
scaleCorrectionFactor := p.f.One()
14+
for len(M) > 1 {
15+
if len(M) >= minFoldScaledLogSize {
16+
M, s = p.foldScaled(M, at[0])
17+
scaleCorrectionFactor = p.f.Mul(scaleCorrectionFactor, s)
18+
} else {
19+
M = p.fold(M, at[0])
20+
}
21+
at = at[1:]
22+
}
23+
if len(at) != 0 {
24+
return nil, fmt.Errorf("incompatible evaluation vector size")
25+
}
26+
return p.f.Mul(&M[0], scaleCorrectionFactor), nil
27+
}
28+
29+
func (p *Polynomial[FR]) fold(M Multilinear[FR], at *emulated.Element[FR]) Multilinear[FR] {
30+
mid := len(M) / 2
31+
R := make([]emulated.Element[FR], mid)
32+
for j := range R {
33+
diff := p.f.Sub(&M[mid+j], &M[j])
34+
diffAt := p.f.Mul(diff, at)
35+
R[j] = *p.f.Add(&M[j], diffAt)
36+
}
37+
return R
38+
}
39+
40+
func (p *Polynomial[FR]) foldScaled(M Multilinear[FR], at *emulated.Element[FR]) (Multilinear[FR], *emulated.Element[FR]) {
41+
denom := p.f.Sub(p.f.One(), at)
42+
coeff := p.f.Div(at, denom)
43+
mid := len(M) / 2
44+
R := make([]emulated.Element[FR], mid)
45+
for j := range R {
46+
tmp := p.f.Mul(&M[mid+j], coeff)
47+
R[j] = *p.f.Add(&M[j], tmp)
48+
}
49+
return R, denom
50+
}

Diff for: std/math/polynomial/polynomial_test.go

+120-11
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
package polynomial_test
1+
package polynomial
22

33
import (
44
"testing"
55

66
"github.com/consensys/gnark/frontend"
77
"github.com/consensys/gnark/std/math/emulated"
88
"github.com/consensys/gnark/std/math/emulated/emparams"
9-
"github.com/consensys/gnark/std/math/polynomial"
109
"github.com/consensys/gnark/test"
1110
)
1211

@@ -17,7 +16,7 @@ type evalPolyCircuit[FR emulated.FieldParams] struct {
1716
}
1817

1918
func (c *evalPolyCircuit[FR]) Define(api frontend.API) error {
20-
p, err := polynomial.New[FR](api)
19+
p, err := New[FR](api)
2120
if err != nil {
2221
return err
2322
}
@@ -57,13 +56,13 @@ type evalMultiLinCircuit[FR emulated.FieldParams] struct {
5756
}
5857

5958
func (c *evalMultiLinCircuit[FR]) Define(api frontend.API) error {
60-
p, err := polynomial.New[FR](api)
59+
p, err := New[FR](api)
6160
if err != nil {
6261
return err
6362
}
6463
// M := polynomial.FromSlice(c.M)
65-
X := polynomial.FromSlice(c.At)
66-
res, err := p.EvalMultilinear(c.M, X)
64+
X := FromSlice(c.At)
65+
res, err := p.EvalMultilinear(X, c.M)
6766
if err != nil {
6867
return err
6968
}
@@ -108,12 +107,12 @@ type evalEqCircuit[FR emulated.FieldParams] struct {
108107
}
109108

110109
func (c *evalEqCircuit[FR]) Define(api frontend.API) error {
111-
p, err := polynomial.New[FR](api)
110+
p, err := New[FR](api)
112111
if err != nil {
113112
return err
114113
}
115-
X := polynomial.FromSlice(c.X)
116-
Y := polynomial.FromSlice(c.Y)
114+
X := FromSlice(c.X)
115+
Y := FromSlice(c.Y)
117116
evaluation := p.EvalEqual(X, Y)
118117
f, err := emulated.NewField[FR](api)
119118
if err != nil {
@@ -154,11 +153,11 @@ type interpolateLDECircuit[FR emulated.FieldParams] struct {
154153
}
155154

156155
func (c *interpolateLDECircuit[FR]) Define(api frontend.API) error {
157-
p, err := polynomial.New[FR](api)
156+
p, err := New[FR](api)
158157
if err != nil {
159158
return err
160159
}
161-
vals := polynomial.FromSlice(c.Values)
160+
vals := FromSlice(c.Values)
162161
res := p.InterpolateLDE(&c.At, vals)
163162
f, err := emulated.NewField[FR](api)
164163
if err != nil {
@@ -203,3 +202,113 @@ func TestInterpolateQuadraticExtension(t *testing.T) {
203202
testInterpolateLDE[emparams.BN254Fr](t, 3, []int64{1, 6, 17}, 34)
204203
testInterpolateLDE[emparams.BN254Fr](t, -1, []int64{1, 6, 17}, 2)
205204
}
205+
206+
type TestPartialMultilinearEvalCircuit[FR emulated.FieldParams] struct {
207+
At []emulated.Element[FR] `gnark:",public"`
208+
}
209+
210+
func (c *TestPartialMultilinearEvalCircuit[FR]) Define(api frontend.API) error {
211+
p, err := New[FR](api)
212+
if err != nil {
213+
return err
214+
}
215+
f, err := emulated.NewField[FR](api)
216+
if err != nil {
217+
return err
218+
}
219+
At := FromSlice(c.At)
220+
coefs := p.partialMultilinearEval(At)
221+
res := f.Zero()
222+
for i := range coefs {
223+
res = f.Add(res, coefs[i])
224+
}
225+
ones := make([]emulated.Element[FR], 1<<len(c.At))
226+
for i := range ones {
227+
ones[i] = emulated.ValueOf[FR](1)
228+
}
229+
evaled, err := p.EvalMultilinear(At, ones)
230+
if err != nil {
231+
return err
232+
}
233+
f.AssertIsEqual(res, evaled)
234+
return nil
235+
}
236+
237+
func TestPartialMultilinearEval(t *testing.T) {
238+
testPartialMultilinearEval[emparams.BN254Fr](t, []int64{2, 3, 4, 5})
239+
}
240+
241+
func testPartialMultilinearEval[FR emulated.FieldParams](t *testing.T, at []int64) {
242+
assert := test.NewAssert(t)
243+
atAssignment := make([]emulated.Element[FR], len(at))
244+
for i := range at {
245+
atAssignment[i] = emulated.ValueOf[FR](at[i])
246+
}
247+
assignment := &TestPartialMultilinearEvalCircuit[FR]{
248+
At: atAssignment,
249+
}
250+
assert.CheckCircuit(&TestPartialMultilinearEvalCircuit[FR]{At: make([]emulated.Element[FR], len(atAssignment))}, test.WithValidAssignment(assignment))
251+
}
252+
253+
type TestEvalMultilinear2Circuit[FR emulated.FieldParams] struct {
254+
M []Multilinear[FR] `gnark:",public"`
255+
At []emulated.Element[FR] `gnark:",secret"`
256+
}
257+
258+
func (c *TestEvalMultilinear2Circuit[FR]) Define(api frontend.API) error {
259+
f, err := emulated.NewField[FR](api)
260+
if err != nil {
261+
return err
262+
}
263+
p, err := New[FR](api)
264+
if err != nil {
265+
return err
266+
}
267+
X := FromSlice(c.At)
268+
res2, err := p.EvalMultilinearMany(X, c.M...)
269+
if err != nil {
270+
return err
271+
}
272+
for i := range c.M {
273+
res, err := p.evalMultilinearOld(c.M[i], X)
274+
if err != nil {
275+
return err
276+
}
277+
f.AssertIsEqual(res2[i], res)
278+
}
279+
return nil
280+
}
281+
282+
func TestEvalMultiLin2(t *testing.T) {
283+
testEvalMultiLin2[emparams.BN254Fr](t)
284+
}
285+
286+
func testEvalMultiLin2[FR emulated.FieldParams](t *testing.T) {
287+
assert := test.NewAssert(t)
288+
nbML := 4
289+
nbVar := 3
290+
nbVals := 1 << nbVar
291+
292+
M := make([]Multilinear[FR], nbML)
293+
for i := range M {
294+
M[i] = make([]emulated.Element[FR], nbVals)
295+
for j := range M[i] {
296+
M[i][j] = emulated.ValueOf[FR](1 + nbML*i + j)
297+
}
298+
}
299+
X := make([]emulated.Element[FR], 3)
300+
for i := range X {
301+
X[i] = emulated.ValueOf[FR](5 + i)
302+
}
303+
304+
// M = 2 X₀ + X₁ + 1
305+
witness := TestEvalMultilinear2Circuit[FR]{
306+
M: M,
307+
At: X,
308+
}
309+
circuit := &TestEvalMultilinear2Circuit[FR]{M: make([]Multilinear[FR], nbML), At: make([]emulated.Element[FR], nbVar)}
310+
for i := range circuit.M {
311+
circuit.M[i] = make([]emulated.Element[FR], nbVals)
312+
}
313+
assert.CheckCircuit(circuit, test.WithValidAssignment(&witness))
314+
}

Diff for: std/recursion/sumcheck/claimable_gate.go

+3-6
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,9 @@ func (g *gateClaim[FR]) AssertEvaluation(r []*emulated.Element[FR], combinationC
139139
// For that, we first have to map the random challenges to a random input to
140140
// the gate. As the inputs mapping is given by multilinear extension, then
141141
// this means evaluating the MLE at the random point.
142-
inputEvals := make([]*emulated.Element[FR], len(g.inputPreprocessors))
143-
for i := range inputEvals {
144-
inputEvals[i], err = g.p.EvalMultilinear(g.inputPreprocessors[i], r)
145-
if err != nil {
146-
return fmt.Errorf("eval multilin: %w", err)
147-
}
142+
inputEvals, err := g.p.EvalMultilinearMany(r, g.inputPreprocessors...)
143+
if err != nil {
144+
return fmt.Errorf("eval multilin: %w", err)
148145
}
149146
// now, we can evaluate the gate at the random input.
150147
gateEval := g.gate.Evaluate(g.engine, inputEvals...)

Diff for: std/recursion/sumcheck/claimable_multilinear.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func (fn *multilinearClaim[FR]) Degree(i int) int {
5353
}
5454

5555
func (fn *multilinearClaim[FR]) AssertEvaluation(r []*emulated.Element[FR], combinationCoeff *emulated.Element[FR], expectedValue *emulated.Element[FR], proof EvaluationProof) error {
56-
val, err := fn.p.EvalMultilinear(fn.ml, r)
56+
val, err := fn.p.EvalMultilinear(r, fn.ml)
5757
if err != nil {
5858
return fmt.Errorf("eval: %w", err)
5959
}

0 commit comments

Comments
 (0)