github.com/aergoio/aergo@v1.3.1/state/statebuffer.go (about)

     1  package state
     2  
     3  import (
     4  	"sort"
     5  
     6  	"github.com/aergoio/aergo/internal/common"
     7  	"github.com/aergoio/aergo/pkg/trie"
     8  	"github.com/aergoio/aergo/types"
     9  	"github.com/golang/protobuf/proto"
    10  )
    11  
    12  type entry interface {
    13  	KeyID() types.HashID
    14  	Hash() []byte
    15  	Value() interface{}
    16  }
    17  
    18  type cached interface {
    19  	cache() *stateBuffer
    20  }
    21  
    22  type valueEntry struct {
    23  	key   types.HashID
    24  	value interface{}
    25  }
    26  
    27  func newValueEntry(key types.HashID, value interface{}) entry {
    28  	return &valueEntry{
    29  		key:   key,
    30  		value: value,
    31  	}
    32  }
    33  func newValueEntryDelete(key types.HashID) entry {
    34  	return &valueEntry{
    35  		key:   key,
    36  		value: nil,
    37  	}
    38  }
    39  func (et *valueEntry) KeyID() types.HashID {
    40  	return et.key
    41  }
    42  func (et *valueEntry) Hash() []byte {
    43  	if hash := getHashBytes(et.value); hash != nil {
    44  		return hash
    45  	}
    46  	return []byte{0}
    47  }
    48  func (et *valueEntry) Value() interface{} {
    49  	return et.value
    50  }
    51  
    52  type metaEntry struct {
    53  	*valueEntry
    54  }
    55  
    56  func newMetaEntry(key types.HashID, value interface{}) entry {
    57  	return &metaEntry{
    58  		valueEntry: &valueEntry{
    59  			key:   key,
    60  			value: value,
    61  		},
    62  	}
    63  }
    64  
    65  type bufferIndex map[types.HashID]*stack
    66  
    67  func (idxs *bufferIndex) peek(key types.HashID) int {
    68  	return (*idxs)[key].peek()
    69  }
    70  func (idxs *bufferIndex) pop(key types.HashID) int {
    71  	return (*idxs)[key].pop()
    72  }
    73  func (idxs *bufferIndex) push(key types.HashID, argv ...int) {
    74  	(*idxs)[key] = (*idxs)[key].push(argv...)
    75  }
    76  func (idxs *bufferIndex) rollback(snapshot int) {
    77  	for k, v := range *idxs {
    78  		for v.peek() >= snapshot {
    79  			v.pop()
    80  		}
    81  		if v.peek() < 0 {
    82  			delete(*idxs, k)
    83  		}
    84  	}
    85  }
    86  
    87  type stateBuffer struct {
    88  	entries []entry
    89  	indexes bufferIndex
    90  	nextIdx int
    91  }
    92  
    93  func newStateBuffer() *stateBuffer {
    94  	buffer := stateBuffer{
    95  		entries: []entry{},
    96  		indexes: bufferIndex{},
    97  		nextIdx: 0,
    98  	}
    99  	return &buffer
   100  }
   101  
   102  func (buffer *stateBuffer) reset() error {
   103  	return buffer.rollback(0)
   104  }
   105  
   106  func (buffer *stateBuffer) get(key types.HashID) entry {
   107  	if index, ok := buffer.indexes[key]; ok {
   108  		return buffer.entries[index.peek()]
   109  	}
   110  	return nil
   111  }
   112  func (buffer *stateBuffer) has(key types.HashID) bool {
   113  	_, ok := buffer.indexes[key]
   114  	return ok
   115  }
   116  
   117  func (buffer *stateBuffer) put(et entry) {
   118  	snapshot := buffer.snapshot()
   119  	buffer.entries = append(buffer.entries, et)
   120  	buffer.indexes[et.KeyID()] = buffer.indexes[et.KeyID()].push(snapshot)
   121  	buffer.nextIdx++
   122  }
   123  
   124  func (buffer *stateBuffer) snapshot() int {
   125  	return buffer.nextIdx
   126  }
   127  
   128  func (buffer *stateBuffer) rollback(snapshot int) error {
   129  	for i := buffer.nextIdx - 1; i >= snapshot; i-- {
   130  		et := buffer.entries[i]
   131  		buffer.indexes.pop(et.KeyID())
   132  		idx := buffer.indexes.peek(et.KeyID())
   133  		if idx < 0 {
   134  			delete(buffer.indexes, et.KeyID())
   135  			continue
   136  		}
   137  	}
   138  	buffer.entries = buffer.entries[:snapshot]
   139  	//buffer.indexes.rollback(snapshot)
   140  	buffer.nextIdx = snapshot
   141  	return nil
   142  }
   143  
   144  func (buffer *stateBuffer) isEmpty() bool {
   145  	return len(buffer.entries) == 0
   146  }
   147  
   148  func (buffer *stateBuffer) export() ([][]byte, [][]byte) {
   149  	bufs := make([]entry, 0, len(buffer.indexes))
   150  	for _, v := range buffer.indexes {
   151  		idx := v.peek()
   152  		if idx < 0 {
   153  			continue
   154  		}
   155  		et := buffer.entries[idx]
   156  		if _, ok := et.(metaEntry); ok {
   157  			// skip meta entry
   158  			continue
   159  		}
   160  		bufs = append(bufs, et)
   161  	}
   162  	sort.Slice(bufs, func(i, j int) bool {
   163  		return -1 == (bufs[i].KeyID().Compare(bufs[j].KeyID()))
   164  	})
   165  	size := len(bufs)
   166  	keys := make([][]byte, size)
   167  	vals := make([][]byte, size)
   168  	for i, et := range bufs {
   169  		keys[i] = append(keys[i], et.KeyID().Bytes()...)
   170  		vals[i] = append(vals[i], et.Hash()...)
   171  	}
   172  	return keys, vals
   173  }
   174  
   175  func (buffer *stateBuffer) updateTrie(tr *trie.Trie) error {
   176  	keys, vals := buffer.export()
   177  	if len(keys) == 0 || len(vals) == 0 {
   178  		// nothing to update
   179  		return nil
   180  	}
   181  	if _, err := tr.Update(keys, vals); err != nil {
   182  		return err
   183  	}
   184  	return nil
   185  }
   186  
   187  func (buffer *stateBuffer) stage(txn trie.DbTx) error {
   188  	for _, v := range buffer.indexes {
   189  		et := buffer.entries[v.peek()]
   190  		buf, err := marshal(et.Value())
   191  		if err != nil {
   192  			return err
   193  		}
   194  		txn.Set(et.Hash(), buf)
   195  	}
   196  	return nil
   197  }
   198  
   199  func marshal(data interface{}) ([]byte, error) {
   200  	switch data.(type) {
   201  	case ([]byte):
   202  		return data.([]byte), nil
   203  	case (*[]byte):
   204  		return *(data.(*[]byte)), nil
   205  	case (types.ImplMarshal):
   206  		return data.(types.ImplMarshal).Marshal()
   207  	case (proto.Message):
   208  		return proto.Marshal(data.(proto.Message))
   209  	}
   210  	return nil, nil
   211  }
   212  
   213  func getHashBytes(data interface{}) []byte {
   214  	if data == nil {
   215  		return nil
   216  	}
   217  	switch data.(type) {
   218  	case (types.ImplHashBytes):
   219  		return data.(types.ImplHashBytes).Hash()
   220  	default:
   221  	}
   222  	buf, err := marshal(data)
   223  	if err != nil {
   224  		logger.Error().Err(err).Msg("failed to get hash bytes: marshal")
   225  		return nil
   226  	}
   227  	return common.Hasher(buf)
   228  }