github.com/ledgerwatch/erigon-lib@v1.0.0/commitment/patricia_state_mock_test.go (about)

     1  package commitment
     2  
     3  import (
     4  	"encoding/binary"
     5  	"encoding/hex"
     6  	"fmt"
     7  	"testing"
     8  
     9  	"github.com/holiman/uint256"
    10  	"golang.org/x/crypto/sha3"
    11  	"golang.org/x/exp/slices"
    12  
    13  	"github.com/ledgerwatch/erigon-lib/common"
    14  	"github.com/ledgerwatch/erigon-lib/common/length"
    15  )
    16  
    17  // In memory commitment and state to use with the tests
    18  type MockState struct {
    19  	t      *testing.T
    20  	sm     map[string][]byte     // backbone of the state
    21  	cm     map[string]BranchData // backbone of the commitments
    22  	numBuf [binary.MaxVarintLen64]byte
    23  }
    24  
    25  func NewMockState(t *testing.T) *MockState {
    26  	t.Helper()
    27  	return &MockState{
    28  		t:  t,
    29  		sm: make(map[string][]byte),
    30  		cm: make(map[string]BranchData),
    31  	}
    32  }
    33  
    34  func (ms MockState) branchFn(prefix []byte) ([]byte, error) {
    35  	if exBytes, ok := ms.cm[string(prefix)]; ok {
    36  		return exBytes[2:], nil // Skip touchMap, but keep afterMap
    37  	}
    38  	return nil, nil
    39  }
    40  
    41  func (ms MockState) accountFn(plainKey []byte, cell *Cell) error {
    42  	exBytes, ok := ms.sm[string(plainKey[:])]
    43  	if !ok {
    44  		ms.t.Logf("accountFn not found key [%x]", plainKey)
    45  		cell.Delete = true
    46  		return nil
    47  	}
    48  	var ex Update
    49  	pos, err := ex.Decode(exBytes, 0)
    50  	if err != nil {
    51  		ms.t.Fatalf("accountFn decode existing [%x], bytes: [%x]: %v", plainKey, exBytes, err)
    52  		return nil
    53  	}
    54  	if pos != len(exBytes) {
    55  		ms.t.Fatalf("accountFn key [%x] leftover bytes in [%x], comsumed %x", plainKey, exBytes, pos)
    56  		return nil
    57  	}
    58  	if ex.Flags&StorageUpdate != 0 {
    59  		ms.t.Logf("accountFn reading storage item for key [%x]", plainKey)
    60  		return fmt.Errorf("storage read by accountFn")
    61  	}
    62  	if ex.Flags&DeleteUpdate != 0 {
    63  		ms.t.Fatalf("accountFn reading deleted account for key [%x]", plainKey)
    64  		return nil
    65  	}
    66  	if ex.Flags&BalanceUpdate != 0 {
    67  		cell.Balance.Set(&ex.Balance)
    68  	} else {
    69  		cell.Balance.Clear()
    70  	}
    71  	if ex.Flags&NonceUpdate != 0 {
    72  		cell.Nonce = ex.Nonce
    73  	} else {
    74  		cell.Nonce = 0
    75  	}
    76  	if ex.Flags&CodeUpdate != 0 {
    77  		copy(cell.CodeHash[:], ex.CodeHashOrStorage[:])
    78  	} else {
    79  		copy(cell.CodeHash[:], EmptyCodeHash)
    80  	}
    81  	return nil
    82  }
    83  
    84  func (ms MockState) storageFn(plainKey []byte, cell *Cell) error {
    85  	exBytes, ok := ms.sm[string(plainKey[:])]
    86  	if !ok {
    87  		ms.t.Logf("storageFn not found key [%x]", plainKey)
    88  		cell.Delete = true
    89  		return nil
    90  	}
    91  	var ex Update
    92  	pos, err := ex.Decode(exBytes, 0)
    93  	if err != nil {
    94  		ms.t.Fatalf("storageFn decode existing [%x], bytes: [%x]: %v", plainKey, exBytes, err)
    95  		return nil
    96  	}
    97  	if pos != len(exBytes) {
    98  		ms.t.Fatalf("storageFn key [%x] leftover bytes in [%x], comsumed %x", plainKey, exBytes, pos)
    99  		return nil
   100  	}
   101  	if ex.Flags&BalanceUpdate != 0 {
   102  		ms.t.Logf("storageFn reading balance for key [%x]", plainKey)
   103  		return nil
   104  	}
   105  	if ex.Flags&NonceUpdate != 0 {
   106  		ms.t.Fatalf("storageFn reading nonce for key [%x]", plainKey)
   107  		return nil
   108  	}
   109  	if ex.Flags&CodeUpdate != 0 {
   110  		ms.t.Fatalf("storageFn reading codeHash for key [%x]", plainKey)
   111  		return nil
   112  	}
   113  	if ex.Flags&DeleteUpdate != 0 {
   114  		ms.t.Fatalf("storageFn reading deleted item for key [%x]", plainKey)
   115  		return nil
   116  	}
   117  	if ex.Flags&StorageUpdate != 0 {
   118  		copy(cell.Storage[:], ex.CodeHashOrStorage[:])
   119  		cell.StorageLen = len(ex.CodeHashOrStorage)
   120  	} else {
   121  		cell.StorageLen = 0
   122  		cell.Storage = [length.Hash]byte{}
   123  	}
   124  	return nil
   125  }
   126  
   127  func (ms *MockState) applyPlainUpdates(plainKeys [][]byte, updates []Update) error {
   128  	for i, key := range plainKeys {
   129  		update := updates[i]
   130  		if update.Flags&DeleteUpdate != 0 {
   131  			delete(ms.sm, string(key))
   132  		} else {
   133  			if exBytes, ok := ms.sm[string(key)]; ok {
   134  				var ex Update
   135  				pos, err := ex.Decode(exBytes, 0)
   136  				if err != nil {
   137  					return fmt.Errorf("applyPlainUpdates decode existing [%x], bytes: [%x]: %w", key, exBytes, err)
   138  				}
   139  				if pos != len(exBytes) {
   140  					return fmt.Errorf("applyPlainUpdates key [%x] leftover bytes in [%x], comsumed %x", key, exBytes, pos)
   141  				}
   142  				if update.Flags&BalanceUpdate != 0 {
   143  					ex.Flags |= BalanceUpdate
   144  					ex.Balance.Set(&update.Balance)
   145  				}
   146  				if update.Flags&NonceUpdate != 0 {
   147  					ex.Flags |= NonceUpdate
   148  					ex.Nonce = update.Nonce
   149  				}
   150  				if update.Flags&CodeUpdate != 0 {
   151  					ex.Flags |= CodeUpdate
   152  					copy(ex.CodeHashOrStorage[:], update.CodeHashOrStorage[:])
   153  				}
   154  				if update.Flags&StorageUpdate != 0 {
   155  					ex.Flags |= StorageUpdate
   156  					copy(ex.CodeHashOrStorage[:], update.CodeHashOrStorage[:])
   157  				}
   158  				ms.sm[string(key)] = ex.Encode(nil, ms.numBuf[:])
   159  			} else {
   160  				ms.sm[string(key)] = update.Encode(nil, ms.numBuf[:])
   161  			}
   162  		}
   163  	}
   164  	return nil
   165  }
   166  
   167  func (ms *MockState) applyBranchNodeUpdates(updates map[string]BranchData) {
   168  	for key, update := range updates {
   169  		if pre, ok := ms.cm[key]; ok {
   170  			// Merge
   171  			merged, err := pre.MergeHexBranches(update, nil)
   172  			if err != nil {
   173  				panic(err)
   174  			}
   175  			ms.cm[key] = merged
   176  		} else {
   177  			ms.cm[key] = update
   178  		}
   179  	}
   180  }
   181  
   182  func decodeHex(in string) []byte {
   183  	payload, err := hex.DecodeString(in)
   184  	if err != nil {
   185  		panic(err)
   186  	}
   187  	return payload
   188  }
   189  
   190  // UpdateBuilder collects updates to the state
   191  // and provides them in properly sorted form
   192  type UpdateBuilder struct {
   193  	balances   map[string]*uint256.Int
   194  	nonces     map[string]uint64
   195  	codeHashes map[string][length.Hash]byte
   196  	storages   map[string]map[string][]byte
   197  	deletes    map[string]struct{}
   198  	deletes2   map[string]map[string]struct{}
   199  	keyset     map[string]struct{}
   200  	keyset2    map[string]map[string]struct{}
   201  }
   202  
   203  func NewUpdateBuilder() *UpdateBuilder {
   204  	return &UpdateBuilder{
   205  		balances:   make(map[string]*uint256.Int),
   206  		nonces:     make(map[string]uint64),
   207  		codeHashes: make(map[string][length.Hash]byte),
   208  		storages:   make(map[string]map[string][]byte),
   209  		deletes:    make(map[string]struct{}),
   210  		deletes2:   make(map[string]map[string]struct{}),
   211  		keyset:     make(map[string]struct{}),
   212  		keyset2:    make(map[string]map[string]struct{}),
   213  	}
   214  }
   215  
   216  func (ub *UpdateBuilder) Balance(addr string, balance uint64) *UpdateBuilder {
   217  	sk := string(decodeHex(addr))
   218  	delete(ub.deletes, sk)
   219  	ub.balances[sk] = uint256.NewInt(balance)
   220  	ub.keyset[sk] = struct{}{}
   221  	return ub
   222  }
   223  
   224  func (ub *UpdateBuilder) Nonce(addr string, nonce uint64) *UpdateBuilder {
   225  	sk := string(decodeHex(addr))
   226  	delete(ub.deletes, sk)
   227  	ub.nonces[sk] = nonce
   228  	ub.keyset[sk] = struct{}{}
   229  	return ub
   230  }
   231  
   232  func (ub *UpdateBuilder) CodeHash(addr string, hash string) *UpdateBuilder {
   233  	sk := string(decodeHex(addr))
   234  	delete(ub.deletes, sk)
   235  	hcode, err := hex.DecodeString(hash)
   236  	if err != nil {
   237  		panic(fmt.Errorf("invalid code hash provided: %w", err))
   238  	}
   239  	if len(hcode) != length.Hash {
   240  		panic(fmt.Errorf("code hash should be %d bytes long, got %d", length.Hash, len(hcode)))
   241  	}
   242  
   243  	dst := [length.Hash]byte{}
   244  	copy(dst[:32], hcode)
   245  
   246  	ub.codeHashes[sk] = dst
   247  	ub.keyset[sk] = struct{}{}
   248  	return ub
   249  }
   250  
   251  func (ub *UpdateBuilder) Storage(addr string, loc string, value string) *UpdateBuilder {
   252  	sk1 := string(decodeHex(addr))
   253  	sk2 := string(decodeHex(loc))
   254  	v := decodeHex(value)
   255  	if d, ok := ub.deletes2[sk1]; ok {
   256  		delete(d, sk2)
   257  		if len(d) == 0 {
   258  			delete(ub.deletes2, sk1)
   259  		}
   260  	}
   261  	if k, ok := ub.keyset2[sk1]; ok {
   262  		k[sk2] = struct{}{}
   263  	} else {
   264  		ub.keyset2[sk1] = make(map[string]struct{})
   265  		ub.keyset2[sk1][sk2] = struct{}{}
   266  	}
   267  	if s, ok := ub.storages[sk1]; ok {
   268  		s[sk2] = v
   269  	} else {
   270  		ub.storages[sk1] = make(map[string][]byte)
   271  		ub.storages[sk1][sk2] = v
   272  	}
   273  	return ub
   274  }
   275  
   276  func (ub *UpdateBuilder) IncrementBalance(addr string, balance []byte) *UpdateBuilder {
   277  	sk := string(decodeHex(addr))
   278  	delete(ub.deletes, sk)
   279  	increment := uint256.NewInt(0)
   280  	increment.SetBytes(balance)
   281  	if old, ok := ub.balances[sk]; ok {
   282  		balance := uint256.NewInt(0)
   283  		balance.Add(old, increment)
   284  		ub.balances[sk] = balance
   285  	} else {
   286  		ub.balances[sk] = increment
   287  	}
   288  	ub.keyset[sk] = struct{}{}
   289  	return ub
   290  }
   291  
   292  func (ub *UpdateBuilder) Delete(addr string) *UpdateBuilder {
   293  	sk := string(decodeHex(addr))
   294  	delete(ub.balances, sk)
   295  	delete(ub.nonces, sk)
   296  	delete(ub.codeHashes, sk)
   297  	delete(ub.storages, sk)
   298  	ub.deletes[sk] = struct{}{}
   299  	ub.keyset[sk] = struct{}{}
   300  	return ub
   301  }
   302  
   303  func (ub *UpdateBuilder) DeleteStorage(addr string, loc string) *UpdateBuilder {
   304  	sk1 := string(decodeHex(addr))
   305  	sk2 := string(decodeHex(loc))
   306  	if s, ok := ub.storages[sk1]; ok {
   307  		delete(s, sk2)
   308  		if len(s) == 0 {
   309  			delete(ub.storages, sk1)
   310  		}
   311  	}
   312  	if k, ok := ub.keyset2[sk1]; ok {
   313  		k[sk2] = struct{}{}
   314  	} else {
   315  		ub.keyset2[sk1] = make(map[string]struct{})
   316  		ub.keyset2[sk1][sk2] = struct{}{}
   317  	}
   318  	if d, ok := ub.deletes2[sk1]; ok {
   319  		d[sk2] = struct{}{}
   320  	} else {
   321  		ub.deletes2[sk1] = make(map[string]struct{})
   322  		ub.deletes2[sk1][sk2] = struct{}{}
   323  	}
   324  	return ub
   325  }
   326  
   327  // Build returns three slices (in the order sorted by the hashed keys)
   328  // 1. Plain keys
   329  // 2. Corresponding hashed keys
   330  // 3. Corresponding updates
   331  func (ub *UpdateBuilder) Build() (plainKeys, hashedKeys [][]byte, updates []Update) {
   332  	hashed := make([]string, 0, len(ub.keyset)+len(ub.keyset2))
   333  	preimages := make(map[string][]byte)
   334  	preimages2 := make(map[string][]byte)
   335  	keccak := sha3.NewLegacyKeccak256()
   336  	for key := range ub.keyset {
   337  		keccak.Reset()
   338  		keccak.Write([]byte(key))
   339  		h := keccak.Sum(nil)
   340  		hashedKey := make([]byte, len(h)*2)
   341  		for i, c := range h {
   342  			hashedKey[i*2] = (c >> 4) & 0xf
   343  			hashedKey[i*2+1] = c & 0xf
   344  		}
   345  		hashed = append(hashed, string(hashedKey))
   346  		preimages[string(hashedKey)] = []byte(key)
   347  	}
   348  	hashedKey := make([]byte, 128)
   349  	for sk1, k := range ub.keyset2 {
   350  		keccak.Reset()
   351  		keccak.Write([]byte(sk1))
   352  		h := keccak.Sum(nil)
   353  		for i, c := range h {
   354  			hashedKey[i*2] = (c >> 4) & 0xf
   355  			hashedKey[i*2+1] = c & 0xf
   356  		}
   357  		for sk2 := range k {
   358  			keccak.Reset()
   359  			keccak.Write([]byte(sk2))
   360  			h2 := keccak.Sum(nil)
   361  			for i, c := range h2 {
   362  				hashedKey[64+i*2] = (c >> 4) & 0xf
   363  				hashedKey[64+i*2+1] = c & 0xf
   364  			}
   365  			hs := string(common.Copy(hashedKey))
   366  			hashed = append(hashed, hs)
   367  			preimages[hs] = []byte(sk1)
   368  			preimages2[hs] = []byte(sk2)
   369  		}
   370  
   371  	}
   372  	slices.Sort(hashed)
   373  	plainKeys = make([][]byte, len(hashed))
   374  	hashedKeys = make([][]byte, len(hashed))
   375  	updates = make([]Update, len(hashed))
   376  	for i, hashedKey := range hashed {
   377  		hashedKeys[i] = []byte(hashedKey)
   378  		key := preimages[hashedKey]
   379  		key2 := preimages2[hashedKey]
   380  		plainKey := make([]byte, len(key)+len(key2))
   381  		copy(plainKey[:], key)
   382  		if key2 != nil {
   383  			copy(plainKey[len(key):], key2)
   384  		}
   385  		plainKeys[i] = plainKey
   386  		u := &updates[i]
   387  		if key2 == nil {
   388  			if balance, ok := ub.balances[string(key)]; ok {
   389  				u.Flags |= BalanceUpdate
   390  				u.Balance.Set(balance)
   391  			}
   392  			if nonce, ok := ub.nonces[string(key)]; ok {
   393  				u.Flags |= NonceUpdate
   394  				u.Nonce = nonce
   395  			}
   396  			if codeHash, ok := ub.codeHashes[string(key)]; ok {
   397  				u.Flags |= CodeUpdate
   398  				copy(u.CodeHashOrStorage[:], codeHash[:])
   399  			}
   400  			if _, del := ub.deletes[string(key)]; del {
   401  				u.Flags = DeleteUpdate
   402  				continue
   403  			}
   404  		} else {
   405  			if dm, ok1 := ub.deletes2[string(key)]; ok1 {
   406  				if _, ok2 := dm[string(key2)]; ok2 {
   407  					u.Flags = DeleteUpdate
   408  					continue
   409  				}
   410  			}
   411  			if sm, ok1 := ub.storages[string(key)]; ok1 {
   412  				if storage, ok2 := sm[string(key2)]; ok2 {
   413  					u.Flags |= StorageUpdate
   414  					u.CodeHashOrStorage = [length.Hash]byte{}
   415  					u.ValLength = len(storage)
   416  					copy(u.CodeHashOrStorage[:], storage)
   417  				}
   418  			}
   419  		}
   420  	}
   421  	return
   422  }