github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/network/p2p/discover/database.go (about)

     1  package discover
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"encoding/binary"
     7  	"os"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/neatlab/neatio/chain/log"
    12  	"github.com/neatlab/neatio/utilities/crypto"
    13  	"github.com/neatlab/neatio/utilities/rlp"
    14  	"github.com/syndtr/goleveldb/leveldb"
    15  	"github.com/syndtr/goleveldb/leveldb/errors"
    16  	"github.com/syndtr/goleveldb/leveldb/iterator"
    17  	"github.com/syndtr/goleveldb/leveldb/opt"
    18  	"github.com/syndtr/goleveldb/leveldb/storage"
    19  	"github.com/syndtr/goleveldb/leveldb/util"
    20  )
    21  
    22  var (
    23  	nodeDBNilNodeID      = NodeID{}
    24  	nodeDBNodeExpiration = 24 * time.Hour
    25  	nodeDBCleanupCycle   = time.Hour
    26  )
    27  
    28  type nodeDB struct {
    29  	lvl    *leveldb.DB
    30  	self   NodeID
    31  	runner sync.Once
    32  	quit   chan struct{}
    33  }
    34  
    35  var (
    36  	nodeDBVersionKey = []byte("version")
    37  	nodeDBItemPrefix = []byte("n:")
    38  
    39  	nodeDBDiscoverRoot      = ":discover"
    40  	nodeDBDiscoverPing      = nodeDBDiscoverRoot + ":lastping"
    41  	nodeDBDiscoverPong      = nodeDBDiscoverRoot + ":lastpong"
    42  	nodeDBDiscoverFindFails = nodeDBDiscoverRoot + ":findfail"
    43  )
    44  
    45  func newNodeDB(path string, version int, self NodeID) (*nodeDB, error) {
    46  	if path == "" {
    47  		return newMemoryNodeDB(self)
    48  	}
    49  	return newPersistentNodeDB(path, version, self)
    50  }
    51  
    52  func newMemoryNodeDB(self NodeID) (*nodeDB, error) {
    53  	db, err := leveldb.Open(storage.NewMemStorage(), nil)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	return &nodeDB{
    58  		lvl:  db,
    59  		self: self,
    60  		quit: make(chan struct{}),
    61  	}, nil
    62  }
    63  
    64  func newPersistentNodeDB(path string, version int, self NodeID) (*nodeDB, error) {
    65  	opts := &opt.Options{OpenFilesCacheCapacity: 5}
    66  	db, err := leveldb.OpenFile(path, opts)
    67  	if _, iscorrupted := err.(*errors.ErrCorrupted); iscorrupted {
    68  		db, err = leveldb.RecoverFile(path, nil)
    69  	}
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	currentVer := make([]byte, binary.MaxVarintLen64)
    75  	currentVer = currentVer[:binary.PutVarint(currentVer, int64(version))]
    76  
    77  	blob, err := db.Get(nodeDBVersionKey, nil)
    78  	switch err {
    79  	case leveldb.ErrNotFound:
    80  
    81  		if err := db.Put(nodeDBVersionKey, currentVer, nil); err != nil {
    82  			db.Close()
    83  			return nil, err
    84  		}
    85  
    86  	case nil:
    87  
    88  		if !bytes.Equal(blob, currentVer) {
    89  			db.Close()
    90  			if err = os.RemoveAll(path); err != nil {
    91  				return nil, err
    92  			}
    93  			return newPersistentNodeDB(path, version, self)
    94  		}
    95  	}
    96  	return &nodeDB{
    97  		lvl:  db,
    98  		self: self,
    99  		quit: make(chan struct{}),
   100  	}, nil
   101  }
   102  
   103  func makeKey(id NodeID, field string) []byte {
   104  	if bytes.Equal(id[:], nodeDBNilNodeID[:]) {
   105  		return []byte(field)
   106  	}
   107  	return append(nodeDBItemPrefix, append(id[:], field...)...)
   108  }
   109  
   110  func splitKey(key []byte) (id NodeID, field string) {
   111  
   112  	if !bytes.HasPrefix(key, nodeDBItemPrefix) {
   113  		return NodeID{}, string(key)
   114  	}
   115  
   116  	item := key[len(nodeDBItemPrefix):]
   117  	copy(id[:], item[:len(id)])
   118  	field = string(item[len(id):])
   119  
   120  	return id, field
   121  }
   122  
   123  func (db *nodeDB) fetchInt64(key []byte) int64 {
   124  	blob, err := db.lvl.Get(key, nil)
   125  	if err != nil {
   126  		return 0
   127  	}
   128  	val, read := binary.Varint(blob)
   129  	if read <= 0 {
   130  		return 0
   131  	}
   132  	return val
   133  }
   134  
   135  func (db *nodeDB) storeInt64(key []byte, n int64) error {
   136  	blob := make([]byte, binary.MaxVarintLen64)
   137  	blob = blob[:binary.PutVarint(blob, n)]
   138  
   139  	return db.lvl.Put(key, blob, nil)
   140  }
   141  
   142  func (db *nodeDB) node(id NodeID) *Node {
   143  	blob, err := db.lvl.Get(makeKey(id, nodeDBDiscoverRoot), nil)
   144  	if err != nil {
   145  		return nil
   146  	}
   147  	node := new(Node)
   148  	if err := rlp.DecodeBytes(blob, node); err != nil {
   149  		log.Error("Failed to decode node RLP", "err", err)
   150  		return nil
   151  	}
   152  	node.sha = crypto.Keccak256Hash(node.ID[:])
   153  	return node
   154  }
   155  
   156  func (db *nodeDB) updateNode(node *Node) error {
   157  	blob, err := rlp.EncodeToBytes(node)
   158  	if err != nil {
   159  		return err
   160  	}
   161  	return db.lvl.Put(makeKey(node.ID, nodeDBDiscoverRoot), blob, nil)
   162  }
   163  
   164  func (db *nodeDB) deleteNode(id NodeID) error {
   165  	deleter := db.lvl.NewIterator(util.BytesPrefix(makeKey(id, "")), nil)
   166  	for deleter.Next() {
   167  		if err := db.lvl.Delete(deleter.Key(), nil); err != nil {
   168  			return err
   169  		}
   170  	}
   171  	return nil
   172  }
   173  
   174  func (db *nodeDB) ensureExpirer() {
   175  	db.runner.Do(func() { go db.expirer() })
   176  }
   177  
   178  func (db *nodeDB) expirer() {
   179  	tick := time.NewTicker(nodeDBCleanupCycle)
   180  	defer tick.Stop()
   181  	for {
   182  		select {
   183  		case <-tick.C:
   184  			if err := db.expireNodes(); err != nil {
   185  				log.Error("Failed to expire nodedb items", "err", err)
   186  			}
   187  		case <-db.quit:
   188  			return
   189  		}
   190  	}
   191  }
   192  
   193  func (db *nodeDB) expireNodes() error {
   194  	threshold := time.Now().Add(-nodeDBNodeExpiration)
   195  
   196  	it := db.lvl.NewIterator(nil, nil)
   197  	defer it.Release()
   198  
   199  	for it.Next() {
   200  
   201  		id, field := splitKey(it.Key())
   202  		if field != nodeDBDiscoverRoot {
   203  			continue
   204  		}
   205  
   206  		if !bytes.Equal(id[:], db.self[:]) {
   207  			if seen := db.bondTime(id); seen.After(threshold) {
   208  				continue
   209  			}
   210  		}
   211  
   212  		db.deleteNode(id)
   213  	}
   214  	return nil
   215  }
   216  
   217  func (db *nodeDB) lastPing(id NodeID) time.Time {
   218  	return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0)
   219  }
   220  
   221  func (db *nodeDB) updateLastPing(id NodeID, instance time.Time) error {
   222  	return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix())
   223  }
   224  
   225  func (db *nodeDB) bondTime(id NodeID) time.Time {
   226  	return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0)
   227  }
   228  
   229  func (db *nodeDB) hasBond(id NodeID) bool {
   230  	return time.Since(db.bondTime(id)) < nodeDBNodeExpiration
   231  }
   232  
   233  func (db *nodeDB) updateBondTime(id NodeID, instance time.Time) error {
   234  	return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix())
   235  }
   236  
   237  func (db *nodeDB) findFails(id NodeID) int {
   238  	return int(db.fetchInt64(makeKey(id, nodeDBDiscoverFindFails)))
   239  }
   240  
   241  func (db *nodeDB) updateFindFails(id NodeID, fails int) error {
   242  	return db.storeInt64(makeKey(id, nodeDBDiscoverFindFails), int64(fails))
   243  }
   244  
   245  func (db *nodeDB) querySeeds(n int, maxAge time.Duration) []*Node {
   246  	var (
   247  		now   = time.Now()
   248  		nodes = make([]*Node, 0, n)
   249  		it    = db.lvl.NewIterator(nil, nil)
   250  		id    NodeID
   251  	)
   252  	defer it.Release()
   253  
   254  seek:
   255  	for seeks := 0; len(nodes) < n && seeks < n*5; seeks++ {
   256  
   257  		ctr := id[0]
   258  		rand.Read(id[:])
   259  		id[0] = ctr + id[0]%16
   260  		it.Seek(makeKey(id, nodeDBDiscoverRoot))
   261  
   262  		n := nextNode(it)
   263  		if n == nil {
   264  			id[0] = 0
   265  			continue seek
   266  		}
   267  		if n.ID == db.self {
   268  			continue seek
   269  		}
   270  		if now.Sub(db.bondTime(n.ID)) > maxAge {
   271  			continue seek
   272  		}
   273  		for i := range nodes {
   274  			if nodes[i].ID == n.ID {
   275  				continue seek
   276  			}
   277  		}
   278  		nodes = append(nodes, n)
   279  	}
   280  	return nodes
   281  }
   282  
   283  func nextNode(it iterator.Iterator) *Node {
   284  	for end := false; !end; end = !it.Next() {
   285  		id, field := splitKey(it.Key())
   286  		if field != nodeDBDiscoverRoot {
   287  			continue
   288  		}
   289  		var n Node
   290  		if err := rlp.DecodeBytes(it.Value(), &n); err != nil {
   291  			log.Warn("Failed to decode node RLP", "id", id, "err", err)
   292  			continue
   293  		}
   294  		return &n
   295  	}
   296  	return nil
   297  }
   298  
   299  func (db *nodeDB) close() {
   300  	close(db.quit)
   301  	db.lvl.Close()
   302  }