github.com/aergoio/aergo@v1.3.1/contract/ethstorageproof.go (about)

     1  package contract
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"encoding/hex"
     7  	"errors"
     8  
     9  	"golang.org/x/crypto/sha3"
    10  )
    11  
    12  const (
    13  	shortNode  = 2
    14  	branchNode = 17
    15  	hexChar    = "0123456789abcdef"
    16  )
    17  
    18  type (
    19  	rlpNode   [][]byte
    20  	keyStream struct {
    21  		*bytes.Buffer
    22  	}
    23  )
    24  
    25  var (
    26  	errDecode = errors.New("storage proof decode error")
    27  	lenBuf    = make([]byte, 8)
    28  	nilBuf    = make([]byte, 8)
    29  )
    30  
    31  func verifyEthStorageProof(key, value, expectedHash []byte, proof [][]byte) bool {
    32  	if len(key) == 0 || value == nil || len(proof) == 0 {
    33  		return false
    34  	}
    35  	key = []byte(hex.EncodeToString(keccak256(key)))
    36  	value = rlpEncodeString55(value)
    37  	ks := keyStream{bytes.NewBuffer(key)}
    38  	for i, p := range proof {
    39  		if ((i != 0 && len(p) < 32) || !bytes.Equal(expectedHash, keccak256(p))) && !bytes.Equal(expectedHash, p) {
    40  			return false
    41  		}
    42  		n := decodeRlpTrieNode(p)
    43  		switch len(n) {
    44  		case shortNode:
    45  			if len(n[0]) == 0 {
    46  				return false
    47  			}
    48  			leaf, sharedNibbles, err := decodeHpHeader(n[0][0])
    49  			if err != nil {
    50  				return false
    51  			}
    52  			sharedNibbles = append(sharedNibbles, []byte(hex.EncodeToString(n[0][1:]))...)
    53  			if len(sharedNibbles) == 0 {
    54  				return false
    55  			}
    56  			if leaf {
    57  				return bytes.Equal(sharedNibbles, ks.Key(-1)) && bytes.Equal(n[1], value)
    58  			}
    59  			if !bytes.Equal(sharedNibbles, ks.Key(len(sharedNibbles))) {
    60  				return false
    61  			}
    62  			expectedHash = n[1]
    63  		case branchNode:
    64  			if ks.Len() == 0 {
    65  				return bytes.Equal(n[16], value)
    66  			}
    67  			k := ks.Index()
    68  			if k > 0x0f {
    69  				return false
    70  			}
    71  			expectedHash = n[k]
    72  		default:
    73  			return false
    74  		}
    75  	}
    76  	return false
    77  }
    78  
    79  func decodeRlpTrieNode(data []byte) rlpNode {
    80  	var (
    81  		dataLen = uint64(len(data))
    82  		node    rlpNode
    83  	)
    84  	if dataLen == uint64(0) {
    85  		return nil
    86  	}
    87  	switch {
    88  	case data[0] >= 0xf8:
    89  		lenLen := int(data[0]) - 0xf7
    90  		l, err := decodeLen(data[1:], lenLen)
    91  		if err != nil {
    92  			return nil
    93  		}
    94  		if dataLen != uint64(1)+uint64(lenLen)+l {
    95  			return nil
    96  		}
    97  		node = toList(data[1+lenLen:], l)
    98  	case data[0] >= 0xc0:
    99  		l := uint64(data[0]) - 0xc0
   100  		if dataLen != uint64(1+l) {
   101  			return nil
   102  		}
   103  		node = toList(data[1:], l)
   104  	}
   105  	return node
   106  }
   107  
   108  func decodeLen(data []byte, lenLen int) (uint64, error) {
   109  	if len(data) <= lenLen || lenLen > 8 {
   110  		return 0, errDecode
   111  	}
   112  	switch lenLen {
   113  	case 1:
   114  		return uint64(data[0]), nil
   115  	default:
   116  		start := int(8 - lenLen)
   117  		copy(lenBuf[:], nilBuf[:start])
   118  		copy(lenBuf[start:], data[:lenLen])
   119  		return binary.BigEndian.Uint64(lenBuf), nil
   120  	}
   121  }
   122  
   123  func toList(data []byte, dataLen uint64) rlpNode {
   124  	var (
   125  		node   rlpNode
   126  		l      uint64
   127  		offset = uint64(0)
   128  	)
   129  	for {
   130  		e, l, err := toString(data[offset:])
   131  		if err != nil {
   132  			return nil
   133  		}
   134  		node = append(node, e)
   135  		offset += l
   136  		if dataLen == offset {
   137  			break
   138  		}
   139  		if dataLen < offset {
   140  			return nil
   141  		}
   142  	}
   143  	l = uint64(len(node))
   144  	if l != uint64(2) && l != uint64(17) {
   145  		return nil
   146  	}
   147  	return node
   148  }
   149  
   150  func toString(data []byte) ([]byte, uint64, error) {
   151  	if len(data) == 0 {
   152  		return nil, 0, errDecode
   153  	}
   154  	switch {
   155  	case data[0] <= 0x7f: // character
   156  		return data[0:1], 1, nil
   157  	case data[0] <= 0xb7: // string <= 55
   158  		end := 1 + data[0] - 0x80
   159  		return data[1:end], uint64(end), nil
   160  	case data[0] <= 0xbf: // string > 55
   161  		lenLen := data[0] - 0xb7
   162  		l, err := decodeLen(data[1:], int(lenLen))
   163  		if err != nil {
   164  			return nil, 0, err
   165  		}
   166  		start := 1 + lenLen
   167  		end := uint64(start) + l
   168  		return data[start:end], end, nil
   169  	default:
   170  		return nil, 0, errDecode
   171  	}
   172  }
   173  
   174  func keccak256(data ...[]byte) []byte {
   175  	h := sha3.NewLegacyKeccak256()
   176  	for _, d := range data {
   177  		h.Write(d)
   178  	}
   179  	return h.Sum(nil)
   180  }
   181  
   182  func keccak256Hex(data ...[]byte) string {
   183  	return hex.EncodeToString(keccak256(data...))
   184  }
   185  
   186  func decodeHpHeader(b byte) (bool, []byte, error) {
   187  	switch b >> 4 {
   188  	case 0:
   189  		return false, []byte{}, nil
   190  	case 1:
   191  		return false, []byte{hexChar[b&0x0f]}, nil
   192  	case 2:
   193  		return true, []byte{}, nil
   194  	case 3:
   195  		return true, []byte{hexChar[b&0x0f]}, nil
   196  	default:
   197  		return false, []byte{}, errDecode
   198  	}
   199  }
   200  
   201  func hexToIndex(c byte) (byte, error) {
   202  	switch {
   203  	case '0' <= c && c <= '9':
   204  		return c - '0', nil
   205  	case 'a' <= c && c <= 'f':
   206  		return c - 'a' + 10, nil
   207  	case 'A' <= c && c <= 'F':
   208  		return c - 'A' + 10, nil
   209  	}
   210  	return 0, errDecode
   211  }
   212  
   213  func (ks keyStream) Index() byte {
   214  	b, err := ks.ReadByte()
   215  	if err != nil {
   216  		return 0x10
   217  	}
   218  	i, err := hexToIndex(b)
   219  	if err != nil {
   220  		return 0x10
   221  	}
   222  	return i
   223  }
   224  
   225  func (ks keyStream) Key(l int) []byte {
   226  	if l == -1 {
   227  		return ks.Buffer.Bytes()
   228  	}
   229  	return ks.Buffer.Next(l)
   230  }
   231  
   232  func rlpEncodeString55(b []byte) []byte {
   233  	var rlpBytes []byte
   234  	l := len(b)
   235  	if l == 1 && b[0] < 0x80 {
   236  		rlpBytes = append(rlpBytes, b[0])
   237  	} else if l < 56 {
   238  		rlpBytes = append(rlpBytes, 0x80+byte(l))
   239  		rlpBytes = append(rlpBytes, b...)
   240  	}
   241  	return rlpBytes
   242  }