github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/chain/accounts/keystore/account_cache.go (about)

     1  package keystore
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/json"
     6  	"fmt"
     7  	"os"
     8  	"path/filepath"
     9  	"sort"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/neatio-net/neatio/chain/accounts"
    15  	"github.com/neatio-net/neatio/chain/log"
    16  	"github.com/neatio-net/neatio/utilities/common"
    17  	"github.com/neatio-net/set-go"
    18  )
    19  
    20  const minReloadInterval = 2 * time.Second
    21  
    22  type accountsByURL []accounts.Account
    23  
    24  func (s accountsByURL) Len() int           { return len(s) }
    25  func (s accountsByURL) Less(i, j int) bool { return s[i].URL.Cmp(s[j].URL) < 0 }
    26  func (s accountsByURL) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
    27  
    28  type AmbiguousAddrError struct {
    29  	Addr    common.Address
    30  	Matches []accounts.Account
    31  }
    32  
    33  func (err *AmbiguousAddrError) Error() string {
    34  	files := ""
    35  	for i, a := range err.Matches {
    36  		files += a.URL.Path
    37  		if i < len(err.Matches)-1 {
    38  			files += ", "
    39  		}
    40  	}
    41  	return fmt.Sprintf("multiple keys match address (%s)", files)
    42  }
    43  
    44  type accountCache struct {
    45  	keydir   string
    46  	watcher  *watcher
    47  	mu       sync.Mutex
    48  	all      accountsByURL
    49  	byAddr   map[common.Address][]accounts.Account
    50  	throttle *time.Timer
    51  	notify   chan struct{}
    52  	fileC    fileCache
    53  }
    54  
    55  func newAccountCache(keydir string) (*accountCache, chan struct{}) {
    56  	ac := &accountCache{
    57  		keydir: keydir,
    58  		byAddr: make(map[common.Address][]accounts.Account),
    59  		notify: make(chan struct{}, 1),
    60  		fileC:  fileCache{all: set.NewNonTS()},
    61  	}
    62  	ac.watcher = newWatcher(ac)
    63  	return ac, ac.notify
    64  }
    65  
    66  func (ac *accountCache) accounts() []accounts.Account {
    67  	ac.maybeReload()
    68  	ac.mu.Lock()
    69  	defer ac.mu.Unlock()
    70  	cpy := make([]accounts.Account, len(ac.all))
    71  	copy(cpy, ac.all)
    72  	return cpy
    73  }
    74  
    75  func (ac *accountCache) hasAddress(addr common.Address) bool {
    76  	ac.maybeReload()
    77  	ac.mu.Lock()
    78  	defer ac.mu.Unlock()
    79  	return len(ac.byAddr[addr]) > 0
    80  }
    81  
    82  func (ac *accountCache) add(newAccount accounts.Account) {
    83  	ac.mu.Lock()
    84  	defer ac.mu.Unlock()
    85  
    86  	i := sort.Search(len(ac.all), func(i int) bool { return ac.all[i].URL.Cmp(newAccount.URL) >= 0 })
    87  	if i < len(ac.all) && ac.all[i] == newAccount {
    88  		return
    89  	}
    90  
    91  	ac.all = append(ac.all, accounts.Account{})
    92  	copy(ac.all[i+1:], ac.all[i:])
    93  	ac.all[i] = newAccount
    94  	ac.byAddr[newAccount.Address] = append(ac.byAddr[newAccount.Address], newAccount)
    95  }
    96  
    97  func (ac *accountCache) delete(removed accounts.Account) {
    98  	ac.mu.Lock()
    99  	defer ac.mu.Unlock()
   100  
   101  	ac.all = removeAccount(ac.all, removed)
   102  	if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 {
   103  		delete(ac.byAddr, removed.Address)
   104  	} else {
   105  		ac.byAddr[removed.Address] = ba
   106  	}
   107  }
   108  
   109  func (ac *accountCache) deleteByFile(path string) {
   110  	ac.mu.Lock()
   111  	defer ac.mu.Unlock()
   112  	i := sort.Search(len(ac.all), func(i int) bool { return ac.all[i].URL.Path >= path })
   113  
   114  	if i < len(ac.all) && ac.all[i].URL.Path == path {
   115  		removed := ac.all[i]
   116  		ac.all = append(ac.all[:i], ac.all[i+1:]...)
   117  		if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 {
   118  			delete(ac.byAddr, removed.Address)
   119  		} else {
   120  			ac.byAddr[removed.Address] = ba
   121  		}
   122  	}
   123  }
   124  
   125  func removeAccount(slice []accounts.Account, elem accounts.Account) []accounts.Account {
   126  	for i := range slice {
   127  		if slice[i] == elem {
   128  			return append(slice[:i], slice[i+1:]...)
   129  		}
   130  	}
   131  	return slice
   132  }
   133  
   134  func (ac *accountCache) find(a accounts.Account) (accounts.Account, error) {
   135  
   136  	matches := ac.all
   137  	if (a.Address != common.Address{}) {
   138  		matches = ac.byAddr[a.Address]
   139  	}
   140  	if a.URL.Path != "" {
   141  
   142  		if !strings.ContainsRune(a.URL.Path, filepath.Separator) {
   143  			a.URL.Path = filepath.Join(ac.keydir, a.URL.Path)
   144  		}
   145  		for i := range matches {
   146  			if matches[i].URL == a.URL {
   147  				return matches[i], nil
   148  			}
   149  		}
   150  		if (a.Address == common.Address{}) {
   151  			return accounts.Account{}, ErrNoMatch
   152  		}
   153  	}
   154  	switch len(matches) {
   155  	case 1:
   156  		return matches[0], nil
   157  	case 0:
   158  		return accounts.Account{}, ErrNoMatch
   159  	default:
   160  		err := &AmbiguousAddrError{Addr: a.Address, Matches: make([]accounts.Account, len(matches))}
   161  		copy(err.Matches, matches)
   162  		sort.Sort(accountsByURL(err.Matches))
   163  		return accounts.Account{}, err
   164  	}
   165  }
   166  
   167  func (ac *accountCache) maybeReload() {
   168  	ac.mu.Lock()
   169  
   170  	if ac.watcher.running {
   171  		ac.mu.Unlock()
   172  		return
   173  	}
   174  	if ac.throttle == nil {
   175  		ac.throttle = time.NewTimer(0)
   176  	} else {
   177  		select {
   178  		case <-ac.throttle.C:
   179  		default:
   180  			ac.mu.Unlock()
   181  			return
   182  		}
   183  	}
   184  
   185  	ac.watcher.start()
   186  	ac.throttle.Reset(minReloadInterval)
   187  	ac.mu.Unlock()
   188  	ac.scanAccounts()
   189  }
   190  
   191  func (ac *accountCache) close() {
   192  	ac.mu.Lock()
   193  	ac.watcher.close()
   194  	if ac.throttle != nil {
   195  		ac.throttle.Stop()
   196  	}
   197  	if ac.notify != nil {
   198  		close(ac.notify)
   199  		ac.notify = nil
   200  	}
   201  	ac.mu.Unlock()
   202  }
   203  
   204  func (ac *accountCache) scanAccounts() error {
   205  
   206  	creates, deletes, updates, err := ac.fileC.scan(ac.keydir)
   207  	if err != nil {
   208  		log.Debug("Failed to reload keystore contents", "err", err)
   209  		return err
   210  	}
   211  	if creates.Size() == 0 && deletes.Size() == 0 && updates.Size() == 0 {
   212  		return nil
   213  	}
   214  
   215  	var (
   216  		buf = new(bufio.Reader)
   217  		key struct {
   218  			Address string `json:"address"`
   219  		}
   220  	)
   221  	readAccount := func(path string) *accounts.Account {
   222  		fd, err := os.Open(path)
   223  		if err != nil {
   224  			log.Trace("Failed to open keystore file", "path", path, "err", err)
   225  			return nil
   226  		}
   227  		defer fd.Close()
   228  		buf.Reset(fd)
   229  
   230  		key.Address = ""
   231  		err = json.NewDecoder(buf).Decode(&key)
   232  
   233  		addr := common.HexToAddress(key.Address)
   234  		switch {
   235  		case err != nil:
   236  			log.Debug("Failed to decode keystore key", "path", path, "err", err)
   237  		case (addr == common.Address{}):
   238  			log.Debug("Failed to decode keystore key", "path", path, "err", "missing or zero address")
   239  		default:
   240  			return &accounts.Account{Address: addr, URL: accounts.URL{Scheme: KeyStoreScheme, Path: path}}
   241  		}
   242  		return nil
   243  	}
   244  
   245  	start := time.Now()
   246  
   247  	for _, p := range creates.List() {
   248  		if a := readAccount(p.(string)); a != nil {
   249  			ac.add(*a)
   250  		}
   251  	}
   252  	for _, p := range deletes.List() {
   253  		ac.deleteByFile(p.(string))
   254  	}
   255  	for _, p := range updates.List() {
   256  		path := p.(string)
   257  		ac.deleteByFile(path)
   258  		if a := readAccount(path); a != nil {
   259  			ac.add(*a)
   260  		}
   261  	}
   262  	end := time.Now()
   263  
   264  	select {
   265  	case ac.notify <- struct{}{}:
   266  	default:
   267  	}
   268  	log.Trace("Handled keystore changes", "time", end.Sub(start))
   269  	return nil
   270  }