github.com/MetalBlockchain/metalgo@v1.11.9/vms/components/keystore/user.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package keystore
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/MetalBlockchain/metalgo/api/keystore"
    11  	"github.com/MetalBlockchain/metalgo/database"
    12  	"github.com/MetalBlockchain/metalgo/database/encdb"
    13  	"github.com/MetalBlockchain/metalgo/ids"
    14  	"github.com/MetalBlockchain/metalgo/utils/crypto/secp256k1"
    15  	"github.com/MetalBlockchain/metalgo/utils/set"
    16  	"github.com/MetalBlockchain/metalgo/vms/secp256k1fx"
    17  )
    18  
    19  // Max number of addresses allowed for a single keystore user
    20  const maxKeystoreAddresses = 5000
    21  
    22  var (
    23  	// Key in the database whose corresponding value is the list of addresses
    24  	// this user controls
    25  	addressesKey = ids.Empty[:]
    26  
    27  	errMaxAddresses = fmt.Errorf("keystore user has reached its limit of %d addresses", maxKeystoreAddresses)
    28  
    29  	_ User = (*user)(nil)
    30  )
    31  
    32  type User interface {
    33  	io.Closer
    34  
    35  	// Get the addresses controlled by this user
    36  	GetAddresses() ([]ids.ShortID, error)
    37  
    38  	// PutKeys persists [privKeys]
    39  	PutKeys(privKeys ...*secp256k1.PrivateKey) error
    40  
    41  	// GetKey returns the private key that controls the given address
    42  	GetKey(address ids.ShortID) (*secp256k1.PrivateKey, error)
    43  }
    44  
    45  type user struct {
    46  	db *encdb.Database
    47  }
    48  
    49  // NewUserFromKeystore tracks a keystore user from the provided keystore
    50  func NewUserFromKeystore(ks keystore.BlockchainKeystore, username, password string) (User, error) {
    51  	db, err := ks.GetDatabase(username, password)
    52  	if err != nil {
    53  		return nil, fmt.Errorf("problem retrieving user %q: %w", username, err)
    54  	}
    55  	return NewUserFromDB(db), nil
    56  }
    57  
    58  // NewUserFromDB tracks a keystore user from a database
    59  func NewUserFromDB(db *encdb.Database) User {
    60  	return &user{db: db}
    61  }
    62  
    63  func (u *user) GetAddresses() ([]ids.ShortID, error) {
    64  	// Get user's addresses
    65  	addressBytes, err := u.db.Get(addressesKey)
    66  	if err == database.ErrNotFound {
    67  		// If user has no addresses, return empty list
    68  		return nil, nil
    69  	}
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	var addresses []ids.ShortID
    75  	_, err = LegacyCodec.Unmarshal(addressBytes, &addresses)
    76  	return addresses, err
    77  }
    78  
    79  func (u *user) PutKeys(privKeys ...*secp256k1.PrivateKey) error {
    80  	toStore := make([]*secp256k1.PrivateKey, 0, len(privKeys))
    81  	for _, privKey := range privKeys {
    82  		address := privKey.PublicKey().Address() // address the privKey controls
    83  		hasAddress, err := u.db.Has(address.Bytes())
    84  		if err != nil {
    85  			return err
    86  		}
    87  		if !hasAddress {
    88  			toStore = append(toStore, privKey)
    89  		}
    90  	}
    91  
    92  	// there's nothing to store
    93  	if len(toStore) == 0 {
    94  		return nil
    95  	}
    96  
    97  	addresses, err := u.GetAddresses()
    98  	if err != nil {
    99  		return err
   100  	}
   101  
   102  	if len(toStore) > maxKeystoreAddresses || len(addresses) > maxKeystoreAddresses-len(toStore) {
   103  		return errMaxAddresses
   104  	}
   105  
   106  	for _, privKey := range toStore {
   107  		address := privKey.PublicKey().Address() // address the privKey controls
   108  		// Address --> private key
   109  		if err := u.db.Put(address.Bytes(), privKey.Bytes()); err != nil {
   110  			return err
   111  		}
   112  		addresses = append(addresses, address)
   113  	}
   114  
   115  	addressBytes, err := Codec.Marshal(CodecVersion, addresses)
   116  	if err != nil {
   117  		return err
   118  	}
   119  	return u.db.Put(addressesKey, addressBytes)
   120  }
   121  
   122  func (u *user) GetKey(address ids.ShortID) (*secp256k1.PrivateKey, error) {
   123  	bytes, err := u.db.Get(address.Bytes())
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  	return secp256k1.ToPrivateKey(bytes)
   128  }
   129  
   130  func (u *user) Close() error {
   131  	return u.db.Close()
   132  }
   133  
   134  // Create and store a new key that will be controlled by this user.
   135  func NewKey(u User) (*secp256k1.PrivateKey, error) {
   136  	keys, err := NewKeys(u, 1)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	return keys[0], nil
   141  }
   142  
   143  // Create and store [numKeys] new keys that will be controlled by this user.
   144  func NewKeys(u User, numKeys int) ([]*secp256k1.PrivateKey, error) {
   145  	keys := make([]*secp256k1.PrivateKey, numKeys)
   146  	for i := range keys {
   147  		sk, err := secp256k1.NewPrivateKey()
   148  		if err != nil {
   149  			return nil, err
   150  		}
   151  		keys[i] = sk
   152  	}
   153  	return keys, u.PutKeys(keys...)
   154  }
   155  
   156  // Keychain returns a new keychain from the [user].
   157  // If [addresses] is non-empty it fetches only the keys in addresses. If a key
   158  // is missing, it will be ignored.
   159  // If [addresses] is empty, then it will create a keychain using every address
   160  // in the provided [user].
   161  func GetKeychain(u User, addresses set.Set[ids.ShortID]) (*secp256k1fx.Keychain, error) {
   162  	addrsList := addresses.List()
   163  	if len(addrsList) == 0 {
   164  		var err error
   165  		addrsList, err = u.GetAddresses()
   166  		if err != nil {
   167  			return nil, err
   168  		}
   169  	}
   170  
   171  	kc := secp256k1fx.NewKeychain()
   172  	for _, addr := range addrsList {
   173  		sk, err := u.GetKey(addr)
   174  		if err == database.ErrNotFound {
   175  			continue
   176  		}
   177  		if err != nil {
   178  			return nil, fmt.Errorf("problem retrieving private key for address %s: %w", addr, err)
   179  		}
   180  		kc.Add(sk)
   181  	}
   182  	return kc, nil
   183  }