github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/accounts/keystore/keystore.go (about)

     1  package keystore
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	crand "crypto/rand"
     6  	"errors"
     7  	"fmt"
     8  	"math/big"
     9  	"os"
    10  	"path/filepath"
    11  	"reflect"
    12  	"runtime"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/neatlab/neatio/chain/accounts"
    17  	"github.com/neatlab/neatio/chain/core/types"
    18  	"github.com/neatlab/neatio/utilities/common"
    19  	"github.com/neatlab/neatio/utilities/crypto"
    20  	"github.com/neatlab/neatio/utilities/event"
    21  )
    22  
    23  var (
    24  	ErrLocked  = accounts.NewAuthNeededError("password or unlock")
    25  	ErrNoMatch = errors.New("no key for given address or file")
    26  	ErrDecrypt = errors.New("could not decrypt key with given passphrase")
    27  )
    28  
    29  var KeyStoreType = reflect.TypeOf(&KeyStore{})
    30  
    31  var KeyStoreScheme = "keystore"
    32  
    33  const walletRefreshCycle = 3 * time.Second
    34  
    35  type KeyStore struct {
    36  	storage  keyStore
    37  	cache    *accountCache
    38  	changes  chan struct{}
    39  	unlocked map[common.Address]*unlocked
    40  
    41  	wallets     []accounts.Wallet
    42  	updateFeed  event.Feed
    43  	updateScope event.SubscriptionScope
    44  	updating    bool
    45  
    46  	mu sync.RWMutex
    47  }
    48  
    49  type unlocked struct {
    50  	*Key
    51  	abort chan struct{}
    52  }
    53  
    54  func NewKeyStore(keydir string, scryptN, scryptP int) *KeyStore {
    55  	keydir, _ = filepath.Abs(keydir)
    56  	ks := &KeyStore{storage: &keyStorePassphrase{keydir, scryptN, scryptP}}
    57  	ks.init(keydir)
    58  	return ks
    59  }
    60  
    61  func NewPlaintextKeyStore(keydir string) *KeyStore {
    62  	keydir, _ = filepath.Abs(keydir)
    63  	ks := &KeyStore{storage: &keyStorePlain{keydir}}
    64  	ks.init(keydir)
    65  	return ks
    66  }
    67  
    68  func (ks *KeyStore) init(keydir string) {
    69  
    70  	ks.mu.Lock()
    71  	defer ks.mu.Unlock()
    72  
    73  	ks.unlocked = make(map[common.Address]*unlocked)
    74  	ks.cache, ks.changes = newAccountCache(keydir)
    75  
    76  	runtime.SetFinalizer(ks, func(m *KeyStore) {
    77  		m.cache.close()
    78  	})
    79  
    80  	accs := ks.cache.accounts()
    81  	ks.wallets = make([]accounts.Wallet, len(accs))
    82  	for i := 0; i < len(accs); i++ {
    83  		ks.wallets[i] = &keystoreWallet{account: accs[i], keystore: ks}
    84  	}
    85  }
    86  
    87  func (ks *KeyStore) Wallets() []accounts.Wallet {
    88  
    89  	ks.refreshWallets()
    90  
    91  	ks.mu.RLock()
    92  	defer ks.mu.RUnlock()
    93  
    94  	cpy := make([]accounts.Wallet, len(ks.wallets))
    95  	copy(cpy, ks.wallets)
    96  	return cpy
    97  }
    98  
    99  func (ks *KeyStore) refreshWallets() {
   100  
   101  	ks.mu.Lock()
   102  	accs := ks.cache.accounts()
   103  
   104  	wallets := make([]accounts.Wallet, 0, len(accs))
   105  	events := []accounts.WalletEvent{}
   106  
   107  	for _, account := range accs {
   108  
   109  		for len(ks.wallets) > 0 && ks.wallets[0].URL().Cmp(account.URL) < 0 {
   110  			events = append(events, accounts.WalletEvent{Wallet: ks.wallets[0], Kind: accounts.WalletDropped})
   111  			ks.wallets = ks.wallets[1:]
   112  		}
   113  
   114  		if len(ks.wallets) == 0 || ks.wallets[0].URL().Cmp(account.URL) > 0 {
   115  			wallet := &keystoreWallet{account: account, keystore: ks}
   116  
   117  			events = append(events, accounts.WalletEvent{Wallet: wallet, Kind: accounts.WalletArrived})
   118  			wallets = append(wallets, wallet)
   119  			continue
   120  		}
   121  
   122  		if ks.wallets[0].Accounts()[0] == account {
   123  			wallets = append(wallets, ks.wallets[0])
   124  			ks.wallets = ks.wallets[1:]
   125  			continue
   126  		}
   127  	}
   128  
   129  	for _, wallet := range ks.wallets {
   130  		events = append(events, accounts.WalletEvent{Wallet: wallet, Kind: accounts.WalletDropped})
   131  	}
   132  	ks.wallets = wallets
   133  	ks.mu.Unlock()
   134  
   135  	for _, event := range events {
   136  		ks.updateFeed.Send(event)
   137  	}
   138  }
   139  
   140  func (ks *KeyStore) Subscribe(sink chan<- accounts.WalletEvent) event.Subscription {
   141  
   142  	ks.mu.Lock()
   143  	defer ks.mu.Unlock()
   144  
   145  	sub := ks.updateScope.Track(ks.updateFeed.Subscribe(sink))
   146  
   147  	if !ks.updating {
   148  		ks.updating = true
   149  		go ks.updater()
   150  	}
   151  	return sub
   152  }
   153  
   154  func (ks *KeyStore) updater() {
   155  	for {
   156  
   157  		select {
   158  		case <-ks.changes:
   159  		case <-time.After(walletRefreshCycle):
   160  		}
   161  
   162  		ks.refreshWallets()
   163  
   164  		ks.mu.Lock()
   165  		if ks.updateScope.Count() == 0 {
   166  			ks.updating = false
   167  			ks.mu.Unlock()
   168  			return
   169  		}
   170  		ks.mu.Unlock()
   171  	}
   172  }
   173  
   174  func (ks *KeyStore) HasAddress(addr common.Address) bool {
   175  	return ks.cache.hasAddress(addr)
   176  }
   177  
   178  func (ks *KeyStore) Accounts() []accounts.Account {
   179  	return ks.cache.accounts()
   180  }
   181  
   182  func (ks *KeyStore) Delete(a accounts.Account, passphrase string) error {
   183  
   184  	a, key, err := ks.getDecryptedKey(a, passphrase)
   185  	if key != nil {
   186  		zeroKey(key.PrivateKey)
   187  	}
   188  	if err != nil {
   189  		return err
   190  	}
   191  
   192  	err = os.Remove(a.URL.Path)
   193  	if err == nil {
   194  		ks.cache.delete(a)
   195  		ks.refreshWallets()
   196  	}
   197  	return err
   198  }
   199  
   200  func (ks *KeyStore) SignHash(a accounts.Account, hash []byte) ([]byte, error) {
   201  
   202  	ks.mu.RLock()
   203  	defer ks.mu.RUnlock()
   204  
   205  	unlockedKey, found := ks.unlocked[a.Address]
   206  	if !found {
   207  		return nil, ErrLocked
   208  	}
   209  
   210  	return crypto.Sign(hash, unlockedKey.PrivateKey)
   211  }
   212  
   213  func (ks *KeyStore) SignTx(a accounts.Account, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) {
   214  
   215  	ks.mu.RLock()
   216  	defer ks.mu.RUnlock()
   217  
   218  	unlockedKey, found := ks.unlocked[a.Address]
   219  	if !found {
   220  		return nil, ErrLocked
   221  	}
   222  
   223  	if chainID != nil {
   224  		return types.SignTx(tx, types.NewEIP155Signer(chainID), unlockedKey.PrivateKey)
   225  	}
   226  	return types.SignTx(tx, types.HomesteadSigner{}, unlockedKey.PrivateKey)
   227  }
   228  
   229  func (ks *KeyStore) SignTxWithAddress(a accounts.Account, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) {
   230  
   231  	ks.mu.RLock()
   232  	defer ks.mu.RUnlock()
   233  
   234  	unlockedKey, found := ks.unlocked[a.Address]
   235  	if !found {
   236  		return nil, ErrLocked
   237  	}
   238  
   239  	if chainID != nil {
   240  		return types.SignTxWithAddress(tx, types.NewEIP155Signer(chainID), unlockedKey.PrivateKey)
   241  	}
   242  	return types.SignTxWithAddress(tx, types.HomesteadSigner{}, unlockedKey.PrivateKey)
   243  }
   244  
   245  func (ks *KeyStore) SignHashWithPassphrase(a accounts.Account, passphrase string, hash []byte) (signature []byte, err error) {
   246  	_, key, err := ks.getDecryptedKey(a, passphrase)
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  	defer zeroKey(key.PrivateKey)
   251  	return crypto.Sign(hash, key.PrivateKey)
   252  }
   253  
   254  func (ks *KeyStore) SignTxWithPassphrase(a accounts.Account, passphrase string, tx *types.Transaction, chainID *big.Int) (*types.Transaction, error) {
   255  	_, key, err := ks.getDecryptedKey(a, passphrase)
   256  	if err != nil {
   257  		return nil, err
   258  	}
   259  	defer zeroKey(key.PrivateKey)
   260  
   261  	if chainID != nil {
   262  		return types.SignTx(tx, types.NewEIP155Signer(chainID), key.PrivateKey)
   263  	}
   264  	return types.SignTx(tx, types.HomesteadSigner{}, key.PrivateKey)
   265  }
   266  
   267  func (ks *KeyStore) Unlock(a accounts.Account, passphrase string) error {
   268  	return ks.TimedUnlock(a, passphrase, 0)
   269  }
   270  
   271  func (ks *KeyStore) Lock(addr common.Address) error {
   272  	ks.mu.Lock()
   273  	if unl, found := ks.unlocked[addr]; found {
   274  		ks.mu.Unlock()
   275  		ks.expire(addr, unl, time.Duration(0)*time.Nanosecond)
   276  	} else {
   277  		ks.mu.Unlock()
   278  	}
   279  	return nil
   280  }
   281  
   282  func (ks *KeyStore) TimedUnlock(a accounts.Account, passphrase string, timeout time.Duration) error {
   283  	a, key, err := ks.getDecryptedKey(a, passphrase)
   284  	if err != nil {
   285  		return err
   286  	}
   287  
   288  	ks.mu.Lock()
   289  	defer ks.mu.Unlock()
   290  	u, found := ks.unlocked[a.Address]
   291  	if found {
   292  		if u.abort == nil {
   293  
   294  			zeroKey(key.PrivateKey)
   295  			return nil
   296  		}
   297  
   298  		close(u.abort)
   299  	}
   300  	if timeout > 0 {
   301  		u = &unlocked{Key: key, abort: make(chan struct{})}
   302  		go ks.expire(a.Address, u, timeout)
   303  	} else {
   304  		u = &unlocked{Key: key}
   305  	}
   306  	ks.unlocked[a.Address] = u
   307  	return nil
   308  }
   309  
   310  func (ks *KeyStore) Find(a accounts.Account) (accounts.Account, error) {
   311  	ks.cache.maybeReload()
   312  	ks.cache.mu.Lock()
   313  	a, err := ks.cache.find(a)
   314  	ks.cache.mu.Unlock()
   315  	return a, err
   316  }
   317  
   318  func (ks *KeyStore) getDecryptedKey(a accounts.Account, auth string) (accounts.Account, *Key, error) {
   319  	a, err := ks.Find(a)
   320  	if err != nil {
   321  		return a, nil, err
   322  	}
   323  	key, err := ks.storage.GetKey(a.Address, a.URL.Path, auth)
   324  	return a, key, err
   325  }
   326  
   327  func (ks *KeyStore) expire(addr common.Address, u *unlocked, timeout time.Duration) {
   328  	t := time.NewTimer(timeout)
   329  	defer t.Stop()
   330  	select {
   331  	case <-u.abort:
   332  
   333  	case <-t.C:
   334  		ks.mu.Lock()
   335  
   336  		if ks.unlocked[addr] == u {
   337  			zeroKey(u.PrivateKey)
   338  			delete(ks.unlocked, addr)
   339  		}
   340  		ks.mu.Unlock()
   341  	}
   342  }
   343  
   344  func (ks *KeyStore) NewAccount(passphrase string) (accounts.Account, error) {
   345  	_, account, err := storeNewKey(ks.storage, crand.Reader, passphrase)
   346  	if err != nil {
   347  		return accounts.Account{}, err
   348  	}
   349  
   350  	ks.cache.add(account)
   351  	ks.refreshWallets()
   352  	return account, nil
   353  }
   354  
   355  func (ks *KeyStore) Export(a accounts.Account, passphrase, newPassphrase string) (keyJSON []byte, err error) {
   356  	_, key, err := ks.getDecryptedKey(a, passphrase)
   357  	if err != nil {
   358  		return nil, err
   359  	}
   360  	var N, P int
   361  	if store, ok := ks.storage.(*keyStorePassphrase); ok {
   362  		N, P = store.scryptN, store.scryptP
   363  	} else {
   364  		N, P = StandardScryptN, StandardScryptP
   365  	}
   366  	return EncryptKey(key, newPassphrase, N, P)
   367  }
   368  
   369  func (ks *KeyStore) Import(keyJSON []byte, passphrase, newPassphrase string) (accounts.Account, error) {
   370  	key, err := DecryptKey(keyJSON, passphrase)
   371  	if key != nil && key.PrivateKey != nil {
   372  		defer zeroKey(key.PrivateKey)
   373  	}
   374  	if err != nil {
   375  		return accounts.Account{}, err
   376  	}
   377  	return ks.importKey(key, newPassphrase)
   378  }
   379  
   380  func (ks *KeyStore) ImportECDSA(priv *ecdsa.PrivateKey, passphrase string) (accounts.Account, error) {
   381  	key := newKeyFromECDSA(priv)
   382  	if ks.cache.hasAddress(key.Address) {
   383  		return accounts.Account{}, fmt.Errorf("account already exists")
   384  	}
   385  	return ks.importKey(key, passphrase)
   386  }
   387  
   388  func (ks *KeyStore) importKey(key *Key, passphrase string) (accounts.Account, error) {
   389  	a := accounts.Account{Address: key.Address, URL: accounts.URL{Scheme: KeyStoreScheme, Path: ks.storage.JoinPath(keyFileName(key.Address))}}
   390  	if err := ks.storage.StoreKey(a.URL.Path, key, passphrase); err != nil {
   391  		return accounts.Account{}, err
   392  	}
   393  	ks.cache.add(a)
   394  	ks.refreshWallets()
   395  	return a, nil
   396  }
   397  
   398  func (ks *KeyStore) Update(a accounts.Account, passphrase, newPassphrase string) error {
   399  	a, key, err := ks.getDecryptedKey(a, passphrase)
   400  	if err != nil {
   401  		return err
   402  	}
   403  	return ks.storage.StoreKey(a.URL.Path, key, newPassphrase)
   404  }
   405  
   406  func (ks *KeyStore) ImportPreSaleKey(keyJSON []byte, passphrase string) (accounts.Account, error) {
   407  	a, _, err := importPreSaleKey(ks.storage, keyJSON, passphrase)
   408  	if err != nil {
   409  		return a, err
   410  	}
   411  	ks.cache.add(a)
   412  	ks.refreshWallets()
   413  	return a, nil
   414  }
   415  
   416  func zeroKey(k *ecdsa.PrivateKey) {
   417  	b := k.D.Bits()
   418  	for i := range b {
   419  		b[i] = 0
   420  	}
   421  }