Skip to content

Implement Flag-Based Verification for Recursive SNARK Proofs & In-Circuit Signatures #1432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
16 changes: 15 additions & 1 deletion std/algebra/emulated/fields_bw6761/e6.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ func (e Ext6) Sub(x, y *E6) *E6 {
}
}

func (e Ext6) IsZero(x *E6) frontend.Variable {
isZero := e.fp.IsZero(&x.A0)
isZero = e.api.And(isZero, e.fp.IsZero(&x.A1))
isZero = e.api.And(isZero, e.fp.IsZero(&x.A2))
isZero = e.api.And(isZero, e.fp.IsZero(&x.A3))
isZero = e.api.And(isZero, e.fp.IsZero(&x.A4))
isZero = e.api.And(isZero, e.fp.IsZero(&x.A5))
return isZero
}

func (e Ext6) IsEqual(x, y *E6) frontend.Variable {
diff := e.Sub(x, y)
return e.IsZero(diff)
}

func (e Ext6) Double(x *E6) *E6 {
two := big.NewInt(2)
a0 := e.fp.MulConst(&x.A0, two)
Expand Down Expand Up @@ -1106,7 +1121,6 @@ func (e Ext6) AssertIsEqual(a, b *E6) {
e.fp.AssertIsEqual(&a.A3, &b.A3)
e.fp.AssertIsEqual(&a.A4, &b.A4)
e.fp.AssertIsEqual(&a.A5, &b.A5)

}

func (e Ext6) Copy(x *E6) *E6 {
Expand Down
5 changes: 5 additions & 0 deletions std/algebra/emulated/sw_bls12381/pairing.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ func (pr Pairing) AssertIsEqual(x, y *GTEl) {
pr.Ext12.AssertIsEqual(x, y)
}

func (pr Pairing) IsEqual(x, y *GTEl) frontend.Variable {
diff := pr.Ext12.Sub(x, y)
return pr.Ext12.IsEqual(diff, pr.Ext12.Zero())
}

func (pr Pairing) MuxG2(sel frontend.Variable, inputs ...*G2Affine) *G2Affine {
if len(inputs) == 0 {
return nil
Expand Down
4 changes: 4 additions & 0 deletions std/algebra/emulated/sw_bw6761/pairing.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ func (pr Pairing) AssertIsEqual(x, y *GTEl) {
pr.Ext6.AssertIsEqual(x, y)
}

func (pr Pairing) IsEqual(x, y *GTEl) frontend.Variable {
return pr.Ext6.IsEqual(x, y)
}

func (pr Pairing) MuxG2(sel frontend.Variable, inputs ...*G2Affine) *G2Affine {
if len(inputs) == 0 {
return nil
Expand Down
3 changes: 3 additions & 0 deletions std/algebra/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ type Pairing[G1El G1ElementT, G2El G2ElementT, GtEl GtElementT] interface {
// AssertIsEqual asserts the equality of the inputs.
AssertIsEqual(*GtEl, *GtEl)

// IsEqual returns 1 if both inputs are equal, 0 otherwise.
IsEqual(*GtEl, *GtEl) frontend.Variable

// AssertIsOnG1 asserts that the input is on the G1 curve.
AssertIsOnG1(*G1El)

Expand Down
14 changes: 14 additions & 0 deletions std/algebra/native/fields_bls12377/e12.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ func (e *E12) Mul(api frontend.API, e1, e2 E12) *E12 {
return e
}

func (e *E12) IsZero(api frontend.API) frontend.Variable {
isZero := e.C0.B0.IsZero(api)
isZero = api.And(isZero, e.C0.B1.IsZero(api))
isZero = api.And(isZero, e.C0.B2.IsZero(api))
isZero = api.And(isZero, e.C1.B0.IsZero(api))
isZero = api.And(isZero, e.C1.B1.IsZero(api))
isZero = api.And(isZero, e.C1.B2.IsZero(api))
return isZero
}

func (e *E12) IsEqual(api frontend.API, x, y *E12) frontend.Variable {
return e.Sub(api, *x, *y).IsZero(api)
}

// Square squares an element in Fp12
func (e *E12) Square(api frontend.API, x E12) *E12 {

Expand Down
11 changes: 11 additions & 0 deletions std/algebra/native/fields_bls24315/e24.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,17 @@ func (e *E24) Mul(api frontend.API, e1, e2 E24) *E24 {
return e
}

func (e *E24) IsZero(api frontend.API) frontend.Variable {
isZero := e.D0.C0.IsZero(api)
isZero = api.And(isZero, e.D0.C1.IsZero(api))
isZero = api.And(isZero, e.D0.C2.IsZero(api))
isZero = api.And(isZero, e.D1.C0.IsZero(api))
isZero = api.And(isZero, e.D1.C1.IsZero(api))
isZero = api.And(isZero, e.D1.C2.IsZero(api))
return isZero
}


// Square squares an element in Fp24
func (e *E24) Square(api frontend.API, x E24) *E24 {

Expand Down
4 changes: 4 additions & 0 deletions std/algebra/native/sw_bls12377/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ func (p *Pairing) AssertIsEqual(e1, e2 *GT) {
e1.AssertIsEqual(p.api, *e2)
}

func (pr *Pairing) IsEqual(e1, e2 *GT) frontend.Variable {
return e1.IsEqual(pr.api, e1, e2)
}

func (pr Pairing) MuxG2(sel frontend.Variable, inputs ...*G2Affine) *G2Affine {
if len(inputs) == 0 {
return nil
Expand Down
4 changes: 4 additions & 0 deletions std/algebra/native/sw_bls24315/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,10 @@ func (pr Pairing) MuxGt(sel frontend.Variable, inputs ...*GT) *GT {
return &ret
}

func (pr *Pairing) IsEqual(e1, e2 *GT) frontend.Variable {
return e1.Sub(pr.api, *e1, *e2).IsZero(pr.api)
}

func (p *Pairing) AssertIsOnG1(P *G1Affine) {
panic("not implemented")
}
Expand Down
5 changes: 5 additions & 0 deletions std/algebra/native/twistededwards/curve.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ func (c *curve) Neg(p1 Point) Point {
p.neg(c.api, &p1)
return p
}

func (c *curve) IsOnCurve(p1 Point) frontend.Variable {
return p1.isOnCurve(c.api, c.params)
}

func (c *curve) AssertIsOnCurve(p1 Point) {
p1.assertIsOnCurve(c.api, c.params)
}
Expand Down
10 changes: 8 additions & 2 deletions std/algebra/native/twistededwards/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ func (p *Point) neg(api frontend.API, p1 *Point) *Point {
// assertIsOnCurve checks if a point is on the reduced twisted Edwards curve
// a*x² + y² = 1 + d*x²*y².
func (p *Point) assertIsOnCurve(api frontend.API, curve *CurveParams) {
flag := p.isOnCurve(api, curve)
api.AssertIsEqual(flag, 1)
}

// isOnCurve returns 1 if a point is on the reduced twisted Edwards curve
// a*x² + y² = 1 + d*x²*y², 0 otherwise.
func (p *Point) isOnCurve(api frontend.API, curve *CurveParams) frontend.Variable {

xx := api.Mul(p.X, p.X)
yy := api.Mul(p.Y, p.Y)
Expand All @@ -25,8 +32,7 @@ func (p *Point) assertIsOnCurve(api frontend.API, curve *CurveParams) {
dxxyy := api.Mul(dxx, yy)
rhs := api.Add(dxxyy, 1)

api.AssertIsEqual(lhs, rhs)

return api.IsZero(api.Sub(lhs, rhs))
}

// add Adds two points on a twisted edwards curve (eg jubjub)
Expand Down
1 change: 1 addition & 0 deletions std/algebra/native/twistededwards/twistededwards.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Curve interface {
Add(p1, p2 Point) Point
Double(p1 Point) Point
Neg(p1 Point) Point
IsOnCurve(p1 Point) frontend.Variable
AssertIsOnCurve(p1 Point)
ScalarMul(p1 Point, scalar frontend.Variable) Point
DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point
Expand Down
32 changes: 21 additions & 11 deletions std/recursion/groth16/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,16 +607,27 @@ func NewVerifier[FR emulated.FieldParams, G1El algebra.G1ElementT, G2El algebra.
// AssertProof asserts that the SNARK proof holds for the given witness and
// verifying key.
func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El, GtEl], proof Proof[G1El, G2El], witness Witness[FR], opts ...VerifierOption) error {
flag, err := v.ProofIsValid(vk, proof, witness, opts...)
if err != nil {
return err
}
v.api.AssertIsEqual(flag, 1)
return nil
}

// ProofIsValid returns 1 if the SNARK proof holds for the given witness and
// verifying key, and 0 otherwise.
func (v *Verifier[FR, G1El, G2El, GtEl]) ProofIsValid(vk VerifyingKey[G1El, G2El, GtEl], proof Proof[G1El, G2El], witness Witness[FR], opts ...VerifierOption) (frontend.Variable, error) {
if len(vk.CommitmentKeys) != len(proof.Commitments) {
return fmt.Errorf("invalid number of commitments, got %d, expected %d", len(proof.Commitments), len(vk.CommitmentKeys))
return 0, fmt.Errorf("invalid number of commitments, got %d, expected %d", len(proof.Commitments), len(vk.CommitmentKeys))
}
if len(vk.CommitmentKeys) != len(vk.PublicAndCommitmentCommitted) {
return fmt.Errorf("invalid number of commitment keys, got %d, expected %d", len(vk.CommitmentKeys), len(vk.PublicAndCommitmentCommitted))
return 0, fmt.Errorf("invalid number of commitment keys, got %d, expected %d", len(vk.CommitmentKeys), len(vk.PublicAndCommitmentCommitted))
}
var fr FR
nbPublicVars := len(vk.G1.K) - len(vk.PublicAndCommitmentCommitted)
if len(witness.Public) != nbPublicVars-1 {
return fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(witness.Public), len(vk.G1.K)-1)
return 0, fmt.Errorf("invalid witness size, got %d, expected %d (public - ONE_WIRE)", len(witness.Public), len(vk.G1.K)-1)
}

inP := make([]*G1El, len(vk.G1.K)-1) // first is for the one wire, we add it manually after MSM
Expand All @@ -630,11 +641,11 @@ func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El,

opt, err := newCfg(opts...)
if err != nil {
return fmt.Errorf("apply options: %w", err)
return 0, fmt.Errorf("apply options: %w", err)
}
hashToField, err := recursion.NewHash(v.api, fr.Modulus(), true)
if err != nil {
return fmt.Errorf("hash to field: %w", err)
return 0, fmt.Errorf("hash to field: %w", err)
}

maxNbPublicCommitted := 0
Expand Down Expand Up @@ -663,16 +674,16 @@ func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El,
// explicitly do not verify the commitment as there is nothing
case 1:
if err = v.commitment.AssertCommitment(proof.Commitments[0], proof.CommitmentPok, vk.CommitmentKeys[0], opt.pedopt...); err != nil {
return fmt.Errorf("assert commitment: %w", err)
return 0, fmt.Errorf("assert commitment: %w", err)
}
default:
// TODO: we support only a single commitment in the recursion for now
return fmt.Errorf("multiple commitments are not supported")
return 0, fmt.Errorf("multiple commitments are not supported")
}

kSum, err := v.curve.MultiScalarMul(inP, inS, opt.algopt...)
if err != nil {
return fmt.Errorf("multi scalar mul: %w", err)
return 0, fmt.Errorf("multi scalar mul: %w", err)
}
kSum = v.curve.Add(kSum, &vk.G1.K[0])

Expand All @@ -687,10 +698,9 @@ func (v *Verifier[FR, G1El, G2El, GtEl]) AssertProof(vk VerifyingKey[G1El, G2El,
}
pairing, err := v.pairing.Pair([]*G1El{kSum, &proof.Krs, &proof.Ar}, []*G2El{&vk.G2.GammaNeg, &vk.G2.DeltaNeg, &proof.Bs})
if err != nil {
return fmt.Errorf("pairing: %w", err)
return 0, fmt.Errorf("pairing: %w", err)
}
v.pairing.AssertIsEqual(pairing, &vk.E)
return nil
return v.pairing.IsEqual(pairing, &vk.E), nil
}

// SwitchVerification key switches the verification key based on the provided
Expand Down
19 changes: 18 additions & 1 deletion std/signature/ecdsa/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ type PublicKey[Base, Scalar emulated.FieldParams] sw_emulated.AffinePoint[Base]
//
// We assume that the message msg is already hashed to the scalar field.
func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParams, msg *emulated.Element[S], sig *Signature[S]) {
flag := pk.SignIsValid(api, params, msg, sig)
api.AssertIsEqual(flag, 1)
}

// SignIsValid returns 1 if the signature sig verifies for the message msg and
// public key pk or 0 if not. The curve parameters params define the elliptic
// curve.
//
// We assume that the message msg is already hashed to the scalar field.
func (pk PublicKey[T, S]) SignIsValid(api frontend.API, params sw_emulated.CurveParams, msg *emulated.Element[S], sig *Signature[S]) frontend.Variable {
cr, err := sw_emulated.New[T, S](api, params)
if err != nil {
panic(err)
Expand All @@ -43,7 +53,14 @@ func (pk PublicKey[T, S]) Verify(api frontend.API, params sw_emulated.CurveParam
if len(rbits) != len(qxBits) {
panic("non-equal lengths")
}
// store 1 to expect equality
res := frontend.Variable(1)
for i := range rbits {
api.AssertIsEqual(rbits[i], qxBits[i])
// calc the difference between the bits
diff := api.Sub(rbits[i], qxBits[i])
// update the result with the AND of the previous result and the
// equality between the bits (diff == 0)
res = api.And(res, api.IsZero(diff))
}
return res
}
28 changes: 21 additions & 7 deletions std/signature/eddsa/eddsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,22 @@ type Signature struct {
S frontend.Variable
}

// Verify verifies an eddsa signature using MiMC hash function
// Verify checks that an eddsa signature verifies for the message msg and
// public key pk provided using MiMC hash function.
// cf https://en.wikipedia.org/wiki/EdDSA
func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pubKey PublicKey, hash hash.FieldHasher) error {
isValid, err := SignIsValid(curve, sig, msg, pubKey, hash)
if err != nil {
return err
}
curve.API().AssertIsEqual(isValid, 1)
return nil
}

// SignIsValid returns 1 if the signature sig verifies an eddsa signature
// using MiMC hash function for the message msg and public key pk or 0 if not.
// cf https://en.wikipedia.org/wiki/EdDSA
func SignIsValid(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pubKey PublicKey, hash hash.FieldHasher) (frontend.Variable, error) {
// compute H(R, A, M)
hash.Write(sig.R.X)
hash.Write(sig.R.Y)
Expand All @@ -56,7 +68,9 @@ func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pu
//[S]G-[H(R,A,M)]*A
_A := curve.Neg(pubKey.A)
Q := curve.DoubleBaseScalarMul(base, _A, sig.S, hRAM)
curve.AssertIsOnCurve(Q)
// check if Q is on the curve, if not multiply by 0
isOnCurve := curve.IsOnCurve(Q)
Q = curve.ScalarMul(Q, isOnCurve)

//[S]G-[H(R,A,M)]*A-R
Q = curve.Add(curve.Neg(Q), sig.R)
Expand All @@ -66,7 +80,7 @@ func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pu
if !curve.Params().Cofactor.IsUint64() {
err := errors.New("invalid cofactor")
log.Err(err).Str("cofactor", curve.Params().Cofactor.String()).Send()
return err
return nil, err
}
cofactor := curve.Params().Cofactor.Uint64()
switch cofactor {
Expand All @@ -78,10 +92,10 @@ func Verify(curve twistededwards.Curve, sig Signature, msg frontend.Variable, pu
log.Warn().Str("cofactor", curve.Params().Cofactor.String()).Msg("curve cofactor is not implemented")
}

curve.API().AssertIsEqual(Q.X, 0)
curve.API().AssertIsEqual(Q.Y, 1)

return nil
zeroX := curve.API().IsZero(Q.X)
oneY := curve.API().IsZero(curve.API().Sub(Q.Y, 1))
expectedPoint := curve.API().And(zeroX, oneY)
return curve.API().And(isOnCurve, expectedPoint), nil
}

// Assign is a helper to assigned a compressed binary public key representation into its uncompressed form
Expand Down