github.com/zhiqiangxu/util@v0.0.0-20230112053021-0a7aee056cd5/crypto/vrf/p256/unmarshal.go (about)

     1  package p256
     2  
     3  import (
     4  	"crypto/elliptic"
     5  	"errors"
     6  	"math/big"
     7  )
     8  
     9  var (
    10  	errInvalidData = errors.New("invalid data")
    11  )
    12  
    13  // Unmarshal a compressed point in the form specified in section 4.3.6 of ANSI X9.62.
    14  func Unmarshal(curve elliptic.Curve, data []byte) (x, y *big.Int, err error) {
    15  	if (data[0] &^ 1) != 2 {
    16  		err = errInvalidData
    17  		return
    18  	}
    19  
    20  	byteLen := (curve.Params().BitSize + 7) >> 3
    21  	if len(data) != 1+byteLen {
    22  		err = errInvalidData
    23  		return
    24  	}
    25  
    26  	// Based on Routine 2.2.4 in NIST Mathematical routines paper
    27  	params := curve.Params()
    28  	tx := new(big.Int).SetBytes(data[1 : 1+byteLen])
    29  	y2 := y2(params, tx)
    30  	sqrt := defaultSqrt
    31  	ty := sqrt(y2, params.P)
    32  	if ty == nil {
    33  		// "y^2" is not a square: invalid point
    34  		err = errInvalidData
    35  		return
    36  	}
    37  
    38  	var y2c big.Int
    39  	y2c.Mul(ty, ty).Mod(&y2c, params.P)
    40  	if y2c.Cmp(y2) != 0 {
    41  		// sqrt(y2)^2 != y2: invalid point
    42  		err = errInvalidData
    43  		return
    44  	}
    45  
    46  	if ty.Bit(0) != uint(data[0]&1) {
    47  		ty.Sub(params.P, ty)
    48  	}
    49  
    50  	x, y = tx, ty // valid point: return it
    51  	return
    52  }
    53  
    54  // Use the curve equation to calculate y² given x.
    55  // only applies to curves of the form y² = x³ - 3x + b.
    56  func y2(curve *elliptic.CurveParams, x *big.Int) *big.Int {
    57  	// y² = x³ - 3x + b
    58  	x3 := new(big.Int).Mul(x, x)
    59  	x3.Mul(x3, x)
    60  
    61  	threeX := new(big.Int).Lsh(x, 1)
    62  	threeX.Add(threeX, x)
    63  
    64  	x3.Sub(x3, threeX)
    65  	x3.Add(x3, curve.B)
    66  	x3.Mod(x3, curve.P)
    67  	return x3
    68  }
    69  
    70  func defaultSqrt(x, p *big.Int) *big.Int {
    71  	var r big.Int
    72  	if nil == r.ModSqrt(x, p) {
    73  		return nil // x is not a square
    74  	}
    75  	return &r
    76  }