github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/core/state/statedb_forbidden.go (about)

     1  package state
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"sort"
     8  
     9  	"github.com/neatlab/neatio/utilities/common"
    10  	"github.com/neatlab/neatio/utilities/rlp"
    11  )
    12  
    13  // ----- banned Set
    14  
    15  // MarkAddressBanned adds the specified object to the dirty map
    16  func (self *StateDB) MarkAddressBanned(addr common.Address) {
    17  	if _, exist := self.GetBannedSet()[addr]; !exist {
    18  		self.bannedSet[addr] = struct{}{}
    19  		self.bannedSetDirty = true
    20  	}
    21  }
    22  
    23  func (self *StateDB) GetBannedSet() BannedSet {
    24  	if len(self.bannedSet) != 0 {
    25  		return self.bannedSet
    26  	}
    27  	// Try to get from Trie
    28  	enc, err := self.trie.TryGet(bannedSetKey)
    29  	if err != nil {
    30  		self.setError(err)
    31  		return nil
    32  	}
    33  	var value BannedSet
    34  	if len(enc) > 0 {
    35  		err := rlp.DecodeBytes(enc, &value)
    36  		if err != nil {
    37  			self.setError(err)
    38  		}
    39  		self.bannedSet = value
    40  	}
    41  	return value
    42  }
    43  
    44  func (self *StateDB) commitBannedSet() {
    45  	data, err := rlp.EncodeToBytes(self.bannedSet)
    46  	if err != nil {
    47  		panic(fmt.Errorf("can't encode banned set : %v", err))
    48  	}
    49  	self.setError(self.trie.TryUpdate(bannedSetKey, data))
    50  }
    51  
    52  func (self *StateDB) ClearBannedSetByAddress(addr common.Address) {
    53  	delete(self.bannedSet, addr)
    54  	self.bannedSetDirty = true
    55  }
    56  
    57  // Store the Banned Address Set
    58  
    59  var bannedSetKey = []byte("BannedSet")
    60  
    61  type BannedSet map[common.Address]struct{}
    62  
    63  func (set BannedSet) EncodeRLP(w io.Writer) error {
    64  	var list []common.Address
    65  	for addr := range set {
    66  		list = append(list, addr)
    67  	}
    68  	sort.Slice(list, func(i, j int) bool {
    69  		return bytes.Compare(list[i].Bytes(), list[j].Bytes()) == 1
    70  	})
    71  	return rlp.Encode(w, list)
    72  }
    73  
    74  func (set *BannedSet) DecodeRLP(s *rlp.Stream) error {
    75  	var list []common.Address
    76  	if err := s.Decode(&list); err != nil {
    77  		return err
    78  	}
    79  	bannedSet := make(BannedSet, len(list))
    80  	for _, addr := range list {
    81  		bannedSet[addr] = struct{}{}
    82  	}
    83  	*set = bannedSet
    84  	return nil
    85  }