github.com/klaytn/klaytn@v1.12.1/accounts/keystore/account_cache.go (about)

     1  // Modifications Copyright 2018 The klaytn Authors
     2  // Copyright 2017 The go-ethereum Authors
     3  // This file is part of the go-ethereum library.
     4  //
     5  // The go-ethereum library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The go-ethereum library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    17  //
    18  // This file is derived from accounts/keystore/account_cache.go (2018/06/04).
    19  // Modified and improved for the klaytn development.
    20  
    21  package keystore
    22  
    23  import (
    24  	"bufio"
    25  	"encoding/json"
    26  	"fmt"
    27  	"os"
    28  	"path/filepath"
    29  	"sort"
    30  	"strings"
    31  	"sync"
    32  	"time"
    33  
    34  	"github.com/klaytn/klaytn/accounts"
    35  	"github.com/klaytn/klaytn/common"
    36  	"github.com/klaytn/klaytn/log"
    37  	"gopkg.in/fatih/set.v0"
    38  )
    39  
    40  // Minimum amount of time between cache reloads. This limit applies if the platform does
    41  // not support change notifications. It also applies if the keystore directory does not
    42  // exist yet, the code will attempt to create a watcher at most this often.
    43  const minReloadInterval = 2 * time.Second
    44  
    45  type accountsByURL []accounts.Account
    46  
    47  func (s accountsByURL) Len() int           { return len(s) }
    48  func (s accountsByURL) Less(i, j int) bool { return s[i].URL.Cmp(s[j].URL) < 0 }
    49  func (s accountsByURL) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
    50  
    51  var logger = log.NewModuleLogger(log.AccountsKeystore)
    52  
    53  // AmbiguousAddrError is returned when attempting to unlock
    54  // an address for which more than one file exists.
    55  type AmbiguousAddrError struct {
    56  	Addr    common.Address
    57  	Matches []accounts.Account
    58  }
    59  
    60  func (err *AmbiguousAddrError) Error() string {
    61  	files := ""
    62  	for i, a := range err.Matches {
    63  		files += a.URL.Path
    64  		if i < len(err.Matches)-1 {
    65  			files += ", "
    66  		}
    67  	}
    68  	return fmt.Sprintf("multiple keys match address (%s)", files)
    69  }
    70  
    71  // accountCache is a live index of all accounts in the keystore.
    72  type accountCache struct {
    73  	keydir   string
    74  	watcher  *watcher
    75  	mu       sync.Mutex
    76  	all      accountsByURL
    77  	byAddr   map[common.Address][]accounts.Account
    78  	throttle *time.Timer
    79  	notify   chan struct{}
    80  	fileC    fileCache
    81  }
    82  
    83  func newAccountCache(keydir string) (*accountCache, chan struct{}) {
    84  	ac := &accountCache{
    85  		keydir: keydir,
    86  		byAddr: make(map[common.Address][]accounts.Account),
    87  		notify: make(chan struct{}, 1),
    88  		fileC:  fileCache{all: set.NewNonTS()},
    89  	}
    90  	ac.watcher = newWatcher(ac)
    91  	return ac, ac.notify
    92  }
    93  
    94  func (ac *accountCache) accounts() []accounts.Account {
    95  	ac.maybeReload()
    96  	ac.mu.Lock()
    97  	defer ac.mu.Unlock()
    98  	cpy := make([]accounts.Account, len(ac.all))
    99  	copy(cpy, ac.all)
   100  	return cpy
   101  }
   102  
   103  func (ac *accountCache) hasAddress(addr common.Address) bool {
   104  	ac.maybeReload()
   105  	ac.mu.Lock()
   106  	defer ac.mu.Unlock()
   107  	return len(ac.byAddr[addr]) > 0
   108  }
   109  
   110  func (ac *accountCache) add(newAccount accounts.Account) {
   111  	ac.mu.Lock()
   112  	defer ac.mu.Unlock()
   113  
   114  	i := sort.Search(len(ac.all), func(i int) bool { return ac.all[i].URL.Cmp(newAccount.URL) >= 0 })
   115  	if i < len(ac.all) && ac.all[i] == newAccount {
   116  		return
   117  	}
   118  	// newAccount is not in the cache.
   119  	ac.all = append(ac.all, accounts.Account{})
   120  	copy(ac.all[i+1:], ac.all[i:])
   121  	ac.all[i] = newAccount
   122  	ac.byAddr[newAccount.Address] = append(ac.byAddr[newAccount.Address], newAccount)
   123  }
   124  
   125  // note: removed needs to be unique here (i.e. both File and Address must be set).
   126  func (ac *accountCache) delete(removed accounts.Account) {
   127  	ac.mu.Lock()
   128  	defer ac.mu.Unlock()
   129  
   130  	ac.all = removeAccount(ac.all, removed)
   131  	if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 {
   132  		delete(ac.byAddr, removed.Address)
   133  	} else {
   134  		ac.byAddr[removed.Address] = ba
   135  	}
   136  }
   137  
   138  // deleteByFile removes an account referenced by the given path.
   139  func (ac *accountCache) deleteByFile(path string) {
   140  	ac.mu.Lock()
   141  	defer ac.mu.Unlock()
   142  	i := sort.Search(len(ac.all), func(i int) bool { return ac.all[i].URL.Path >= path })
   143  
   144  	if i < len(ac.all) && ac.all[i].URL.Path == path {
   145  		removed := ac.all[i]
   146  		ac.all = append(ac.all[:i], ac.all[i+1:]...)
   147  		if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 {
   148  			delete(ac.byAddr, removed.Address)
   149  		} else {
   150  			ac.byAddr[removed.Address] = ba
   151  		}
   152  	}
   153  }
   154  
   155  func removeAccount(slice []accounts.Account, elem accounts.Account) []accounts.Account {
   156  	for i := range slice {
   157  		if slice[i] == elem {
   158  			return append(slice[:i], slice[i+1:]...)
   159  		}
   160  	}
   161  	return slice
   162  }
   163  
   164  // find returns the cached account for address if there is a unique match.
   165  // The exact matching rules are explained by the documentation of accounts.Account.
   166  // Callers must hold ac.mu.
   167  func (ac *accountCache) find(a accounts.Account) (accounts.Account, error) {
   168  	// Limit search to address candidates if possible.
   169  	matches := ac.all
   170  	if (a.Address != common.Address{}) {
   171  		matches = ac.byAddr[a.Address]
   172  	}
   173  	if a.URL.Path != "" {
   174  		// If only the basename is specified, complete the path.
   175  		if !strings.ContainsRune(a.URL.Path, filepath.Separator) {
   176  			a.URL.Path = filepath.Join(ac.keydir, a.URL.Path)
   177  		}
   178  		for i := range matches {
   179  			if matches[i].URL == a.URL {
   180  				return matches[i], nil
   181  			}
   182  		}
   183  		if (a.Address == common.Address{}) {
   184  			return accounts.Account{}, ErrNoMatch
   185  		}
   186  	}
   187  	switch len(matches) {
   188  	case 1:
   189  		return matches[0], nil
   190  	case 0:
   191  		return accounts.Account{}, ErrNoMatch
   192  	default:
   193  		err := &AmbiguousAddrError{Addr: a.Address, Matches: make([]accounts.Account, len(matches))}
   194  		copy(err.Matches, matches)
   195  		sort.Sort(accountsByURL(err.Matches))
   196  		return accounts.Account{}, err
   197  	}
   198  }
   199  
   200  func (ac *accountCache) maybeReload() {
   201  	ac.mu.Lock()
   202  
   203  	if ac.watcher.running {
   204  		ac.mu.Unlock()
   205  		return // A watcher is running and will keep the cache up-to-date.
   206  	}
   207  	if ac.throttle == nil {
   208  		ac.throttle = time.NewTimer(0)
   209  	} else {
   210  		select {
   211  		case <-ac.throttle.C:
   212  		default:
   213  			ac.mu.Unlock()
   214  			return // The cache was reloaded recently.
   215  		}
   216  	}
   217  	// No watcher running, start it.
   218  	ac.watcher.start()
   219  	ac.throttle.Reset(minReloadInterval)
   220  	ac.mu.Unlock()
   221  	ac.scanAccounts()
   222  }
   223  
   224  func (ac *accountCache) close() {
   225  	ac.mu.Lock()
   226  	ac.watcher.close()
   227  	if ac.throttle != nil {
   228  		ac.throttle.Stop()
   229  	}
   230  	if ac.notify != nil {
   231  		close(ac.notify)
   232  		ac.notify = nil
   233  	}
   234  	ac.mu.Unlock()
   235  }
   236  
   237  // scanAccounts checks if any changes have occurred on the filesystem, and
   238  // updates the account cache accordingly
   239  func (ac *accountCache) scanAccounts() error {
   240  	// Scan the entire folder metadata for file changes
   241  	creates, deletes, updates, err := ac.fileC.scan(ac.keydir)
   242  	if err != nil {
   243  		logger.Debug("Failed to reload keystore contents", "err", err)
   244  		return err
   245  	}
   246  	if creates.Size() == 0 && deletes.Size() == 0 && updates.Size() == 0 {
   247  		return nil
   248  	}
   249  	// Create a helper method to scan the contents of the key files
   250  	var (
   251  		buf = new(bufio.Reader)
   252  		key struct {
   253  			Address string `json:"address"`
   254  		}
   255  	)
   256  	readAccount := func(path string) *accounts.Account {
   257  		fd, err := os.Open(path)
   258  		if err != nil {
   259  			logger.Trace("Failed to open keystore file", "path", path, "err", err)
   260  			return nil
   261  		}
   262  		defer fd.Close()
   263  		buf.Reset(fd)
   264  		// Parse the address.
   265  		key.Address = ""
   266  		err = json.NewDecoder(buf).Decode(&key)
   267  		addr := common.HexToAddress(key.Address)
   268  		switch {
   269  		case err != nil:
   270  			logger.Debug("Failed to decode keystore key", "path", path, "err", err)
   271  		case (addr == common.Address{}):
   272  			logger.Debug("Failed to decode keystore key", "path", path, "err", "missing or zero address")
   273  		default:
   274  			return &accounts.Account{Address: addr, URL: accounts.URL{Scheme: KeyStoreScheme, Path: path}}
   275  		}
   276  		return nil
   277  	}
   278  	// Process all the file diffs
   279  	start := time.Now()
   280  
   281  	for _, p := range creates.List() {
   282  		if a := readAccount(p.(string)); a != nil {
   283  			ac.add(*a)
   284  		}
   285  	}
   286  	for _, p := range deletes.List() {
   287  		ac.deleteByFile(p.(string))
   288  	}
   289  	for _, p := range updates.List() {
   290  		path := p.(string)
   291  		ac.deleteByFile(path)
   292  		if a := readAccount(path); a != nil {
   293  			ac.add(*a)
   294  		}
   295  	}
   296  	end := time.Now()
   297  
   298  	select {
   299  	case ac.notify <- struct{}{}:
   300  	default:
   301  	}
   302  	logger.Trace("Handled keystore changes", "time", end.Sub(start))
   303  	return nil
   304  }