github.com/carter-ya/go-ethereum@v0.0.0-20230628080049-d2309be3983b/tests/fuzzers/bls12381/bls12381_fuzz.go (about)

     1  // Copyright 2021 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  //go:build gofuzz
    18  // +build gofuzz
    19  
    20  package bls
    21  
    22  import (
    23  	"bytes"
    24  	"crypto/rand"
    25  	"fmt"
    26  	"io"
    27  	"math/big"
    28  
    29  	gnark "github.com/consensys/gnark-crypto/ecc/bls12-381"
    30  	"github.com/consensys/gnark-crypto/ecc/bls12-381/fp"
    31  	"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
    32  	"github.com/ethereum/go-ethereum/common"
    33  	"github.com/ethereum/go-ethereum/crypto/bls12381"
    34  	blst "github.com/supranational/blst/bindings/go"
    35  )
    36  
    37  func FuzzCrossPairing(data []byte) int {
    38  	input := bytes.NewReader(data)
    39  
    40  	// get random G1 points
    41  	kpG1, cpG1, blG1, err := getG1Points(input)
    42  	if err != nil {
    43  		return 0
    44  	}
    45  
    46  	// get random G2 points
    47  	kpG2, cpG2, blG2, err := getG2Points(input)
    48  	if err != nil {
    49  		return 0
    50  	}
    51  
    52  	// compute pairing using geth
    53  	engine := bls12381.NewPairingEngine()
    54  	engine.AddPair(kpG1, kpG2)
    55  	kResult := engine.Result()
    56  
    57  	// compute pairing using gnark
    58  	cResult, err := gnark.Pair([]gnark.G1Affine{*cpG1}, []gnark.G2Affine{*cpG2})
    59  	if err != nil {
    60  		panic(fmt.Sprintf("gnark/bls12381 encountered error: %v", err))
    61  	}
    62  
    63  	// compare result
    64  	if !(bytes.Equal(cResult.Marshal(), bls12381.NewGT().ToBytes(kResult))) {
    65  		panic("pairing mismatch gnark / geth ")
    66  	}
    67  
    68  	// compute pairing using blst
    69  	blstResult := blst.Fp12MillerLoop(blG2, blG1)
    70  	blstResult.FinalExp()
    71  	res := massageBLST(blstResult.ToBendian())
    72  	if !(bytes.Equal(res, bls12381.NewGT().ToBytes(kResult))) {
    73  		panic("pairing mismatch blst / geth")
    74  	}
    75  
    76  	return 1
    77  }
    78  
    79  func massageBLST(in []byte) []byte {
    80  	out := make([]byte, len(in))
    81  	len := 12 * 48
    82  	// 1
    83  	copy(out[0:], in[len-1*48:len])
    84  	copy(out[1*48:], in[len-2*48:len-1*48])
    85  	// 2
    86  	copy(out[6*48:], in[len-3*48:len-2*48])
    87  	copy(out[7*48:], in[len-4*48:len-3*48])
    88  	// 3
    89  	copy(out[2*48:], in[len-5*48:len-4*48])
    90  	copy(out[3*48:], in[len-6*48:len-5*48])
    91  	// 4
    92  	copy(out[8*48:], in[len-7*48:len-6*48])
    93  	copy(out[9*48:], in[len-8*48:len-7*48])
    94  	// 5
    95  	copy(out[4*48:], in[len-9*48:len-8*48])
    96  	copy(out[5*48:], in[len-10*48:len-9*48])
    97  	// 6
    98  	copy(out[10*48:], in[len-11*48:len-10*48])
    99  	copy(out[11*48:], in[len-12*48:len-11*48])
   100  	return out
   101  }
   102  
   103  func FuzzCrossG1Add(data []byte) int {
   104  	input := bytes.NewReader(data)
   105  
   106  	// get random G1 points
   107  	kp1, cp1, bl1, err := getG1Points(input)
   108  	if err != nil {
   109  		return 0
   110  	}
   111  
   112  	// get random G1 points
   113  	kp2, cp2, bl2, err := getG1Points(input)
   114  	if err != nil {
   115  		return 0
   116  	}
   117  
   118  	// compute kp = kp1 + kp2
   119  	g1 := bls12381.NewG1()
   120  	kp := bls12381.PointG1{}
   121  	g1.Add(&kp, kp1, kp2)
   122  
   123  	// compute cp = cp1 + cp2
   124  	_cp1 := new(gnark.G1Jac).FromAffine(cp1)
   125  	_cp2 := new(gnark.G1Jac).FromAffine(cp2)
   126  	cp := new(gnark.G1Affine).FromJacobian(_cp1.AddAssign(_cp2))
   127  
   128  	// compare result
   129  	if !(bytes.Equal(cp.Marshal(), g1.ToBytes(&kp))) {
   130  		panic("G1 point addition mismatch gnark / geth ")
   131  	}
   132  
   133  	bl3 := blst.P1AffinesAdd([]*blst.P1Affine{bl1, bl2})
   134  	if !(bytes.Equal(cp.Marshal(), bl3.Serialize())) {
   135  		panic("G1 point addition mismatch blst / geth ")
   136  	}
   137  
   138  	return 1
   139  }
   140  
   141  func FuzzCrossG2Add(data []byte) int {
   142  	input := bytes.NewReader(data)
   143  
   144  	// get random G2 points
   145  	kp1, cp1, bl1, err := getG2Points(input)
   146  	if err != nil {
   147  		return 0
   148  	}
   149  
   150  	// get random G2 points
   151  	kp2, cp2, bl2, err := getG2Points(input)
   152  	if err != nil {
   153  		return 0
   154  	}
   155  
   156  	// compute kp = kp1 + kp2
   157  	g2 := bls12381.NewG2()
   158  	kp := bls12381.PointG2{}
   159  	g2.Add(&kp, kp1, kp2)
   160  
   161  	// compute cp = cp1 + cp2
   162  	_cp1 := new(gnark.G2Jac).FromAffine(cp1)
   163  	_cp2 := new(gnark.G2Jac).FromAffine(cp2)
   164  	cp := new(gnark.G2Affine).FromJacobian(_cp1.AddAssign(_cp2))
   165  
   166  	// compare result
   167  	if !(bytes.Equal(cp.Marshal(), g2.ToBytes(&kp))) {
   168  		panic("G2 point addition mismatch gnark / geth ")
   169  	}
   170  
   171  	bl3 := blst.P2AffinesAdd([]*blst.P2Affine{bl1, bl2})
   172  	if !(bytes.Equal(cp.Marshal(), bl3.Serialize())) {
   173  		panic("G1 point addition mismatch blst / geth ")
   174  	}
   175  
   176  	return 1
   177  }
   178  
   179  func FuzzCrossG1MultiExp(data []byte) int {
   180  	var (
   181  		input        = bytes.NewReader(data)
   182  		gethScalars  []*big.Int
   183  		gnarkScalars []fr.Element
   184  		gethPoints   []*bls12381.PointG1
   185  		gnarkPoints  []gnark.G1Affine
   186  	)
   187  	// n random scalars (max 17)
   188  	for i := 0; i < 17; i++ {
   189  		// note that geth/crypto/bls12381 works only with scalars <= 32bytes
   190  		s, err := randomScalar(input, fr.Modulus())
   191  		if err != nil {
   192  			break
   193  		}
   194  		// get a random G1 point as basis
   195  		kp1, cp1, _, err := getG1Points(input)
   196  		if err != nil {
   197  			break
   198  		}
   199  		gethScalars = append(gethScalars, s)
   200  		var gnarkScalar = &fr.Element{}
   201  		gnarkScalar = gnarkScalar.SetBigInt(s).FromMont()
   202  		gnarkScalars = append(gnarkScalars, *gnarkScalar)
   203  
   204  		gethPoints = append(gethPoints, new(bls12381.PointG1).Set(kp1))
   205  		gnarkPoints = append(gnarkPoints, *cp1)
   206  	}
   207  	if len(gethScalars) == 0 {
   208  		return 0
   209  	}
   210  	// compute multi exponentiation
   211  	g1 := bls12381.NewG1()
   212  	kp := bls12381.PointG1{}
   213  	if _, err := g1.MultiExp(&kp, gethPoints, gethScalars); err != nil {
   214  		panic(fmt.Sprintf("G1 multi exponentiation errored (geth): %v", err))
   215  	}
   216  	// note that geth/crypto/bls12381.MultiExp mutates the scalars slice (and sets all the scalars to zero)
   217  
   218  	// gnark multi exp
   219  	cp := new(gnark.G1Affine)
   220  	cp.MultiExp(gnarkPoints, gnarkScalars)
   221  
   222  	// compare result
   223  	if !(bytes.Equal(cp.Marshal(), g1.ToBytes(&kp))) {
   224  		panic("G1 multi exponentiation mismatch gnark / geth ")
   225  	}
   226  
   227  	return 1
   228  }
   229  
   230  func getG1Points(input io.Reader) (*bls12381.PointG1, *gnark.G1Affine, *blst.P1Affine, error) {
   231  	// sample a random scalar
   232  	s, err := randomScalar(input, fp.Modulus())
   233  	if err != nil {
   234  		return nil, nil, nil, err
   235  	}
   236  
   237  	// compute a random point
   238  	cp := new(gnark.G1Affine)
   239  	_, _, g1Gen, _ := gnark.Generators()
   240  	cp.ScalarMultiplication(&g1Gen, s)
   241  	cpBytes := cp.Marshal()
   242  
   243  	// marshal gnark point -> geth point
   244  	g1 := bls12381.NewG1()
   245  	kp, err := g1.FromBytes(cpBytes)
   246  	if err != nil {
   247  		panic(fmt.Sprintf("Could not marshal gnark.G1 -> geth.G1: %v", err))
   248  	}
   249  	if !bytes.Equal(g1.ToBytes(kp), cpBytes) {
   250  		panic("bytes(gnark.G1) != bytes(geth.G1)")
   251  	}
   252  
   253  	// marshal gnark point -> blst point
   254  	scalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32))
   255  	p1 := new(blst.P1Affine).From(scalar)
   256  	if !bytes.Equal(p1.Serialize(), cpBytes) {
   257  		panic("bytes(blst.G1) != bytes(geth.G1)")
   258  	}
   259  
   260  	return kp, cp, p1, nil
   261  }
   262  
   263  func getG2Points(input io.Reader) (*bls12381.PointG2, *gnark.G2Affine, *blst.P2Affine, error) {
   264  	// sample a random scalar
   265  	s, err := randomScalar(input, fp.Modulus())
   266  	if err != nil {
   267  		return nil, nil, nil, err
   268  	}
   269  
   270  	// compute a random point
   271  	cp := new(gnark.G2Affine)
   272  	_, _, _, g2Gen := gnark.Generators()
   273  	cp.ScalarMultiplication(&g2Gen, s)
   274  	cpBytes := cp.Marshal()
   275  
   276  	// marshal gnark point -> geth point
   277  	g2 := bls12381.NewG2()
   278  	kp, err := g2.FromBytes(cpBytes)
   279  	if err != nil {
   280  		panic(fmt.Sprintf("Could not marshal gnark.G2 -> geth.G2: %v", err))
   281  	}
   282  	if !bytes.Equal(g2.ToBytes(kp), cpBytes) {
   283  		panic("bytes(gnark.G2) != bytes(geth.G2)")
   284  	}
   285  
   286  	// marshal gnark point -> blst point
   287  	// Left pad the scalar to 32 bytes
   288  	scalar := new(blst.Scalar).FromBEndian(common.LeftPadBytes(s.Bytes(), 32))
   289  	p2 := new(blst.P2Affine).From(scalar)
   290  	if !bytes.Equal(p2.Serialize(), cpBytes) {
   291  		panic("bytes(blst.G2) != bytes(geth.G2)")
   292  	}
   293  
   294  	return kp, cp, p2, nil
   295  }
   296  
   297  func randomScalar(r io.Reader, max *big.Int) (k *big.Int, err error) {
   298  	for {
   299  		k, err = rand.Int(r, max)
   300  		if err != nil || k.Sign() > 0 {
   301  			return
   302  		}
   303  	}
   304  }