github.com/ethereum/go-ethereum@v1.16.1/tests/fuzzers/bls12381/bls12381_fuzz.go (about)

     1  // Copyright 2021 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  //go:build cgo
    18  // +build cgo
    19  
    20  package bls
    21  
    22  import (
    23  	"bytes"
    24  	"crypto/rand"
    25  	"fmt"
    26  	"io"
    27  	"math/big"
    28  
    29  	"github.com/consensys/gnark-crypto/ecc"
    30  	gnark "github.com/consensys/gnark-crypto/ecc/bls12-381"
    31  	"github.com/consensys/gnark-crypto/ecc/bls12-381/fp"
    32  	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
    33  	"github.com/ethereum/go-ethereum/common"
    34  	blst "github.com/supranational/blst/bindings/go"
    35  )
    36  
    37  func fuzzG1SubgroupChecks(data []byte) int {
    38  	input := bytes.NewReader(data)
    39  	cpG1, blG1, err := getG1Points(input)
    40  	if err != nil {
    41  		return 0
    42  	}
    43  	inSubGroupGnark := cpG1.IsInSubGroup()
    44  	inSubGroupBLST := blG1.InG1()
    45  	if inSubGroupGnark != inSubGroupBLST {
    46  		panic(fmt.Sprintf("differing subgroup check, gnark %v, blst %v", inSubGroupGnark, inSubGroupBLST))
    47  	}
    48  	return 1
    49  }
    50  
    51  func fuzzG2SubgroupChecks(data []byte) int {
    52  	input := bytes.NewReader(data)
    53  	gpG2, blG2, err := getG2Points(input)
    54  	if err != nil {
    55  		return 0
    56  	}
    57  	inSubGroupGnark := gpG2.IsInSubGroup()
    58  	inSubGroupBLST := blG2.InG2()
    59  	if inSubGroupGnark != inSubGroupBLST {
    60  		panic(fmt.Sprintf("differing subgroup check, gnark %v, blst %v", inSubGroupGnark, inSubGroupBLST))
    61  	}
    62  	return 1
    63  }
    64  
    65  func fuzzCrossPairing(data []byte) int {
    66  	input := bytes.NewReader(data)
    67  
    68  	// get random G1 points
    69  	cpG1, blG1, err := getG1Points(input)
    70  	if err != nil {
    71  		return 0
    72  	}
    73  
    74  	// get random G2 points
    75  	cpG2, blG2, err := getG2Points(input)
    76  	if err != nil {
    77  		return 0
    78  	}
    79  
    80  	// compute pairing using gnark
    81  	cResult, err := gnark.Pair([]gnark.G1Affine{*cpG1}, []gnark.G2Affine{*cpG2})
    82  	if err != nil {
    83  		panic(fmt.Sprintf("gnark/bls12381 encountered error: %v", err))
    84  	}
    85  
    86  	// compute pairing using blst
    87  	blstResult := blst.Fp12MillerLoop(blG2, blG1)
    88  	blstResult.FinalExp()
    89  	res := massageBLST(blstResult.ToBendian())
    90  	if !(bytes.Equal(res, cResult.Marshal())) {
    91  		panic("pairing mismatch blst / geth")
    92  	}
    93  
    94  	return 1
    95  }
    96  
    97  func massageBLST(in []byte) []byte {
    98  	out := make([]byte, len(in))
    99  	len := 12 * 48
   100  	// 1
   101  	copy(out[0:], in[len-1*48:len])
   102  	copy(out[1*48:], in[len-2*48:len-1*48])
   103  	// 2
   104  	copy(out[6*48:], in[len-3*48:len-2*48])
   105  	copy(out[7*48:], in[len-4*48:len-3*48])
   106  	// 3
   107  	copy(out[2*48:], in[len-5*48:len-4*48])
   108  	copy(out[3*48:], in[len-6*48:len-5*48])
   109  	// 4
   110  	copy(out[8*48:], in[len-7*48:len-6*48])
   111  	copy(out[9*48:], in[len-8*48:len-7*48])
   112  	// 5
   113  	copy(out[4*48:], in[len-9*48:len-8*48])
   114  	copy(out[5*48:], in[len-10*48:len-9*48])
   115  	// 6
   116  	copy(out[10*48:], in[len-11*48:len-10*48])
   117  	copy(out[11*48:], in[len-12*48:len-11*48])
   118  	return out
   119  }
   120  
   121  func fuzzCrossG1Add(data []byte) int {
   122  	input := bytes.NewReader(data)
   123  
   124  	// get random G1 points
   125  	cp1, bl1, err := getG1Points(input)
   126  	if err != nil {
   127  		return 0
   128  	}
   129  
   130  	// get random G1 points
   131  	cp2, bl2, err := getG1Points(input)
   132  	if err != nil {
   133  		return 0
   134  	}
   135  
   136  	// compute cp = cp1 + cp2
   137  	_cp1 := new(gnark.G1Jac).FromAffine(cp1)
   138  	_cp2 := new(gnark.G1Jac).FromAffine(cp2)
   139  	cp := new(gnark.G1Affine).FromJacobian(_cp1.AddAssign(_cp2))
   140  
   141  	bl3 := blst.P1AffinesAdd([]*blst.P1Affine{bl1, bl2})
   142  	if !(bytes.Equal(cp.Marshal(), bl3.Serialize())) {
   143  		panic("G1 point addition mismatch blst / geth ")
   144  	}
   145  
   146  	return 1
   147  }
   148  
   149  func fuzzCrossG2Add(data []byte) int {
   150  	input := bytes.NewReader(data)
   151  
   152  	// get random G2 points
   153  	gp1, bl1, err := getG2Points(input)
   154  	if err != nil {
   155  		return 0
   156  	}
   157  
   158  	// get random G2 points
   159  	gp2, bl2, err := getG2Points(input)
   160  	if err != nil {
   161  		return 0
   162  	}
   163  
   164  	// compute cp = cp1 + cp2
   165  	_gp1 := new(gnark.G2Jac).FromAffine(gp1)
   166  	_gp2 := new(gnark.G2Jac).FromAffine(gp2)
   167  	gp := new(gnark.G2Affine).FromJacobian(_gp1.AddAssign(_gp2))
   168  
   169  	bl3 := blst.P2AffinesAdd([]*blst.P2Affine{bl1, bl2})
   170  	if !(bytes.Equal(gp.Marshal(), bl3.Serialize())) {
   171  		panic("G2 point addition mismatch blst / geth ")
   172  	}
   173  
   174  	return 1
   175  }
   176  
   177  func fuzzCrossG1MultiExp(data []byte) int {
   178  	var (
   179  		input        = bytes.NewReader(data)
   180  		gnarkScalars []fr.Element
   181  		gnarkPoints  []gnark.G1Affine
   182  		blstScalars  []*blst.Scalar
   183  		blstPoints   []*blst.P1Affine
   184  	)
   185  	// n random scalars (max 17)
   186  	for i := 0; i < 17; i++ {
   187  		// note that geth/crypto/bls12381 works only with scalars <= 32bytes
   188  		s, err := randomScalar(input, fr.Modulus())
   189  		if err != nil {
   190  			break
   191  		}
   192  		// get a random G1 point as basis
   193  		cp1, bl1, err := getG1Points(input)
   194  		if err != nil {
   195  			break
   196  		}
   197  
   198  		gnarkScalar := new(fr.Element).SetBigInt(s)
   199  		gnarkScalars = append(gnarkScalars, *gnarkScalar)
   200  		gnarkPoints = append(gnarkPoints, *cp1)
   201  
   202  		blstScalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32))
   203  		blstScalars = append(blstScalars, blstScalar)
   204  		blstPoints = append(blstPoints, bl1)
   205  	}
   206  
   207  	if len(gnarkScalars) == 0 || len(gnarkScalars) != len(gnarkPoints) {
   208  		return 0
   209  	}
   210  
   211  	// gnark multi exp
   212  	cp := new(gnark.G1Affine)
   213  	cp.MultiExp(gnarkPoints, gnarkScalars, ecc.MultiExpConfig{})
   214  
   215  	expectedGnark := multiExpG1Gnark(gnarkPoints, gnarkScalars)
   216  	if !bytes.Equal(cp.Marshal(), expectedGnark.Marshal()) {
   217  		panic("g1 multi exponentiation mismatch")
   218  	}
   219  
   220  	// blst multi exp
   221  	expectedBlst := blst.P1AffinesMult(blstPoints, blstScalars, 256).ToAffine()
   222  	if !bytes.Equal(cp.Marshal(), expectedBlst.Serialize()) {
   223  		panic("g1 multi exponentiation mismatch, gnark/blst")
   224  	}
   225  	return 1
   226  }
   227  
   228  func fuzzCrossG2MultiExp(data []byte) int {
   229  	var (
   230  		input        = bytes.NewReader(data)
   231  		gnarkScalars []fr.Element
   232  		gnarkPoints  []gnark.G2Affine
   233  		blstScalars  []*blst.Scalar
   234  		blstPoints   []*blst.P2Affine
   235  	)
   236  	// n random scalars (max 17)
   237  	for i := 0; i < 17; i++ {
   238  		// note that geth/crypto/bls12381 works only with scalars <= 32bytes
   239  		s, err := randomScalar(input, fr.Modulus())
   240  		if err != nil {
   241  			break
   242  		}
   243  		// get a random G1 point as basis
   244  		cp1, bl1, err := getG2Points(input)
   245  		if err != nil {
   246  			break
   247  		}
   248  
   249  		gnarkScalar := new(fr.Element).SetBigInt(s)
   250  		gnarkScalars = append(gnarkScalars, *gnarkScalar)
   251  		gnarkPoints = append(gnarkPoints, *cp1)
   252  
   253  		blstScalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32))
   254  		blstScalars = append(blstScalars, blstScalar)
   255  		blstPoints = append(blstPoints, bl1)
   256  	}
   257  
   258  	if len(gnarkScalars) == 0 || len(gnarkScalars) != len(gnarkPoints) {
   259  		return 0
   260  	}
   261  
   262  	// gnark multi exp
   263  	cp := new(gnark.G2Affine)
   264  	cp.MultiExp(gnarkPoints, gnarkScalars, ecc.MultiExpConfig{})
   265  
   266  	expectedGnark := multiExpG2Gnark(gnarkPoints, gnarkScalars)
   267  	if !bytes.Equal(cp.Marshal(), expectedGnark.Marshal()) {
   268  		panic("g1 multi exponentiation mismatch")
   269  	}
   270  
   271  	// blst multi exp
   272  	expectedBlst := blst.P2AffinesMult(blstPoints, blstScalars, 256).ToAffine()
   273  	if !bytes.Equal(cp.Marshal(), expectedBlst.Serialize()) {
   274  		panic("g1 multi exponentiation mismatch, gnark/blst")
   275  	}
   276  	return 1
   277  }
   278  
   279  func getG1Points(input io.Reader) (*gnark.G1Affine, *blst.P1Affine, error) {
   280  	// sample a random scalar
   281  	s, err := randomScalar(input, fp.Modulus())
   282  	if err != nil {
   283  		return nil, nil, err
   284  	}
   285  
   286  	// compute a random point
   287  	cp := new(gnark.G1Affine)
   288  	_, _, g1Gen, _ := gnark.Generators()
   289  	cp.ScalarMultiplication(&g1Gen, s)
   290  	cpBytes := cp.Marshal()
   291  
   292  	// marshal gnark point -> blst point
   293  	scalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32))
   294  	p1 := new(blst.P1Affine).From(scalar)
   295  	blstRes := p1.Serialize()
   296  	if !bytes.Equal(blstRes, cpBytes) {
   297  		panic(fmt.Sprintf("bytes(blst.G1) != bytes(geth.G1)\nblst.G1: %x\ngeth.G1: %x\n", blstRes, cpBytes))
   298  	}
   299  
   300  	return cp, p1, nil
   301  }
   302  
   303  func getG2Points(input io.Reader) (*gnark.G2Affine, *blst.P2Affine, error) {
   304  	// sample a random scalar
   305  	s, err := randomScalar(input, fp.Modulus())
   306  	if err != nil {
   307  		return nil, nil, err
   308  	}
   309  
   310  	// compute a random point
   311  	gp := new(gnark.G2Affine)
   312  	_, _, _, g2Gen := gnark.Generators()
   313  	gp.ScalarMultiplication(&g2Gen, s)
   314  	cpBytes := gp.Marshal()
   315  
   316  	// marshal gnark point -> blst point
   317  	// Left pad the scalar to 32 bytes
   318  	scalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32))
   319  	p2 := new(blst.P2Affine).From(scalar)
   320  	if !bytes.Equal(p2.Serialize(), cpBytes) {
   321  		panic("bytes(blst.G2) != bytes(bls12381.G2)")
   322  	}
   323  
   324  	return gp, p2, nil
   325  }
   326  
   327  func randomScalar(r io.Reader, max *big.Int) (k *big.Int, err error) {
   328  	for {
   329  		k, err = rand.Int(r, max)
   330  		if err != nil || k.Sign() > 0 {
   331  			return
   332  		}
   333  	}
   334  }
   335  
   336  // multiExpG1Gnark is a naive implementation of G1 multi-exponentiation
   337  func multiExpG1Gnark(gs []gnark.G1Affine, scalars []fr.Element) gnark.G1Affine {
   338  	res := gnark.G1Affine{}
   339  	for i := 0; i < len(gs); i++ {
   340  		tmp := new(gnark.G1Affine)
   341  		sb := scalars[i].Bytes()
   342  		scalarBytes := new(big.Int).SetBytes(sb[:])
   343  		tmp.ScalarMultiplication(&gs[i], scalarBytes)
   344  		res.Add(&res, tmp)
   345  	}
   346  	return res
   347  }
   348  
   349  // multiExpG2Gnark is a naive implementation of G2 multi-exponentiation
   350  func multiExpG2Gnark(gs []gnark.G2Affine, scalars []fr.Element) gnark.G2Affine {
   351  	res := gnark.G2Affine{}
   352  	for i := 0; i < len(gs); i++ {
   353  		tmp := new(gnark.G2Affine)
   354  		sb := scalars[i].Bytes()
   355  		scalarBytes := new(big.Int).SetBytes(sb[:])
   356  		tmp.ScalarMultiplication(&gs[i], scalarBytes)
   357  		res.Add(&res, tmp)
   358  	}
   359  	return res
   360  }