github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/crypto/keys/publickey.go (about)

     1  package keys
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/elliptic"
     6  	"crypto/x509"
     7  	"encoding/hex"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"math/big"
    12  
    13  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    14  	lru "github.com/hashicorp/golang-lru/v2"
    15  	"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
    16  	"github.com/nspcc-dev/neo-go/pkg/encoding/address"
    17  	"github.com/nspcc-dev/neo-go/pkg/io"
    18  	"github.com/nspcc-dev/neo-go/pkg/util"
    19  	"github.com/nspcc-dev/neo-go/pkg/vm/emit"
    20  )
    21  
    22  // coordLen is the number of bytes in serialized X or Y coordinate.
    23  const coordLen = 32
    24  
    25  // SignatureLen is the length of a standard signature for 256-bit EC key.
    26  const SignatureLen = 64
    27  
    28  // PublicKeys is a list of public keys.
    29  type PublicKeys []*PublicKey
    30  
    31  var big0 = big.NewInt(0)
    32  var big3 = big.NewInt(3)
    33  
    34  // NewPublicKeysFromStrings converts an array of string-encoded P256 public keys
    35  // into an array of PublicKeys.
    36  func NewPublicKeysFromStrings(ss []string) (PublicKeys, error) {
    37  	arr := make([]*PublicKey, len(ss))
    38  	for i := range ss {
    39  		pubKey, err := NewPublicKeyFromString(ss[i])
    40  		if err != nil {
    41  			return nil, err
    42  		}
    43  		arr[i] = pubKey
    44  	}
    45  	return PublicKeys(arr), nil
    46  }
    47  
    48  func (keys PublicKeys) Len() int      { return len(keys) }
    49  func (keys PublicKeys) Swap(i, j int) { keys[i], keys[j] = keys[j], keys[i] }
    50  func (keys PublicKeys) Less(i, j int) bool {
    51  	return keys[i].Cmp(keys[j]) == -1
    52  }
    53  
    54  // DecodeBytes decodes a PublicKeys from the given slice of bytes.
    55  func (keys *PublicKeys) DecodeBytes(data []byte) error {
    56  	b := io.NewBinReaderFromBuf(data)
    57  	b.ReadArray(keys)
    58  	return b.Err
    59  }
    60  
    61  // Bytes encodes PublicKeys to the new slice of bytes.
    62  func (keys *PublicKeys) Bytes() []byte {
    63  	buf := io.NewBufBinWriter()
    64  	buf.WriteArray(*keys)
    65  	if buf.Err != nil {
    66  		panic(buf.Err)
    67  	}
    68  	return buf.Bytes()
    69  }
    70  
    71  // Contains checks whether the passed param is contained in PublicKeys.
    72  func (keys PublicKeys) Contains(pKey *PublicKey) bool {
    73  	for _, key := range keys {
    74  		if key.Equal(pKey) {
    75  			return true
    76  		}
    77  	}
    78  	return false
    79  }
    80  
    81  // Copy returns a shallow copy of the PublicKeys slice. It creates a new slice with the same elements,
    82  // but does not perform a deep copy of the elements themselves.
    83  func (keys PublicKeys) Copy() PublicKeys {
    84  	if keys == nil {
    85  		return nil
    86  	}
    87  	res := make(PublicKeys, len(keys))
    88  	copy(res, keys)
    89  	return res
    90  }
    91  
    92  // Unique returns a set of public keys.
    93  func (keys PublicKeys) Unique() PublicKeys {
    94  	unique := PublicKeys{}
    95  	for _, publicKey := range keys {
    96  		if !unique.Contains(publicKey) {
    97  			unique = append(unique, publicKey)
    98  		}
    99  	}
   100  	return unique
   101  }
   102  
   103  // PublicKey represents a public key and provides a high level
   104  // API around ecdsa.PublicKey.
   105  type PublicKey ecdsa.PublicKey
   106  
   107  // Equal returns true in case public keys are equal.
   108  func (p *PublicKey) Equal(key *PublicKey) bool {
   109  	return p.X.Cmp(key.X) == 0 && p.Y.Cmp(key.Y) == 0
   110  }
   111  
   112  // Cmp compares two keys.
   113  func (p *PublicKey) Cmp(key *PublicKey) int {
   114  	xCmp := p.X.Cmp(key.X)
   115  	if xCmp != 0 {
   116  		return xCmp
   117  	}
   118  	return p.Y.Cmp(key.Y)
   119  }
   120  
   121  // NewPublicKeyFromString returns a public key created from the
   122  // given hex string public key representation in compressed form.
   123  func NewPublicKeyFromString(s string) (*PublicKey, error) {
   124  	b, err := hex.DecodeString(s)
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  	return NewPublicKeyFromBytes(b, elliptic.P256())
   129  }
   130  
   131  // keycache is a simple lru cache for P256 keys that avoids Y calculation overhead
   132  // for known keys.
   133  var keycache *lru.Cache[string, *PublicKey]
   134  
   135  func init() {
   136  	// Less than 100K, probably enough for our purposes.
   137  	keycache, _ = lru.New[string, *PublicKey](1024)
   138  }
   139  
   140  // NewPublicKeyFromBytes returns a public key created from b using the given EC.
   141  func NewPublicKeyFromBytes(b []byte, curve elliptic.Curve) (*PublicKey, error) {
   142  	pubKey, ok := keycache.Get(string(b))
   143  	if ok && pubKey.Curve == curve {
   144  		return pubKey, nil
   145  	}
   146  	pubKey = new(PublicKey)
   147  	pubKey.Curve = curve
   148  	if err := pubKey.DecodeBytes(b); err != nil {
   149  		return nil, err
   150  	}
   151  	keycache.Add(string(b), pubKey)
   152  	return pubKey, nil
   153  }
   154  
   155  // getBytes serializes X and Y using compressed or uncompressed format.
   156  func (p *PublicKey) getBytes(compressed bool) []byte {
   157  	if p.IsInfinity() {
   158  		return []byte{0x00}
   159  	}
   160  
   161  	if compressed {
   162  		return elliptic.MarshalCompressed(p.Curve, p.X, p.Y)
   163  	}
   164  	return elliptic.Marshal(p.Curve, p.X, p.Y)
   165  }
   166  
   167  // Bytes returns byte array representation of the public key in compressed
   168  // form (33 bytes with 0x02 or 0x03 prefix, except infinity which is always 0).
   169  func (p *PublicKey) Bytes() []byte {
   170  	return p.getBytes(true)
   171  }
   172  
   173  // UncompressedBytes returns byte array representation of the public key in
   174  // uncompressed form (65 bytes with 0x04 prefix, except infinity which is
   175  // always 0).
   176  func (p *PublicKey) UncompressedBytes() []byte {
   177  	return p.getBytes(false)
   178  }
   179  
   180  // NewPublicKeyFromASN1 returns a NEO PublicKey from the ASN.1 serialized key.
   181  func NewPublicKeyFromASN1(data []byte) (*PublicKey, error) {
   182  	var (
   183  		err    error
   184  		pubkey any
   185  	)
   186  	if pubkey, err = x509.ParsePKIXPublicKey(data); err != nil {
   187  		return nil, err
   188  	}
   189  	pk, ok := pubkey.(*ecdsa.PublicKey)
   190  	if !ok {
   191  		return nil, errors.New("given bytes aren't ECDSA public key")
   192  	}
   193  	result := PublicKey(*pk)
   194  	return &result, nil
   195  }
   196  
   197  // decodeCompressedY performs decompression of Y coordinate for the given X and Y's least significant bit.
   198  // We use here a short-form Weierstrass curve (https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html)
   199  // y² = x³ + ax + b. Two types of elliptic curves are supported:
   200  // 1. Secp256k1 (Koblitz curve): y² = x³ + b,
   201  // 2. Secp256r1 (Random curve): y² = x³ - 3x + b.
   202  // To decode a compressed curve point, we perform the following operation: y = sqrt(x³ + ax + b mod p)
   203  // where `p` denotes the order of the underlying curve field.
   204  func decodeCompressedY(x *big.Int, ylsb uint, curve elliptic.Curve) (*big.Int, error) {
   205  	var a *big.Int
   206  	switch curve.(type) {
   207  	case *secp256k1.KoblitzCurve:
   208  		a = big0
   209  	default:
   210  		a = big3
   211  	}
   212  	cp := curve.Params()
   213  	xCubed := new(big.Int).Exp(x, big3, cp.P)
   214  	aX := new(big.Int).Mul(x, a)
   215  	aX.Mod(aX, cp.P)
   216  	ySquared := new(big.Int).Sub(xCubed, aX)
   217  	ySquared.Add(ySquared, cp.B)
   218  	ySquared.Mod(ySquared, cp.P)
   219  	y := new(big.Int).ModSqrt(ySquared, cp.P)
   220  	if y == nil {
   221  		return nil, errors.New("error computing Y for compressed point")
   222  	}
   223  	if y.Bit(0) != ylsb {
   224  		y.Neg(y)
   225  		y.Mod(y, cp.P)
   226  	}
   227  	return y, nil
   228  }
   229  
   230  // DecodeBytes decodes a PublicKey from the given slice of bytes.
   231  func (p *PublicKey) DecodeBytes(data []byte) error {
   232  	b := io.NewBinReaderFromBuf(data)
   233  	p.DecodeBinary(b)
   234  	if b.Err != nil {
   235  		return b.Err
   236  	}
   237  
   238  	if b.Len() != 0 {
   239  		return errors.New("extra data")
   240  	}
   241  	return nil
   242  }
   243  
   244  // DecodeBinary decodes a PublicKey from the given BinReader using information
   245  // about the EC curve to decompress Y point. Secp256r1 is a default value for EC curve.
   246  func (p *PublicKey) DecodeBinary(r *io.BinReader) {
   247  	var prefix uint8
   248  	var x, y *big.Int
   249  	var err error
   250  
   251  	prefix = uint8(r.ReadB())
   252  	if r.Err != nil {
   253  		return
   254  	}
   255  
   256  	if p.Curve == nil {
   257  		p.Curve = elliptic.P256()
   258  	}
   259  	curve := p.Curve
   260  	curveParams := p.Params()
   261  	// Infinity
   262  	switch prefix {
   263  	case 0x00:
   264  		// noop, initialized to nil
   265  		return
   266  	case 0x02, 0x03:
   267  		// Compressed public keys
   268  		xbytes := make([]byte, coordLen)
   269  		r.ReadBytes(xbytes)
   270  		if r.Err != nil {
   271  			return
   272  		}
   273  		x = new(big.Int).SetBytes(xbytes)
   274  		ylsb := uint(prefix & 0x1)
   275  		y, err = decodeCompressedY(x, ylsb, curve)
   276  		if err != nil {
   277  			r.Err = err
   278  			return
   279  		}
   280  	case 0x04:
   281  		xbytes := make([]byte, coordLen)
   282  		ybytes := make([]byte, coordLen)
   283  		r.ReadBytes(xbytes)
   284  		r.ReadBytes(ybytes)
   285  		if r.Err != nil {
   286  			return
   287  		}
   288  		x = new(big.Int).SetBytes(xbytes)
   289  		y = new(big.Int).SetBytes(ybytes)
   290  		if !curve.IsOnCurve(x, y) {
   291  			r.Err = errors.New("encoded point is not on the P256 curve")
   292  			return
   293  		}
   294  	default:
   295  		r.Err = fmt.Errorf("invalid prefix %d", prefix)
   296  		return
   297  	}
   298  	if x.Cmp(curveParams.P) >= 0 || y.Cmp(curveParams.P) >= 0 {
   299  		r.Err = errors.New("enccoded point is not correct (X or Y is bigger than P")
   300  		return
   301  	}
   302  	p.X, p.Y = x, y
   303  }
   304  
   305  // EncodeBinary encodes a PublicKey to the given BinWriter.
   306  func (p *PublicKey) EncodeBinary(w *io.BinWriter) {
   307  	w.WriteBytes(p.Bytes())
   308  }
   309  
   310  // GetVerificationScript returns NEO VM bytecode with CHECKSIG command for the
   311  // public key.
   312  func (p *PublicKey) GetVerificationScript() []byte {
   313  	b := p.Bytes()
   314  	buf := io.NewBufBinWriter()
   315  	if address.Prefix == address.NEO2Prefix {
   316  		buf.WriteB(0x21) // PUSHBYTES33
   317  		buf.WriteBytes(p.Bytes())
   318  		buf.WriteB(0xAC) // CHECKSIG
   319  		return buf.Bytes()
   320  	}
   321  	emit.CheckSig(buf.BinWriter, b)
   322  
   323  	return buf.Bytes()
   324  }
   325  
   326  // GetScriptHash returns a Hash160 of verification script for the key.
   327  func (p *PublicKey) GetScriptHash() util.Uint160 {
   328  	return hash.Hash160(p.GetVerificationScript())
   329  }
   330  
   331  // Address returns a base58-encoded NEO-specific address based on the key hash.
   332  func (p *PublicKey) Address() string {
   333  	return address.Uint160ToString(p.GetScriptHash())
   334  }
   335  
   336  // Verify returns true if the signature is valid and corresponds
   337  // to the hash and public key.
   338  func (p *PublicKey) Verify(signature []byte, hash []byte) bool {
   339  	if p.X == nil || p.Y == nil || len(signature) != SignatureLen {
   340  		return false
   341  	}
   342  	rBytes := new(big.Int).SetBytes(signature[0:32])
   343  	sBytes := new(big.Int).SetBytes(signature[32:64])
   344  	return ecdsa.Verify((*ecdsa.PublicKey)(p), hash, rBytes, sBytes)
   345  }
   346  
   347  // VerifyHashable returns true if the signature is valid and corresponds
   348  // to the hash and public key.
   349  func (p *PublicKey) VerifyHashable(signature []byte, net uint32, hh hash.Hashable) bool {
   350  	var digest = hash.NetSha256(net, hh)
   351  	return p.Verify(signature, digest[:])
   352  }
   353  
   354  // IsInfinity checks if the key is infinite (null, basically).
   355  func (p *PublicKey) IsInfinity() bool {
   356  	return p.X == nil && p.Y == nil
   357  }
   358  
   359  // String implements the Stringer interface.
   360  func (p *PublicKey) String() string {
   361  	if p.IsInfinity() {
   362  		return "00"
   363  	}
   364  	bx := hex.EncodeToString(p.X.Bytes())
   365  	by := hex.EncodeToString(p.Y.Bytes())
   366  	return fmt.Sprintf("%s%s", bx, by)
   367  }
   368  
   369  // MarshalJSON implements the json.Marshaler interface.
   370  func (p PublicKey) MarshalJSON() ([]byte, error) {
   371  	return json.Marshal(p.StringCompressed())
   372  }
   373  
   374  // UnmarshalJSON implements the json.Unmarshaler interface.
   375  func (p *PublicKey) UnmarshalJSON(data []byte) error {
   376  	l := len(data)
   377  	if l < 2 || data[0] != '"' || data[l-1] != '"' {
   378  		return errors.New("wrong format")
   379  	}
   380  
   381  	bytes := make([]byte, hex.DecodedLen(l-2))
   382  	_, err := hex.Decode(bytes, data[1:l-1])
   383  	if err != nil {
   384  		return err
   385  	}
   386  	err = p.DecodeBytes(bytes)
   387  	if err != nil {
   388  		return err
   389  	}
   390  
   391  	return nil
   392  }
   393  
   394  // MarshalYAML implements the YAML marshaler interface.
   395  func (p *PublicKey) MarshalYAML() (any, error) {
   396  	return p.StringCompressed(), nil
   397  }
   398  
   399  // UnmarshalYAML implements the YAML unmarshaler interface.
   400  func (p *PublicKey) UnmarshalYAML(unmarshal func(any) error) error {
   401  	var s string
   402  	err := unmarshal(&s)
   403  	if err != nil {
   404  		return err
   405  	}
   406  
   407  	b, err := hex.DecodeString(s)
   408  	if err != nil {
   409  		return fmt.Errorf("failed to decode public key from hex bytes: %w", err)
   410  	}
   411  	return p.DecodeBytes(b)
   412  }
   413  
   414  // StringCompressed returns the hex string representation of the public key
   415  // in its compressed form.
   416  func (p *PublicKey) StringCompressed() string {
   417  	return hex.EncodeToString(p.Bytes())
   418  }