github.com/gagliardetto/solana-go@v1.11.0/keys.go (about) 1 // Copyright 2021 github.com/gagliardetto 2 // This file has been modified by github.com/gagliardetto 3 // 4 // Copyright 2020 dfuse Platform Inc. 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 package solana 19 20 import ( 21 "bytes" 22 "crypto" 23 "crypto/ed25519" 24 crypto_rand "crypto/rand" 25 "crypto/sha256" 26 "errors" 27 "fmt" 28 "io/ioutil" 29 "math" 30 "sort" 31 32 "filippo.io/edwards25519" 33 "github.com/mr-tron/base58" 34 "go.mongodb.org/mongo-driver/bson" 35 "go.mongodb.org/mongo-driver/bson/bsontype" 36 ) 37 38 type PrivateKey []byte 39 40 func MustPrivateKeyFromBase58(in string) PrivateKey { 41 out, err := PrivateKeyFromBase58(in) 42 if err != nil { 43 panic(err) 44 } 45 return out 46 } 47 48 func PrivateKeyFromBase58(privkey string) (PrivateKey, error) { 49 res, err := base58.Decode(privkey) 50 if err != nil { 51 return nil, err 52 } 53 return res, nil 54 } 55 56 func PrivateKeyFromSolanaKeygenFile(file string) (PrivateKey, error) { 57 content, err := ioutil.ReadFile(file) 58 if err != nil { 59 return nil, fmt.Errorf("read keygen file: %w", err) 60 } 61 62 var values []byte 63 err = json.Unmarshal(content, &values) 64 if err != nil { 65 return nil, fmt.Errorf("decode keygen file: %w", err) 66 } 67 68 return PrivateKey([]byte(values)), nil 69 } 70 71 func (k PrivateKey) String() string { 72 return base58.Encode(k) 73 } 74 75 func NewRandomPrivateKey() (PrivateKey, error) { 76 pub, priv, err := ed25519.GenerateKey(crypto_rand.Reader) 77 if err != nil { 78 return nil, err 79 } 80 var publicKey PublicKey 81 copy(publicKey[:], pub) 82 return PrivateKey(priv), nil 83 } 84 85 func (k PrivateKey) Sign(payload []byte) (Signature, error) { 86 p := ed25519.PrivateKey(k) 87 signData, err := p.Sign(crypto_rand.Reader, payload, crypto.Hash(0)) 88 if err != nil { 89 return Signature{}, err 90 } 91 92 var signature Signature 93 copy(signature[:], signData) 94 95 return signature, err 96 } 97 98 func (k PrivateKey) PublicKey() PublicKey { 99 p := ed25519.PrivateKey(k) 100 pub := p.Public().(ed25519.PublicKey) 101 102 var publicKey PublicKey 103 copy(publicKey[:], pub) 104 105 return publicKey 106 } 107 108 // PK is a convenience alias for PublicKey 109 type PK = PublicKey 110 111 func (p PublicKey) Verify(message []byte, signature Signature) bool { 112 pub := ed25519.PublicKey(p[:]) 113 return ed25519.Verify(pub, message, signature[:]) 114 } 115 116 type PublicKey [PublicKeyLength]byte 117 118 func PublicKeyFromBytes(in []byte) (out PublicKey) { 119 byteCount := len(in) 120 if byteCount == 0 { 121 return 122 } 123 124 max := PublicKeyLength 125 if byteCount < max { 126 max = byteCount 127 } 128 129 copy(out[:], in[0:max]) 130 return 131 } 132 133 // MPK is a convenience alias for MustPublicKeyFromBase58 134 func MPK(in string) PublicKey { 135 return MustPublicKeyFromBase58(in) 136 } 137 138 func MustPublicKeyFromBase58(in string) PublicKey { 139 out, err := PublicKeyFromBase58(in) 140 if err != nil { 141 panic(err) 142 } 143 return out 144 } 145 146 func PublicKeyFromBase58(in string) (out PublicKey, err error) { 147 val, err := base58.Decode(in) 148 if err != nil { 149 return out, fmt.Errorf("decode: %w", err) 150 } 151 152 if len(val) != PublicKeyLength { 153 return out, fmt.Errorf("invalid length, expected %v, got %d", PublicKeyLength, len(val)) 154 } 155 156 copy(out[:], val) 157 return 158 } 159 160 func (p PublicKey) MarshalText() ([]byte, error) { 161 return []byte(base58.Encode(p[:])), nil 162 } 163 164 func (p *PublicKey) UnmarshalText(data []byte) error { 165 return p.Set(string(data)) 166 } 167 168 func (p PublicKey) MarshalJSON() ([]byte, error) { 169 return json.Marshal(base58.Encode(p[:])) 170 } 171 172 func (p *PublicKey) UnmarshalJSON(data []byte) (err error) { 173 var s string 174 if err := json.Unmarshal(data, &s); err != nil { 175 return err 176 } 177 178 *p, err = PublicKeyFromBase58(s) 179 if err != nil { 180 return fmt.Errorf("invalid public key %q: %w", s, err) 181 } 182 return 183 } 184 185 // MarshalBSON implements the bson.Marshaler interface. 186 func (p PublicKey) MarshalBSON() ([]byte, error) { 187 return bson.Marshal(p.String()) 188 } 189 190 // UnmarshalBSON implements the bson.Unmarshaler interface. 191 func (p *PublicKey) UnmarshalBSON(data []byte) (err error) { 192 var s string 193 if err := bson.Unmarshal(data, &s); err != nil { 194 return err 195 } 196 197 *p, err = PublicKeyFromBase58(s) 198 if err != nil { 199 return fmt.Errorf("invalid public key %q: %w", s, err) 200 } 201 return nil 202 } 203 204 // MarshalBSONValue implements the bson.ValueMarshaler interface. 205 func (p PublicKey) MarshalBSONValue() (bsontype.Type, []byte, error) { 206 return bson.MarshalValue(p.String()) 207 } 208 209 // UnmarshalBSONValue implements the bson.ValueUnmarshaler interface. 210 func (p *PublicKey) UnmarshalBSONValue(t bsontype.Type, data []byte) (err error) { 211 var s string 212 if err := bson.Unmarshal(data, &s); err != nil { 213 return err 214 } 215 216 *p, err = PublicKeyFromBase58(s) 217 if err != nil { 218 return fmt.Errorf("invalid public key %q: %w", s, err) 219 } 220 return nil 221 } 222 223 func (p PublicKey) Equals(pb PublicKey) bool { 224 return p == pb 225 } 226 227 // IsAnyOf checks if p is equals to any of the provided keys. 228 func (p PublicKey) IsAnyOf(keys ...PublicKey) bool { 229 for _, k := range keys { 230 if p.Equals(k) { 231 return true 232 } 233 } 234 return false 235 } 236 237 // ToPointer returns a pointer to the pubkey. 238 func (p PublicKey) ToPointer() *PublicKey { 239 return &p 240 } 241 242 func (p PublicKey) Bytes() []byte { 243 return []byte(p[:]) 244 } 245 246 // Check if a `Pubkey` is on the ed25519 curve. 247 func (p PublicKey) IsOnCurve() bool { 248 return IsOnCurve(p[:]) 249 } 250 251 var zeroPublicKey = PublicKey{} 252 253 // IsZero returns whether the public key is zero. 254 // NOTE: the System Program public key is also zero. 255 func (p PublicKey) IsZero() bool { 256 return p == zeroPublicKey 257 } 258 259 func (p *PublicKey) Set(s string) (err error) { 260 *p, err = PublicKeyFromBase58(s) 261 if err != nil { 262 return fmt.Errorf("invalid public key %s: %w", s, err) 263 } 264 return 265 } 266 267 func (p PublicKey) String() string { 268 return base58.Encode(p[:]) 269 } 270 271 // Short returns a shortened pubkey string, 272 // only including the first n chars, ellipsis, and the last n characters. 273 // NOTE: this is ONLY for visual representation for humans, 274 // and cannot be used for anything else. 275 func (p PublicKey) Short(n int) string { 276 return formatShortPubkey(n, p) 277 } 278 279 func formatShortPubkey(n int, pubkey PublicKey) string { 280 str := pubkey.String() 281 if n > (len(str)/2)-1 { 282 n = (len(str) / 2) - 1 283 } 284 if n < 2 { 285 n = 2 286 } 287 return str[:n] + "..." + str[len(str)-n:] 288 } 289 290 type PublicKeySlice []PublicKey 291 292 // UniqueAppend appends the provided pubkey only if it is not 293 // already present in the slice. 294 // Returns true when the provided pubkey wasn't already present. 295 func (slice *PublicKeySlice) UniqueAppend(pubkey PublicKey) bool { 296 if !slice.Has(pubkey) { 297 slice.Append(pubkey) 298 return true 299 } 300 return false 301 } 302 303 func (slice *PublicKeySlice) Append(pubkeys ...PublicKey) { 304 *slice = append(*slice, pubkeys...) 305 } 306 307 func (slice PublicKeySlice) Has(pubkey PublicKey) bool { 308 for _, key := range slice { 309 if key.Equals(pubkey) { 310 return true 311 } 312 } 313 return false 314 } 315 316 func (slice PublicKeySlice) Len() int { 317 return len(slice) 318 } 319 320 func (slice PublicKeySlice) Less(i, j int) bool { 321 return bytes.Compare(slice[i][:], slice[j][:]) < 0 322 } 323 324 func (slice PublicKeySlice) Swap(i, j int) { 325 slice[i], slice[j] = slice[j], slice[i] 326 } 327 328 // Sort sorts the slice. 329 func (slice PublicKeySlice) Sort() { 330 sort.Sort(slice) 331 } 332 333 // Dedupe returns a new slice with all duplicate pubkeys removed. 334 func (slice PublicKeySlice) Dedupe() PublicKeySlice { 335 slice.Sort() 336 deduped := make(PublicKeySlice, 0) 337 for i := 0; i < len(slice); i++ { 338 if i == 0 || !slice[i].Equals(slice[i-1]) { 339 deduped = append(deduped, slice[i]) 340 } 341 } 342 return deduped 343 } 344 345 // Contains returns true if the slice contains the provided pubkey. 346 func (slice PublicKeySlice) Contains(pubkey PublicKey) bool { 347 for _, key := range slice { 348 if key.Equals(pubkey) { 349 return true 350 } 351 } 352 return false 353 } 354 355 // ContainsAll returns true if all the provided pubkeys are present in the slice. 356 func (slice PublicKeySlice) ContainsAll(pubkeys PublicKeySlice) bool { 357 for _, pubkey := range pubkeys { 358 if !slice.Contains(pubkey) { 359 return false 360 } 361 } 362 return true 363 } 364 365 // ContainsAny returns true if any of the provided pubkeys are present in the slice. 366 func (slice PublicKeySlice) ContainsAny(pubkeys ...PublicKey) bool { 367 for _, pubkey := range pubkeys { 368 if slice.Contains(pubkey) { 369 return true 370 } 371 } 372 return false 373 } 374 375 func (slice PublicKeySlice) ToBase58() []string { 376 out := make([]string, len(slice)) 377 for i, pubkey := range slice { 378 out[i] = pubkey.String() 379 } 380 return out 381 } 382 383 func (slice PublicKeySlice) ToBytes() [][]byte { 384 out := make([][]byte, len(slice)) 385 for i, pubkey := range slice { 386 out[i] = pubkey.Bytes() 387 } 388 return out 389 } 390 391 func (slice PublicKeySlice) ToPointers() []*PublicKey { 392 out := make([]*PublicKey, len(slice)) 393 for i, pubkey := range slice { 394 out[i] = pubkey.ToPointer() 395 } 396 return out 397 } 398 399 // Removed returns the elements that are present in `a` but not in `b`. 400 func (a PublicKeySlice) Removed(b PublicKeySlice) PublicKeySlice { 401 var diff PublicKeySlice 402 for _, pubkey := range a { 403 if !b.Contains(pubkey) { 404 diff = append(diff, pubkey) 405 } 406 } 407 return diff.Dedupe() 408 } 409 410 // Added returns the elements that are present in `b` but not in `a`. 411 func (a PublicKeySlice) Added(b PublicKeySlice) PublicKeySlice { 412 return b.Removed(a) 413 } 414 415 // Intersect returns the intersection of two PublicKeySlices, i.e. the elements 416 // that are in both PublicKeySlices. 417 // The returned PublicKeySlice is sorted and deduped. 418 func (prev PublicKeySlice) Intersect(next PublicKeySlice) PublicKeySlice { 419 var intersect PublicKeySlice 420 for _, pubkey := range prev { 421 if next.Contains(pubkey) { 422 intersect = append(intersect, pubkey) 423 } 424 } 425 return intersect.Dedupe() 426 } 427 428 // Equals returns true if the two PublicKeySlices are equal (same order of same keys). 429 func (slice PublicKeySlice) Equals(other PublicKeySlice) bool { 430 if len(slice) != len(other) { 431 return false 432 } 433 for i, pubkey := range slice { 434 if !pubkey.Equals(other[i]) { 435 return false 436 } 437 } 438 return true 439 } 440 441 // Same returns true if the two slices contain the same public keys, 442 // but not necessarily in the same order. 443 func (slice PublicKeySlice) Same(other PublicKeySlice) bool { 444 if len(slice) != len(other) { 445 return false 446 } 447 for _, pubkey := range slice { 448 if !other.Contains(pubkey) { 449 return false 450 } 451 } 452 return true 453 } 454 455 // Split splits the slice into chunks of the specified size. 456 func (slice PublicKeySlice) Split(chunkSize int) []PublicKeySlice { 457 divided := make([]PublicKeySlice, 0) 458 if len(slice) == 0 || chunkSize < 1 { 459 return divided 460 } 461 if len(slice) == 1 { 462 return append(divided, slice) 463 } 464 465 for i := 0; i < len(slice); i += chunkSize { 466 end := i + chunkSize 467 468 if end > len(slice) { 469 end = len(slice) 470 } 471 472 divided = append(divided, slice[i:end]) 473 } 474 475 return divided 476 } 477 478 // Last returns the last element of the slice. 479 // Returns nil if the slice is empty. 480 func (slice PublicKeySlice) Last() *PublicKey { 481 if len(slice) == 0 { 482 return nil 483 } 484 return slice[len(slice)-1].ToPointer() 485 } 486 487 // First returns the first element of the slice. 488 // Returns nil if the slice is empty. 489 func (slice PublicKeySlice) First() *PublicKey { 490 if len(slice) == 0 { 491 return nil 492 } 493 return slice[0].ToPointer() 494 } 495 496 // GetAddedRemoved compares to the `next` pubkey slice, and returns 497 // two slices: 498 // - `added` is the slice of pubkeys that are present in `next` but NOT present in `previous`. 499 // - `removed` is the slice of pubkeys that are present in `previous` but are NOT present in `next`. 500 func (prev PublicKeySlice) GetAddedRemoved(next PublicKeySlice) (added PublicKeySlice, removed PublicKeySlice) { 501 return next.Removed(prev), prev.Removed(next) 502 } 503 504 // GetAddedRemovedPubkeys accepts two slices of pubkeys (`previous` and `next`), and returns 505 // two slices: 506 // - `added` is the slice of pubkeys that are present in `next` but NOT present in `previous`. 507 // - `removed` is the slice of pubkeys that are present in `previous` but are NOT present in `next`. 508 func GetAddedRemovedPubkeys(previous PublicKeySlice, next PublicKeySlice) (added PublicKeySlice, removed PublicKeySlice) { 509 added = make(PublicKeySlice, 0) 510 removed = make(PublicKeySlice, 0) 511 512 for _, prev := range previous { 513 if !next.Has(prev) { 514 removed = append(removed, prev) 515 } 516 } 517 518 for _, nx := range next { 519 if !previous.Has(nx) { 520 added = append(added, nx) 521 } 522 } 523 524 return 525 } 526 527 var nativeProgramIDs = PublicKeySlice{ 528 BPFLoaderProgramID, 529 BPFLoaderDeprecatedProgramID, 530 FeatureProgramID, 531 ConfigProgramID, 532 StakeProgramID, 533 VoteProgramID, 534 Secp256k1ProgramID, 535 SystemProgramID, 536 SysVarClockPubkey, 537 SysVarEpochSchedulePubkey, 538 SysVarFeesPubkey, 539 SysVarInstructionsPubkey, 540 SysVarRecentBlockHashesPubkey, 541 SysVarRentPubkey, 542 SysVarRewardsPubkey, 543 SysVarSlotHashesPubkey, 544 SysVarSlotHistoryPubkey, 545 SysVarStakeHistoryPubkey, 546 } 547 548 // https://github.com/solana-labs/solana/blob/216983c50e0a618facc39aa07472ba6d23f1b33a/sdk/program/src/pubkey.rs#L372 549 func isNativeProgramID(key PublicKey) bool { 550 return nativeProgramIDs.Has(key) 551 } 552 553 const ( 554 /// Number of bytes in a pubkey. 555 PublicKeyLength = 32 556 // Maximum length of derived pubkey seed. 557 MaxSeedLength = 32 558 // Maximum number of seeds. 559 MaxSeeds = 16 560 /// Number of bytes in a signature. 561 SignatureLength = 64 562 563 // // Maximum string length of a base58 encoded pubkey. 564 // MaxBase58Length = 44 565 ) 566 567 // Ported from https://github.com/solana-labs/solana/blob/216983c50e0a618facc39aa07472ba6d23f1b33a/sdk/program/src/pubkey.rs#L159 568 func CreateWithSeed(base PublicKey, seed string, owner PublicKey) (PublicKey, error) { 569 if len(seed) > MaxSeedLength { 570 return PublicKey{}, ErrMaxSeedLengthExceeded 571 } 572 573 // let owner = owner.as_ref(); 574 // if owner.len() >= PDA_MARKER.len() { 575 // let slice = &owner[owner.len() - PDA_MARKER.len()..]; 576 // if slice == PDA_MARKER { 577 // return Err(PubkeyError::IllegalOwner); 578 // } 579 // } 580 581 b := make([]byte, 0, 64+len(seed)) 582 b = append(b, base[:]...) 583 b = append(b, seed[:]...) 584 b = append(b, owner[:]...) 585 hash := sha256.Sum256(b) 586 return PublicKeyFromBytes(hash[:]), nil 587 } 588 589 const PDA_MARKER = "ProgramDerivedAddress" 590 591 var ErrMaxSeedLengthExceeded = errors.New("max seed length exceeded") 592 593 // Create a program address. 594 // Ported from https://github.com/solana-labs/solana/blob/216983c50e0a618facc39aa07472ba6d23f1b33a/sdk/program/src/pubkey.rs#L204 595 func CreateProgramAddress(seeds [][]byte, programID PublicKey) (PublicKey, error) { 596 if len(seeds) > MaxSeeds { 597 return PublicKey{}, ErrMaxSeedLengthExceeded 598 } 599 600 for _, seed := range seeds { 601 if len(seed) > MaxSeedLength { 602 return PublicKey{}, ErrMaxSeedLengthExceeded 603 } 604 } 605 606 buf := []byte{} 607 for _, seed := range seeds { 608 buf = append(buf, seed...) 609 } 610 611 buf = append(buf, programID[:]...) 612 buf = append(buf, []byte(PDA_MARKER)...) 613 hash := sha256.Sum256(buf) 614 615 if IsOnCurve(hash[:]) { 616 return PublicKey{}, errors.New("invalid seeds; address must fall off the curve") 617 } 618 619 return PublicKeyFromBytes(hash[:]), nil 620 } 621 622 // Check if the provided `b` is on the ed25519 curve. 623 func IsOnCurve(b []byte) bool { 624 _, err := new(edwards25519.Point).SetBytes(b) 625 isOnCurve := err == nil 626 return isOnCurve 627 } 628 629 // Find a valid program address and its corresponding bump seed. 630 func FindProgramAddress(seed [][]byte, programID PublicKey) (PublicKey, uint8, error) { 631 var address PublicKey 632 var err error 633 bumpSeed := uint8(math.MaxUint8) 634 for bumpSeed != 0 { 635 address, err = CreateProgramAddress(append(seed, []byte{byte(bumpSeed)}), programID) 636 if err == nil { 637 return address, bumpSeed, nil 638 } 639 bumpSeed-- 640 } 641 return PublicKey{}, bumpSeed, errors.New("unable to find a valid program address") 642 } 643 644 func FindAssociatedTokenAddress( 645 wallet PublicKey, 646 mint PublicKey, 647 ) (PublicKey, uint8, error) { 648 return findAssociatedTokenAddressAndBumpSeed( 649 wallet, 650 mint, 651 SPLAssociatedTokenAccountProgramID, 652 ) 653 } 654 655 func findAssociatedTokenAddressAndBumpSeed( 656 walletAddress PublicKey, 657 splTokenMintAddress PublicKey, 658 programID PublicKey, 659 ) (PublicKey, uint8, error) { 660 return FindProgramAddress([][]byte{ 661 walletAddress[:], 662 TokenProgramID[:], 663 splTokenMintAddress[:], 664 }, 665 programID, 666 ) 667 } 668 669 // FindTokenMetadataAddress returns the token metadata program-derived address given a SPL token mint address. 670 func FindTokenMetadataAddress(mint PublicKey) (PublicKey, uint8, error) { 671 seed := [][]byte{ 672 []byte("metadata"), 673 TokenMetadataProgramID[:], 674 mint[:], 675 } 676 return FindProgramAddress(seed, TokenMetadataProgramID) 677 }