github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/core/native/crypto_blspoints.go (about)

     1  package native
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"math/big"
     7  
     8  	bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
     9  	"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
    10  )
    11  
    12  // blsPoint is a wrapper around bls12381 point types that must be used as
    13  // stackitem.Interop values and implement stackitem.Equatable interface.
    14  type blsPoint struct {
    15  	point any
    16  }
    17  
    18  var _ = stackitem.Equatable(blsPoint{})
    19  
    20  // Equals implements stackitem.Equatable interface.
    21  func (p blsPoint) Equals(other stackitem.Equatable) bool {
    22  	res, err := p.EqualsCheckType(other)
    23  	return err == nil && res
    24  }
    25  
    26  // EqualsCheckType checks whether other is of the same type as p and returns an error if not.
    27  // It also returns whether other and p are equal.
    28  func (p blsPoint) EqualsCheckType(other stackitem.Equatable) (bool, error) {
    29  	b, ok := other.(blsPoint)
    30  	if !ok {
    31  		return false, errors.New("not a bls12-381 point")
    32  	}
    33  	var (
    34  		res bool
    35  		err error
    36  	)
    37  	switch x := p.point.(type) {
    38  	case *bls12381.G1Affine:
    39  		y, ok := b.point.(*bls12381.G1Affine)
    40  		if !ok {
    41  			err = fmt.Errorf("equal: unexpected y bls12381 point type: %T vs G1Affine", y)
    42  			break
    43  		}
    44  		res = x.Equal(y)
    45  	case *bls12381.G1Jac:
    46  		y, ok := b.point.(*bls12381.G1Jac)
    47  		if !ok {
    48  			err = fmt.Errorf("equal: unexpected y bls12381 point type: %T vs G1Jac", y)
    49  			break
    50  		}
    51  		res = x.Equal(y)
    52  	case *bls12381.G2Affine:
    53  		y, ok := b.point.(*bls12381.G2Affine)
    54  		if !ok {
    55  			err = fmt.Errorf("equal: unexpected y bls12381 point type: %T vs G2Affine", y)
    56  			break
    57  		}
    58  		res = x.Equal(y)
    59  	case *bls12381.G2Jac:
    60  		y, ok := b.point.(*bls12381.G2Jac)
    61  		if !ok {
    62  			err = fmt.Errorf("equal: unexpected y bls12381 point type: %T vs G2Jac", y)
    63  			break
    64  		}
    65  		res = x.Equal(y)
    66  	case *bls12381.GT:
    67  		y, ok := b.point.(*bls12381.GT)
    68  		if !ok {
    69  			err = fmt.Errorf("equal: unexpected y bls12381 point type: %T vs GT", y)
    70  			break
    71  		}
    72  		res = x.Equal(y)
    73  	default:
    74  		err = fmt.Errorf("equal: unexpected x bls12381 point type: %T", x)
    75  	}
    76  
    77  	return res, err
    78  }
    79  
    80  // Bytes returns serialized representation of the provided point in compressed form.
    81  func (p blsPoint) Bytes() []byte {
    82  	switch p := p.point.(type) {
    83  	case *bls12381.G1Affine:
    84  		compressed := p.Bytes()
    85  		return compressed[:]
    86  	case *bls12381.G1Jac:
    87  		g1Affine := new(bls12381.G1Affine)
    88  		g1Affine.FromJacobian(p)
    89  		compressed := g1Affine.Bytes()
    90  		return compressed[:]
    91  	case *bls12381.G2Affine:
    92  		compressed := p.Bytes()
    93  		return compressed[:]
    94  	case *bls12381.G2Jac:
    95  		g2Affine := new(bls12381.G2Affine)
    96  		g2Affine.FromJacobian(p)
    97  		compressed := g2Affine.Bytes()
    98  		return compressed[:]
    99  	case *bls12381.GT:
   100  		compressed := p.Bytes()
   101  		return compressed[:]
   102  	default:
   103  		panic(errors.New("unknown bls12381 point type"))
   104  	}
   105  }
   106  
   107  // FromBytes deserializes BLS12-381 point from the given byte slice in compressed form.
   108  func (p *blsPoint) FromBytes(buf []byte) error {
   109  	switch l := len(buf); l {
   110  	case bls12381.SizeOfG1AffineCompressed:
   111  		g1Affine := new(bls12381.G1Affine)
   112  		_, err := g1Affine.SetBytes(buf)
   113  		if err != nil {
   114  			return fmt.Errorf("failed to decode bls12381 G1Affine point: %w", err)
   115  		}
   116  		p.point = g1Affine
   117  	case bls12381.SizeOfG2AffineCompressed:
   118  		g2Affine := new(bls12381.G2Affine)
   119  		_, err := g2Affine.SetBytes(buf)
   120  		if err != nil {
   121  			return fmt.Errorf("failed to decode bls12381 G2Affine point: %w", err)
   122  		}
   123  		p.point = g2Affine
   124  	case bls12381.SizeOfGT:
   125  		gt := new(bls12381.GT)
   126  		err := gt.SetBytes(buf)
   127  		if err != nil {
   128  			return fmt.Errorf("failed to decode GT point: %w", err)
   129  		}
   130  		p.point = gt
   131  	}
   132  
   133  	return nil
   134  }
   135  
   136  // blsPointAdd performs addition of two BLS12-381 points.
   137  func blsPointAdd(a, b blsPoint) (blsPoint, error) {
   138  	var (
   139  		res any
   140  		err error
   141  	)
   142  	switch x := a.point.(type) {
   143  	case *bls12381.G1Affine:
   144  		switch y := b.point.(type) {
   145  		case *bls12381.G1Affine:
   146  			xJac := new(bls12381.G1Jac)
   147  			xJac.FromAffine(x)
   148  			xJac.AddMixed(y)
   149  			res = xJac
   150  		case *bls12381.G1Jac:
   151  			yJac := new(bls12381.G1Jac)
   152  			yJac.Set(y)
   153  			yJac.AddMixed(x)
   154  			res = yJac
   155  		default:
   156  			err = fmt.Errorf("add: inconsistent bls12381 point types: %T and %T", x, y)
   157  		}
   158  	case *bls12381.G1Jac:
   159  		resJac := new(bls12381.G1Jac)
   160  		resJac.Set(x)
   161  		switch y := b.point.(type) {
   162  		case *bls12381.G1Affine:
   163  			resJac.AddMixed(y)
   164  		case *bls12381.G1Jac:
   165  			resJac.AddAssign(y)
   166  		default:
   167  			err = fmt.Errorf("add: inconsistent bls12381 point types: %T and %T", x, y)
   168  		}
   169  		res = resJac
   170  	case *bls12381.G2Affine:
   171  		switch y := b.point.(type) {
   172  		case *bls12381.G2Affine:
   173  			xJac := new(bls12381.G2Jac)
   174  			xJac.FromAffine(x)
   175  			xJac.AddMixed(y)
   176  			res = xJac
   177  		case *bls12381.G2Jac:
   178  			yJac := new(bls12381.G2Jac)
   179  			yJac.Set(y)
   180  			yJac.AddMixed(x)
   181  			res = yJac
   182  		default:
   183  			err = fmt.Errorf("add: inconsistent bls12381 point types: %T and %T", x, y)
   184  		}
   185  	case *bls12381.G2Jac:
   186  		resJac := new(bls12381.G2Jac)
   187  		resJac.Set(x)
   188  		switch y := b.point.(type) {
   189  		case *bls12381.G2Affine:
   190  			resJac.AddMixed(y)
   191  		case *bls12381.G2Jac:
   192  			resJac.AddAssign(y)
   193  		default:
   194  			err = fmt.Errorf("add: inconsistent bls12381 point types: %T and %T", x, y)
   195  		}
   196  		res = resJac
   197  	case *bls12381.GT:
   198  		resGT := new(bls12381.GT)
   199  		resGT.Set(x)
   200  		switch y := b.point.(type) {
   201  		case *bls12381.GT:
   202  			// It's multiplication, see https://github.com/neo-project/Neo.Cryptography.BLS12_381/issues/4.
   203  			resGT.Mul(x, y)
   204  		default:
   205  			err = fmt.Errorf("add: inconsistent bls12381 point types: %T and %T", x, y)
   206  		}
   207  		res = resGT
   208  	default:
   209  		err = fmt.Errorf("add: unexpected bls12381 point type: %T", x)
   210  	}
   211  
   212  	return blsPoint{point: res}, err
   213  }
   214  
   215  // blsPointAdd performs scalar multiplication of two BLS12-381 points.
   216  func blsPointMul(a blsPoint, alphaBi *big.Int) (blsPoint, error) {
   217  	var (
   218  		res any
   219  		err error
   220  	)
   221  	switch x := a.point.(type) {
   222  	case *bls12381.G1Affine:
   223  		// The result is in Jacobian form in the reference implementation.
   224  		g1Jac := new(bls12381.G1Jac)
   225  		g1Jac.FromAffine(x)
   226  		g1Jac.ScalarMultiplication(g1Jac, alphaBi)
   227  		res = g1Jac
   228  	case *bls12381.G1Jac:
   229  		g1Jac := new(bls12381.G1Jac)
   230  		g1Jac.ScalarMultiplication(x, alphaBi)
   231  		res = g1Jac
   232  	case *bls12381.G2Affine:
   233  		// The result is in Jacobian form in the reference implementation.
   234  		g2Jac := new(bls12381.G2Jac)
   235  		g2Jac.FromAffine(x)
   236  		g2Jac.ScalarMultiplication(g2Jac, alphaBi)
   237  		res = g2Jac
   238  	case *bls12381.G2Jac:
   239  		g2Jac := new(bls12381.G2Jac)
   240  		g2Jac.ScalarMultiplication(x, alphaBi)
   241  		res = g2Jac
   242  	case *bls12381.GT:
   243  		gt := new(bls12381.GT)
   244  
   245  		// C# implementation differs a bit from go's. They use double-and-add algorithm, see
   246  		// https://github.com/neo-project/Neo.Cryptography.BLS12_381/blob/844bc3a4f7d8ba2c545ace90ca124f8ada4c8d29/src/Neo.Cryptography.BLS12_381/Gt.cs#L102
   247  		// and https://en.wikipedia.org/wiki/Elliptic_curve_point_multiplication#Double-and-add,
   248  		// Pay attention that C#'s Gt.Double() squares (not doubles!) the initial GT point.
   249  		// Thus.C#'s scalar multiplication operation over Gt and Scalar is effectively an exponent.
   250  		// Go's exponent algorithm differs a bit from the C#'s double-and-add in that go's one
   251  		// uses 2-bits windowed method for multiplication. However, the resulting GT point is
   252  		// absolutely the same between two implementations.
   253  		gt.Exp(*x, alphaBi)
   254  
   255  		res = gt
   256  	default:
   257  		err = fmt.Errorf("mul: unexpected bls12381 point type: %T", x)
   258  	}
   259  
   260  	return blsPoint{point: res}, err
   261  }
   262  
   263  func blsPointPairing(a, b blsPoint) (blsPoint, error) {
   264  	var (
   265  		x *bls12381.G1Affine
   266  		y *bls12381.G2Affine
   267  	)
   268  	switch p := a.point.(type) {
   269  	case *bls12381.G1Affine:
   270  		x = p
   271  	case *bls12381.G1Jac:
   272  		x = new(bls12381.G1Affine)
   273  		x.FromJacobian(p)
   274  	default:
   275  		return blsPoint{}, fmt.Errorf("pairing: unexpected bls12381 point type (g1): %T", p)
   276  	}
   277  	switch p := b.point.(type) {
   278  	case *bls12381.G2Affine:
   279  		y = p
   280  	case *bls12381.G2Jac:
   281  		y = new(bls12381.G2Affine)
   282  		y.FromJacobian(p)
   283  	default:
   284  		return blsPoint{}, fmt.Errorf("pairing: unexpected bls12381 point type (g2): %T", p)
   285  	}
   286  
   287  	gt, err := bls12381.Pair([]bls12381.G1Affine{*x}, []bls12381.G2Affine{*y})
   288  	if err != nil {
   289  		return blsPoint{}, fmt.Errorf("failed to perform pairing operation: %w", err)
   290  	}
   291  
   292  	return blsPoint{&gt}, nil
   293  }