github.com/0xPolygon/supernets2-node@v0.0.0-20230711153321-2fe574524eaa/merkletree/tree.go (about)

     1  package merkletree
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/big"
     7  	"strings"
     8  
     9  	"github.com/0xPolygon/supernets2-node/hex"
    10  	"github.com/0xPolygon/supernets2-node/merkletree/pb"
    11  	"github.com/ethereum/go-ethereum/common"
    12  	"google.golang.org/protobuf/types/known/emptypb"
    13  )
    14  
    15  // StateTree provides methods to access and modify state in merkletree
    16  type StateTree struct {
    17  	grpcClient pb.StateDBServiceClient
    18  }
    19  
    20  // NewStateTree creates new StateTree.
    21  func NewStateTree(client pb.StateDBServiceClient) *StateTree {
    22  	return &StateTree{
    23  		grpcClient: client,
    24  	}
    25  }
    26  
    27  // GetBalance returns balance.
    28  func (tree *StateTree) GetBalance(ctx context.Context, address common.Address, root []byte) (*big.Int, error) {
    29  	r := new(big.Int).SetBytes(root)
    30  
    31  	key, err := KeyEthAddrBalance(address)
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	k := new(big.Int).SetBytes(key[:])
    37  	proof, err := tree.get(ctx, scalarToh4(r), scalarToh4(k))
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  	if proof == nil || proof.Value == nil {
    42  		return big.NewInt(0), nil
    43  	}
    44  	return fea2scalar(proof.Value), nil
    45  }
    46  
    47  // GetNonce returns nonce.
    48  func (tree *StateTree) GetNonce(ctx context.Context, address common.Address, root []byte) (*big.Int, error) {
    49  	r := new(big.Int).SetBytes(root)
    50  
    51  	key, err := KeyEthAddrNonce(address)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	k := new(big.Int).SetBytes(key[:])
    57  	proof, err := tree.get(ctx, scalarToh4(r), scalarToh4(k))
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	if proof == nil || proof.Value == nil {
    62  		return big.NewInt(0), nil
    63  	}
    64  	return fea2scalar(proof.Value), nil
    65  }
    66  
    67  // GetCodeHash returns code hash.
    68  func (tree *StateTree) GetCodeHash(ctx context.Context, address common.Address, root []byte) ([]byte, error) {
    69  	r := new(big.Int).SetBytes(root)
    70  
    71  	key, err := KeyContractCode(address)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	// this code gets only the hash of the smart contract code from the merkle tree
    76  	k := new(big.Int).SetBytes(key[:])
    77  	proof, err := tree.get(ctx, scalarToh4(r), scalarToh4(k))
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	if proof.Value == nil {
    82  		return nil, nil
    83  	}
    84  
    85  	valueBi := fea2scalar(proof.Value)
    86  	return ScalarToFilledByteSlice(valueBi), nil
    87  }
    88  
    89  // GetCode returns code.
    90  func (tree *StateTree) GetCode(ctx context.Context, address common.Address, root []byte) ([]byte, error) {
    91  	scCodeHash, err := tree.GetCodeHash(ctx, address, root)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	k := new(big.Int).SetBytes(scCodeHash[:])
    97  
    98  	// this code gets actual smart contract code from sc code storage
    99  	scCode, err := tree.getProgram(ctx, scalarToh4(k))
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	return scCode.Data, nil
   105  }
   106  
   107  // GetStorageAt returns Storage Value at specified position.
   108  func (tree *StateTree) GetStorageAt(ctx context.Context, address common.Address, position *big.Int, root []byte) (*big.Int, error) {
   109  	r := new(big.Int).SetBytes(root)
   110  
   111  	key, err := KeyContractStorage(address, position.Bytes())
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	k := new(big.Int).SetBytes(key[:])
   117  	proof, err := tree.get(ctx, scalarToh4(r), scalarToh4(k))
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	if proof == nil || proof.Value == nil {
   122  		return big.NewInt(0), nil
   123  	}
   124  	return fea2scalar(proof.Value), nil
   125  }
   126  
   127  // SetBalance sets balance.
   128  func (tree *StateTree) SetBalance(ctx context.Context, address common.Address, balance *big.Int, root []byte) (newRoot []byte, proof *UpdateProof, err error) {
   129  	if balance.Cmp(big.NewInt(0)) == -1 {
   130  		return nil, nil, fmt.Errorf("invalid balance")
   131  	}
   132  
   133  	r := new(big.Int).SetBytes(root)
   134  	key, err := KeyEthAddrBalance(address)
   135  	if err != nil {
   136  		return nil, nil, err
   137  	}
   138  
   139  	k := new(big.Int).SetBytes(key)
   140  	balanceH8 := scalar2fea(balance)
   141  
   142  	updateProof, err := tree.set(ctx, scalarToh4(r), scalarToh4(k), balanceH8)
   143  	if err != nil {
   144  		return nil, nil, err
   145  	}
   146  
   147  	return h4ToFilledByteSlice(updateProof.NewRoot), updateProof, nil
   148  }
   149  
   150  // SetNonce sets nonce.
   151  func (tree *StateTree) SetNonce(ctx context.Context, address common.Address, nonce *big.Int, root []byte) (newRoot []byte, proof *UpdateProof, err error) {
   152  	if nonce.Cmp(big.NewInt(0)) == -1 {
   153  		return nil, nil, fmt.Errorf("invalid nonce")
   154  	}
   155  
   156  	r := new(big.Int).SetBytes(root)
   157  	key, err := KeyEthAddrNonce(address)
   158  	if err != nil {
   159  		return nil, nil, err
   160  	}
   161  
   162  	k := new(big.Int).SetBytes(key[:])
   163  
   164  	nonceH8 := scalar2fea(nonce)
   165  
   166  	updateProof, err := tree.set(ctx, scalarToh4(r), scalarToh4(k), nonceH8)
   167  	if err != nil {
   168  		return nil, nil, err
   169  	}
   170  
   171  	return h4ToFilledByteSlice(updateProof.NewRoot), updateProof, nil
   172  }
   173  
   174  // SetCode sets smart contract code.
   175  func (tree *StateTree) SetCode(ctx context.Context, address common.Address, code []byte, root []byte) (newRoot []byte, proof *UpdateProof, err error) {
   176  	// calculating smart contract code hash
   177  	scCodeHash4, err := hashContractBytecode(code)
   178  	if err != nil {
   179  		return nil, nil, err
   180  	}
   181  
   182  	// store smart contract code by its hash
   183  	err = tree.setProgram(ctx, scCodeHash4, code, true)
   184  	if err != nil {
   185  		return nil, nil, err
   186  	}
   187  
   188  	// set smart contract code hash as a leaf value in merkle tree
   189  	r := new(big.Int).SetBytes(root)
   190  	key, err := KeyContractCode(address)
   191  	if err != nil {
   192  		return nil, nil, err
   193  	}
   194  	k := new(big.Int).SetBytes(key[:])
   195  
   196  	scCodeHash, err := hex.DecodeHex(H4ToString(scCodeHash4))
   197  	if err != nil {
   198  		return nil, nil, err
   199  	}
   200  
   201  	scCodeHashBI := new(big.Int).SetBytes(scCodeHash[:])
   202  	scCodeHashH8 := scalar2fea(scCodeHashBI)
   203  
   204  	updateProof, err := tree.set(ctx, scalarToh4(r), scalarToh4(k), scCodeHashH8)
   205  	if err != nil {
   206  		return nil, nil, err
   207  	}
   208  
   209  	// set code length as a leaf value in merkle tree
   210  	key, err = KeyCodeLength(address)
   211  	if err != nil {
   212  		return nil, nil, err
   213  	}
   214  	k = new(big.Int).SetBytes(key[:])
   215  	scCodeLengthBI := new(big.Int).SetInt64(int64(len(code)))
   216  	scCodeLengthH8 := scalar2fea(scCodeLengthBI)
   217  
   218  	updateProof, err = tree.set(ctx, updateProof.NewRoot, scalarToh4(k), scCodeLengthH8)
   219  	if err != nil {
   220  		return nil, nil, err
   221  	}
   222  
   223  	return h4ToFilledByteSlice(updateProof.NewRoot), updateProof, nil
   224  }
   225  
   226  // SetStorageAt sets storage value at specified position.
   227  func (tree *StateTree) SetStorageAt(ctx context.Context, address common.Address, position *big.Int, value *big.Int, root []byte) (newRoot []byte, proof *UpdateProof, err error) {
   228  	r := new(big.Int).SetBytes(root)
   229  	key, err := KeyContractStorage(address, position.Bytes())
   230  	if err != nil {
   231  		return nil, nil, err
   232  	}
   233  
   234  	k := new(big.Int).SetBytes(key[:])
   235  	valueH8 := scalar2fea(value)
   236  	updateProof, err := tree.set(ctx, scalarToh4(r), scalarToh4(k), valueH8)
   237  	if err != nil {
   238  		return nil, nil, err
   239  	}
   240  
   241  	return h4ToFilledByteSlice(updateProof.NewRoot), updateProof, nil
   242  }
   243  
   244  func (tree *StateTree) get(ctx context.Context, root, key []uint64) (*Proof, error) {
   245  	result, err := tree.grpcClient.Get(ctx, &pb.GetRequest{
   246  		Root: &pb.Fea{Fe0: root[0], Fe1: root[1], Fe2: root[2], Fe3: root[3]},
   247  		Key:  &pb.Fea{Fe0: key[0], Fe1: key[1], Fe2: key[2], Fe3: key[3]},
   248  	})
   249  	if err != nil {
   250  		return nil, err
   251  	}
   252  
   253  	value, err := string2fea(result.Value)
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  	return &Proof{
   258  		Root:  []uint64{root[0], root[1], root[2], root[3]},
   259  		Key:   key,
   260  		Value: value,
   261  	}, nil
   262  }
   263  
   264  func (tree *StateTree) getProgram(ctx context.Context, key []uint64) (*ProgramProof, error) {
   265  	result, err := tree.grpcClient.GetProgram(ctx, &pb.GetProgramRequest{
   266  		Key: &pb.Fea{Fe0: key[0], Fe1: key[1], Fe2: key[2], Fe3: key[3]},
   267  	})
   268  	if err != nil {
   269  		return nil, err
   270  	}
   271  
   272  	return &ProgramProof{
   273  		Data: result.Data,
   274  	}, nil
   275  }
   276  
   277  func (tree *StateTree) set(ctx context.Context, oldRoot, key, value []uint64) (*UpdateProof, error) {
   278  	feaValue := fea2string(value)
   279  	if strings.HasPrefix(feaValue, "0x") { // nolint
   280  		feaValue = feaValue[2:]
   281  	}
   282  	result, err := tree.grpcClient.Set(ctx, &pb.SetRequest{
   283  		OldRoot:    &pb.Fea{Fe0: oldRoot[0], Fe1: oldRoot[1], Fe2: oldRoot[2], Fe3: oldRoot[3]},
   284  		Key:        &pb.Fea{Fe0: key[0], Fe1: key[1], Fe2: key[2], Fe3: key[3]},
   285  		Value:      feaValue,
   286  		Persistent: true,
   287  	})
   288  	if err != nil {
   289  		return nil, err
   290  	}
   291  
   292  	var newValue []uint64
   293  	if result.NewValue != "" {
   294  		newValue, err = string2fea(result.NewValue)
   295  		if err != nil {
   296  			return nil, err
   297  		}
   298  	}
   299  
   300  	return &UpdateProof{
   301  		OldRoot:  oldRoot,
   302  		NewRoot:  []uint64{result.NewRoot.Fe0, result.NewRoot.Fe1, result.NewRoot.Fe2, result.NewRoot.Fe3},
   303  		Key:      key,
   304  		NewValue: newValue,
   305  	}, nil
   306  }
   307  
   308  func (tree *StateTree) setProgram(ctx context.Context, key []uint64, data []byte, persistent bool) error {
   309  	_, err := tree.grpcClient.SetProgram(ctx, &pb.SetProgramRequest{
   310  		Key:        &pb.Fea{Fe0: key[0], Fe1: key[1], Fe2: key[2], Fe3: key[3]},
   311  		Data:       data,
   312  		Persistent: persistent,
   313  	})
   314  	return err
   315  }
   316  
   317  // Flush flushes all changes to the persistent storage.
   318  func (tree *StateTree) Flush(ctx context.Context) error {
   319  	_, err := tree.grpcClient.Flush(ctx, &emptypb.Empty{})
   320  	return err
   321  }