github.com/gagliardetto/solana-go@v1.11.0/keys.go (about)

     1  // Copyright 2021 github.com/gagliardetto
     2  // This file has been modified by github.com/gagliardetto
     3  //
     4  // Copyright 2020 dfuse Platform Inc.
     5  //
     6  // Licensed under the Apache License, Version 2.0 (the "License");
     7  // you may not use this file except in compliance with the License.
     8  // You may obtain a copy of the License at
     9  //
    10  //      http://www.apache.org/licenses/LICENSE-2.0
    11  //
    12  // Unless required by applicable law or agreed to in writing, software
    13  // distributed under the License is distributed on an "AS IS" BASIS,
    14  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package solana
    19  
    20  import (
    21  	"bytes"
    22  	"crypto"
    23  	"crypto/ed25519"
    24  	crypto_rand "crypto/rand"
    25  	"crypto/sha256"
    26  	"errors"
    27  	"fmt"
    28  	"io/ioutil"
    29  	"math"
    30  	"sort"
    31  
    32  	"filippo.io/edwards25519"
    33  	"github.com/mr-tron/base58"
    34  	"go.mongodb.org/mongo-driver/bson"
    35  	"go.mongodb.org/mongo-driver/bson/bsontype"
    36  )
    37  
    38  type PrivateKey []byte
    39  
    40  func MustPrivateKeyFromBase58(in string) PrivateKey {
    41  	out, err := PrivateKeyFromBase58(in)
    42  	if err != nil {
    43  		panic(err)
    44  	}
    45  	return out
    46  }
    47  
    48  func PrivateKeyFromBase58(privkey string) (PrivateKey, error) {
    49  	res, err := base58.Decode(privkey)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	return res, nil
    54  }
    55  
    56  func PrivateKeyFromSolanaKeygenFile(file string) (PrivateKey, error) {
    57  	content, err := ioutil.ReadFile(file)
    58  	if err != nil {
    59  		return nil, fmt.Errorf("read keygen file: %w", err)
    60  	}
    61  
    62  	var values []byte
    63  	err = json.Unmarshal(content, &values)
    64  	if err != nil {
    65  		return nil, fmt.Errorf("decode keygen file: %w", err)
    66  	}
    67  
    68  	return PrivateKey([]byte(values)), nil
    69  }
    70  
    71  func (k PrivateKey) String() string {
    72  	return base58.Encode(k)
    73  }
    74  
    75  func NewRandomPrivateKey() (PrivateKey, error) {
    76  	pub, priv, err := ed25519.GenerateKey(crypto_rand.Reader)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	var publicKey PublicKey
    81  	copy(publicKey[:], pub)
    82  	return PrivateKey(priv), nil
    83  }
    84  
    85  func (k PrivateKey) Sign(payload []byte) (Signature, error) {
    86  	p := ed25519.PrivateKey(k)
    87  	signData, err := p.Sign(crypto_rand.Reader, payload, crypto.Hash(0))
    88  	if err != nil {
    89  		return Signature{}, err
    90  	}
    91  
    92  	var signature Signature
    93  	copy(signature[:], signData)
    94  
    95  	return signature, err
    96  }
    97  
    98  func (k PrivateKey) PublicKey() PublicKey {
    99  	p := ed25519.PrivateKey(k)
   100  	pub := p.Public().(ed25519.PublicKey)
   101  
   102  	var publicKey PublicKey
   103  	copy(publicKey[:], pub)
   104  
   105  	return publicKey
   106  }
   107  
   108  // PK is a convenience alias for PublicKey
   109  type PK = PublicKey
   110  
   111  func (p PublicKey) Verify(message []byte, signature Signature) bool {
   112  	pub := ed25519.PublicKey(p[:])
   113  	return ed25519.Verify(pub, message, signature[:])
   114  }
   115  
   116  type PublicKey [PublicKeyLength]byte
   117  
   118  func PublicKeyFromBytes(in []byte) (out PublicKey) {
   119  	byteCount := len(in)
   120  	if byteCount == 0 {
   121  		return
   122  	}
   123  
   124  	max := PublicKeyLength
   125  	if byteCount < max {
   126  		max = byteCount
   127  	}
   128  
   129  	copy(out[:], in[0:max])
   130  	return
   131  }
   132  
   133  // MPK is a convenience alias for MustPublicKeyFromBase58
   134  func MPK(in string) PublicKey {
   135  	return MustPublicKeyFromBase58(in)
   136  }
   137  
   138  func MustPublicKeyFromBase58(in string) PublicKey {
   139  	out, err := PublicKeyFromBase58(in)
   140  	if err != nil {
   141  		panic(err)
   142  	}
   143  	return out
   144  }
   145  
   146  func PublicKeyFromBase58(in string) (out PublicKey, err error) {
   147  	val, err := base58.Decode(in)
   148  	if err != nil {
   149  		return out, fmt.Errorf("decode: %w", err)
   150  	}
   151  
   152  	if len(val) != PublicKeyLength {
   153  		return out, fmt.Errorf("invalid length, expected %v, got %d", PublicKeyLength, len(val))
   154  	}
   155  
   156  	copy(out[:], val)
   157  	return
   158  }
   159  
   160  func (p PublicKey) MarshalText() ([]byte, error) {
   161  	return []byte(base58.Encode(p[:])), nil
   162  }
   163  
   164  func (p *PublicKey) UnmarshalText(data []byte) error {
   165  	return p.Set(string(data))
   166  }
   167  
   168  func (p PublicKey) MarshalJSON() ([]byte, error) {
   169  	return json.Marshal(base58.Encode(p[:]))
   170  }
   171  
   172  func (p *PublicKey) UnmarshalJSON(data []byte) (err error) {
   173  	var s string
   174  	if err := json.Unmarshal(data, &s); err != nil {
   175  		return err
   176  	}
   177  
   178  	*p, err = PublicKeyFromBase58(s)
   179  	if err != nil {
   180  		return fmt.Errorf("invalid public key %q: %w", s, err)
   181  	}
   182  	return
   183  }
   184  
   185  // MarshalBSON implements the bson.Marshaler interface.
   186  func (p PublicKey) MarshalBSON() ([]byte, error) {
   187  	return bson.Marshal(p.String())
   188  }
   189  
   190  // UnmarshalBSON implements the bson.Unmarshaler interface.
   191  func (p *PublicKey) UnmarshalBSON(data []byte) (err error) {
   192  	var s string
   193  	if err := bson.Unmarshal(data, &s); err != nil {
   194  		return err
   195  	}
   196  
   197  	*p, err = PublicKeyFromBase58(s)
   198  	if err != nil {
   199  		return fmt.Errorf("invalid public key %q: %w", s, err)
   200  	}
   201  	return nil
   202  }
   203  
   204  // MarshalBSONValue implements the bson.ValueMarshaler interface.
   205  func (p PublicKey) MarshalBSONValue() (bsontype.Type, []byte, error) {
   206  	return bson.MarshalValue(p.String())
   207  }
   208  
   209  // UnmarshalBSONValue implements the bson.ValueUnmarshaler interface.
   210  func (p *PublicKey) UnmarshalBSONValue(t bsontype.Type, data []byte) (err error) {
   211  	var s string
   212  	if err := bson.Unmarshal(data, &s); err != nil {
   213  		return err
   214  	}
   215  
   216  	*p, err = PublicKeyFromBase58(s)
   217  	if err != nil {
   218  		return fmt.Errorf("invalid public key %q: %w", s, err)
   219  	}
   220  	return nil
   221  }
   222  
   223  func (p PublicKey) Equals(pb PublicKey) bool {
   224  	return p == pb
   225  }
   226  
   227  // IsAnyOf checks if p is equals to any of the provided keys.
   228  func (p PublicKey) IsAnyOf(keys ...PublicKey) bool {
   229  	for _, k := range keys {
   230  		if p.Equals(k) {
   231  			return true
   232  		}
   233  	}
   234  	return false
   235  }
   236  
   237  // ToPointer returns a pointer to the pubkey.
   238  func (p PublicKey) ToPointer() *PublicKey {
   239  	return &p
   240  }
   241  
   242  func (p PublicKey) Bytes() []byte {
   243  	return []byte(p[:])
   244  }
   245  
   246  // Check if a `Pubkey` is on the ed25519 curve.
   247  func (p PublicKey) IsOnCurve() bool {
   248  	return IsOnCurve(p[:])
   249  }
   250  
   251  var zeroPublicKey = PublicKey{}
   252  
   253  // IsZero returns whether the public key is zero.
   254  // NOTE: the System Program public key is also zero.
   255  func (p PublicKey) IsZero() bool {
   256  	return p == zeroPublicKey
   257  }
   258  
   259  func (p *PublicKey) Set(s string) (err error) {
   260  	*p, err = PublicKeyFromBase58(s)
   261  	if err != nil {
   262  		return fmt.Errorf("invalid public key %s: %w", s, err)
   263  	}
   264  	return
   265  }
   266  
   267  func (p PublicKey) String() string {
   268  	return base58.Encode(p[:])
   269  }
   270  
   271  // Short returns a shortened pubkey string,
   272  // only including the first n chars, ellipsis, and the last n characters.
   273  // NOTE: this is ONLY for visual representation for humans,
   274  // and cannot be used for anything else.
   275  func (p PublicKey) Short(n int) string {
   276  	return formatShortPubkey(n, p)
   277  }
   278  
   279  func formatShortPubkey(n int, pubkey PublicKey) string {
   280  	str := pubkey.String()
   281  	if n > (len(str)/2)-1 {
   282  		n = (len(str) / 2) - 1
   283  	}
   284  	if n < 2 {
   285  		n = 2
   286  	}
   287  	return str[:n] + "..." + str[len(str)-n:]
   288  }
   289  
   290  type PublicKeySlice []PublicKey
   291  
   292  // UniqueAppend appends the provided pubkey only if it is not
   293  // already present in the slice.
   294  // Returns true when the provided pubkey wasn't already present.
   295  func (slice *PublicKeySlice) UniqueAppend(pubkey PublicKey) bool {
   296  	if !slice.Has(pubkey) {
   297  		slice.Append(pubkey)
   298  		return true
   299  	}
   300  	return false
   301  }
   302  
   303  func (slice *PublicKeySlice) Append(pubkeys ...PublicKey) {
   304  	*slice = append(*slice, pubkeys...)
   305  }
   306  
   307  func (slice PublicKeySlice) Has(pubkey PublicKey) bool {
   308  	for _, key := range slice {
   309  		if key.Equals(pubkey) {
   310  			return true
   311  		}
   312  	}
   313  	return false
   314  }
   315  
   316  func (slice PublicKeySlice) Len() int {
   317  	return len(slice)
   318  }
   319  
   320  func (slice PublicKeySlice) Less(i, j int) bool {
   321  	return bytes.Compare(slice[i][:], slice[j][:]) < 0
   322  }
   323  
   324  func (slice PublicKeySlice) Swap(i, j int) {
   325  	slice[i], slice[j] = slice[j], slice[i]
   326  }
   327  
   328  // Sort sorts the slice.
   329  func (slice PublicKeySlice) Sort() {
   330  	sort.Sort(slice)
   331  }
   332  
   333  // Dedupe returns a new slice with all duplicate pubkeys removed.
   334  func (slice PublicKeySlice) Dedupe() PublicKeySlice {
   335  	slice.Sort()
   336  	deduped := make(PublicKeySlice, 0)
   337  	for i := 0; i < len(slice); i++ {
   338  		if i == 0 || !slice[i].Equals(slice[i-1]) {
   339  			deduped = append(deduped, slice[i])
   340  		}
   341  	}
   342  	return deduped
   343  }
   344  
   345  // Contains returns true if the slice contains the provided pubkey.
   346  func (slice PublicKeySlice) Contains(pubkey PublicKey) bool {
   347  	for _, key := range slice {
   348  		if key.Equals(pubkey) {
   349  			return true
   350  		}
   351  	}
   352  	return false
   353  }
   354  
   355  // ContainsAll returns true if all the provided pubkeys are present in the slice.
   356  func (slice PublicKeySlice) ContainsAll(pubkeys PublicKeySlice) bool {
   357  	for _, pubkey := range pubkeys {
   358  		if !slice.Contains(pubkey) {
   359  			return false
   360  		}
   361  	}
   362  	return true
   363  }
   364  
   365  // ContainsAny returns true if any of the provided pubkeys are present in the slice.
   366  func (slice PublicKeySlice) ContainsAny(pubkeys ...PublicKey) bool {
   367  	for _, pubkey := range pubkeys {
   368  		if slice.Contains(pubkey) {
   369  			return true
   370  		}
   371  	}
   372  	return false
   373  }
   374  
   375  func (slice PublicKeySlice) ToBase58() []string {
   376  	out := make([]string, len(slice))
   377  	for i, pubkey := range slice {
   378  		out[i] = pubkey.String()
   379  	}
   380  	return out
   381  }
   382  
   383  func (slice PublicKeySlice) ToBytes() [][]byte {
   384  	out := make([][]byte, len(slice))
   385  	for i, pubkey := range slice {
   386  		out[i] = pubkey.Bytes()
   387  	}
   388  	return out
   389  }
   390  
   391  func (slice PublicKeySlice) ToPointers() []*PublicKey {
   392  	out := make([]*PublicKey, len(slice))
   393  	for i, pubkey := range slice {
   394  		out[i] = pubkey.ToPointer()
   395  	}
   396  	return out
   397  }
   398  
   399  // Removed returns the elements that are present in `a` but not in `b`.
   400  func (a PublicKeySlice) Removed(b PublicKeySlice) PublicKeySlice {
   401  	var diff PublicKeySlice
   402  	for _, pubkey := range a {
   403  		if !b.Contains(pubkey) {
   404  			diff = append(diff, pubkey)
   405  		}
   406  	}
   407  	return diff.Dedupe()
   408  }
   409  
   410  // Added returns the elements that are present in `b` but not in `a`.
   411  func (a PublicKeySlice) Added(b PublicKeySlice) PublicKeySlice {
   412  	return b.Removed(a)
   413  }
   414  
   415  // Intersect returns the intersection of two PublicKeySlices, i.e. the elements
   416  // that are in both PublicKeySlices.
   417  // The returned PublicKeySlice is sorted and deduped.
   418  func (prev PublicKeySlice) Intersect(next PublicKeySlice) PublicKeySlice {
   419  	var intersect PublicKeySlice
   420  	for _, pubkey := range prev {
   421  		if next.Contains(pubkey) {
   422  			intersect = append(intersect, pubkey)
   423  		}
   424  	}
   425  	return intersect.Dedupe()
   426  }
   427  
   428  // Equals returns true if the two PublicKeySlices are equal (same order of same keys).
   429  func (slice PublicKeySlice) Equals(other PublicKeySlice) bool {
   430  	if len(slice) != len(other) {
   431  		return false
   432  	}
   433  	for i, pubkey := range slice {
   434  		if !pubkey.Equals(other[i]) {
   435  			return false
   436  		}
   437  	}
   438  	return true
   439  }
   440  
   441  // Same returns true if the two slices contain the same public keys,
   442  // but not necessarily in the same order.
   443  func (slice PublicKeySlice) Same(other PublicKeySlice) bool {
   444  	if len(slice) != len(other) {
   445  		return false
   446  	}
   447  	for _, pubkey := range slice {
   448  		if !other.Contains(pubkey) {
   449  			return false
   450  		}
   451  	}
   452  	return true
   453  }
   454  
   455  // Split splits the slice into chunks of the specified size.
   456  func (slice PublicKeySlice) Split(chunkSize int) []PublicKeySlice {
   457  	divided := make([]PublicKeySlice, 0)
   458  	if len(slice) == 0 || chunkSize < 1 {
   459  		return divided
   460  	}
   461  	if len(slice) == 1 {
   462  		return append(divided, slice)
   463  	}
   464  
   465  	for i := 0; i < len(slice); i += chunkSize {
   466  		end := i + chunkSize
   467  
   468  		if end > len(slice) {
   469  			end = len(slice)
   470  		}
   471  
   472  		divided = append(divided, slice[i:end])
   473  	}
   474  
   475  	return divided
   476  }
   477  
   478  // Last returns the last element of the slice.
   479  // Returns nil if the slice is empty.
   480  func (slice PublicKeySlice) Last() *PublicKey {
   481  	if len(slice) == 0 {
   482  		return nil
   483  	}
   484  	return slice[len(slice)-1].ToPointer()
   485  }
   486  
   487  // First returns the first element of the slice.
   488  // Returns nil if the slice is empty.
   489  func (slice PublicKeySlice) First() *PublicKey {
   490  	if len(slice) == 0 {
   491  		return nil
   492  	}
   493  	return slice[0].ToPointer()
   494  }
   495  
   496  // GetAddedRemoved compares to the `next` pubkey slice, and returns
   497  // two slices:
   498  // - `added` is the slice of pubkeys that are present in `next` but NOT present in `previous`.
   499  // - `removed` is the slice of pubkeys that are present in `previous` but are NOT present in `next`.
   500  func (prev PublicKeySlice) GetAddedRemoved(next PublicKeySlice) (added PublicKeySlice, removed PublicKeySlice) {
   501  	return next.Removed(prev), prev.Removed(next)
   502  }
   503  
   504  // GetAddedRemovedPubkeys accepts two slices of pubkeys (`previous` and `next`), and returns
   505  // two slices:
   506  // - `added` is the slice of pubkeys that are present in `next` but NOT present in `previous`.
   507  // - `removed` is the slice of pubkeys that are present in `previous` but are NOT present in `next`.
   508  func GetAddedRemovedPubkeys(previous PublicKeySlice, next PublicKeySlice) (added PublicKeySlice, removed PublicKeySlice) {
   509  	added = make(PublicKeySlice, 0)
   510  	removed = make(PublicKeySlice, 0)
   511  
   512  	for _, prev := range previous {
   513  		if !next.Has(prev) {
   514  			removed = append(removed, prev)
   515  		}
   516  	}
   517  
   518  	for _, nx := range next {
   519  		if !previous.Has(nx) {
   520  			added = append(added, nx)
   521  		}
   522  	}
   523  
   524  	return
   525  }
   526  
   527  var nativeProgramIDs = PublicKeySlice{
   528  	BPFLoaderProgramID,
   529  	BPFLoaderDeprecatedProgramID,
   530  	FeatureProgramID,
   531  	ConfigProgramID,
   532  	StakeProgramID,
   533  	VoteProgramID,
   534  	Secp256k1ProgramID,
   535  	SystemProgramID,
   536  	SysVarClockPubkey,
   537  	SysVarEpochSchedulePubkey,
   538  	SysVarFeesPubkey,
   539  	SysVarInstructionsPubkey,
   540  	SysVarRecentBlockHashesPubkey,
   541  	SysVarRentPubkey,
   542  	SysVarRewardsPubkey,
   543  	SysVarSlotHashesPubkey,
   544  	SysVarSlotHistoryPubkey,
   545  	SysVarStakeHistoryPubkey,
   546  }
   547  
   548  // https://github.com/solana-labs/solana/blob/216983c50e0a618facc39aa07472ba6d23f1b33a/sdk/program/src/pubkey.rs#L372
   549  func isNativeProgramID(key PublicKey) bool {
   550  	return nativeProgramIDs.Has(key)
   551  }
   552  
   553  const (
   554  	/// Number of bytes in a pubkey.
   555  	PublicKeyLength = 32
   556  	// Maximum length of derived pubkey seed.
   557  	MaxSeedLength = 32
   558  	// Maximum number of seeds.
   559  	MaxSeeds = 16
   560  	/// Number of bytes in a signature.
   561  	SignatureLength = 64
   562  
   563  	// // Maximum string length of a base58 encoded pubkey.
   564  	// MaxBase58Length = 44
   565  )
   566  
   567  // Ported from https://github.com/solana-labs/solana/blob/216983c50e0a618facc39aa07472ba6d23f1b33a/sdk/program/src/pubkey.rs#L159
   568  func CreateWithSeed(base PublicKey, seed string, owner PublicKey) (PublicKey, error) {
   569  	if len(seed) > MaxSeedLength {
   570  		return PublicKey{}, ErrMaxSeedLengthExceeded
   571  	}
   572  
   573  	// let owner = owner.as_ref();
   574  	// if owner.len() >= PDA_MARKER.len() {
   575  	//     let slice = &owner[owner.len() - PDA_MARKER.len()..];
   576  	//     if slice == PDA_MARKER {
   577  	//         return Err(PubkeyError::IllegalOwner);
   578  	//     }
   579  	// }
   580  
   581  	b := make([]byte, 0, 64+len(seed))
   582  	b = append(b, base[:]...)
   583  	b = append(b, seed[:]...)
   584  	b = append(b, owner[:]...)
   585  	hash := sha256.Sum256(b)
   586  	return PublicKeyFromBytes(hash[:]), nil
   587  }
   588  
   589  const PDA_MARKER = "ProgramDerivedAddress"
   590  
   591  var ErrMaxSeedLengthExceeded = errors.New("max seed length exceeded")
   592  
   593  // Create a program address.
   594  // Ported from https://github.com/solana-labs/solana/blob/216983c50e0a618facc39aa07472ba6d23f1b33a/sdk/program/src/pubkey.rs#L204
   595  func CreateProgramAddress(seeds [][]byte, programID PublicKey) (PublicKey, error) {
   596  	if len(seeds) > MaxSeeds {
   597  		return PublicKey{}, ErrMaxSeedLengthExceeded
   598  	}
   599  
   600  	for _, seed := range seeds {
   601  		if len(seed) > MaxSeedLength {
   602  			return PublicKey{}, ErrMaxSeedLengthExceeded
   603  		}
   604  	}
   605  
   606  	buf := []byte{}
   607  	for _, seed := range seeds {
   608  		buf = append(buf, seed...)
   609  	}
   610  
   611  	buf = append(buf, programID[:]...)
   612  	buf = append(buf, []byte(PDA_MARKER)...)
   613  	hash := sha256.Sum256(buf)
   614  
   615  	if IsOnCurve(hash[:]) {
   616  		return PublicKey{}, errors.New("invalid seeds; address must fall off the curve")
   617  	}
   618  
   619  	return PublicKeyFromBytes(hash[:]), nil
   620  }
   621  
   622  // Check if the provided `b` is on the ed25519 curve.
   623  func IsOnCurve(b []byte) bool {
   624  	_, err := new(edwards25519.Point).SetBytes(b)
   625  	isOnCurve := err == nil
   626  	return isOnCurve
   627  }
   628  
   629  // Find a valid program address and its corresponding bump seed.
   630  func FindProgramAddress(seed [][]byte, programID PublicKey) (PublicKey, uint8, error) {
   631  	var address PublicKey
   632  	var err error
   633  	bumpSeed := uint8(math.MaxUint8)
   634  	for bumpSeed != 0 {
   635  		address, err = CreateProgramAddress(append(seed, []byte{byte(bumpSeed)}), programID)
   636  		if err == nil {
   637  			return address, bumpSeed, nil
   638  		}
   639  		bumpSeed--
   640  	}
   641  	return PublicKey{}, bumpSeed, errors.New("unable to find a valid program address")
   642  }
   643  
   644  func FindAssociatedTokenAddress(
   645  	wallet PublicKey,
   646  	mint PublicKey,
   647  ) (PublicKey, uint8, error) {
   648  	return findAssociatedTokenAddressAndBumpSeed(
   649  		wallet,
   650  		mint,
   651  		SPLAssociatedTokenAccountProgramID,
   652  	)
   653  }
   654  
   655  func findAssociatedTokenAddressAndBumpSeed(
   656  	walletAddress PublicKey,
   657  	splTokenMintAddress PublicKey,
   658  	programID PublicKey,
   659  ) (PublicKey, uint8, error) {
   660  	return FindProgramAddress([][]byte{
   661  		walletAddress[:],
   662  		TokenProgramID[:],
   663  		splTokenMintAddress[:],
   664  	},
   665  		programID,
   666  	)
   667  }
   668  
   669  // FindTokenMetadataAddress returns the token metadata program-derived address given a SPL token mint address.
   670  func FindTokenMetadataAddress(mint PublicKey) (PublicKey, uint8, error) {
   671  	seed := [][]byte{
   672  		[]byte("metadata"),
   673  		TokenMetadataProgramID[:],
   674  		mint[:],
   675  	}
   676  	return FindProgramAddress(seed, TokenMetadataProgramID)
   677  }