github.com/hyperledger/burrow@v0.34.5-0.20220512172541-77f09336001d/encoding/rlp/rlp.go (about)

     1  //
     2  // See https://eth.wiki/fundamentals/rlp
     3  //
     4  package rlp
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/binary"
     9  	"fmt"
    10  	"math/big"
    11  	"math/bits"
    12  	"reflect"
    13  
    14  	binary2 "github.com/hyperledger/burrow/binary"
    15  )
    16  
    17  type magicOffset uint8
    18  
    19  const (
    20  	ShortLength              = 55
    21  	StringOffset magicOffset = 0x80 // 128 - if string length is less than or equal to 55 [inclusive]
    22  	SliceOffset  magicOffset = 0xC0 // 192 - if slice length is less than or equal to 55 [inclusive]
    23  	SmallByte                = 0x7f // 247 - value less than or equal is itself [inclusive
    24  )
    25  
    26  type Code uint32
    27  
    28  const (
    29  	ErrUnknown Code = iota
    30  	ErrNoInput
    31  	ErrInvalid
    32  )
    33  
    34  var bigIntType = reflect.TypeOf(&big.Int{})
    35  
    36  func (c Code) Error() string {
    37  	switch c {
    38  	case ErrNoInput:
    39  		return "no input"
    40  	case ErrInvalid:
    41  		return "input not valid RLP encoding"
    42  	default:
    43  		return "unknown error"
    44  	}
    45  }
    46  
    47  func Encode(input interface{}) ([]byte, error) {
    48  	val := reflect.ValueOf(input)
    49  	if val.Kind() == reflect.Ptr {
    50  		val = val.Elem()
    51  	}
    52  	return encode(val)
    53  }
    54  
    55  func Decode(src []byte, dst interface{}) error {
    56  	fields, err := decode(src)
    57  	if err != nil {
    58  		return err
    59  	}
    60  
    61  	val := reflect.ValueOf(dst)
    62  	typ := reflect.TypeOf(dst)
    63  
    64  	if val.Kind() == reflect.Ptr {
    65  		val = val.Elem()
    66  	}
    67  
    68  	switch val.Kind() {
    69  	case reflect.Slice:
    70  		switch typ.Elem().Kind() {
    71  		case reflect.Uint8:
    72  			out, ok := dst.([]byte)
    73  			if !ok {
    74  				return fmt.Errorf("cannot decode into type %s", val.Type())
    75  			}
    76  			found := bytes.Join(fields, []byte(""))
    77  			if len(out) < len(found) {
    78  				return fmt.Errorf("cannot decode %d bytes into slice of size %d", len(found), len(out))
    79  			}
    80  			for i, b := range found {
    81  				out[i] = b
    82  			}
    83  		default:
    84  			for i := 0; i < val.Len(); i++ {
    85  				elem := val.Index(i)
    86  				err = decodeField(elem, fields[i])
    87  				if err != nil {
    88  					return err
    89  				}
    90  			}
    91  		}
    92  	case reflect.Struct:
    93  		rt := val.Type()
    94  		numExportedFields := 0
    95  		for i := 0; i < val.NumField(); i++ {
    96  			// Skip unexported fields
    97  			if rt.Field(i).PkgPath == "" {
    98  				err := decodeField(val.Field(i), fields[numExportedFields])
    99  				if err != nil {
   100  					return err
   101  				}
   102  				numExportedFields++
   103  			}
   104  		}
   105  		if numExportedFields != len(fields) {
   106  			return fmt.Errorf("wrong number of fields; have %d, want %d", len(fields), numExportedFields)
   107  		}
   108  
   109  	default:
   110  		return fmt.Errorf("cannot decode into unsupported type %v", reflect.TypeOf(dst))
   111  	}
   112  	return nil
   113  }
   114  
   115  func encodeUint8(input uint8) ([]byte, error) {
   116  	if input == 0 {
   117  		// yes this makes no sense, but it does seem to be what everyone else does, apparently 'no leading zeroes'.
   118  		// It means we cannot store []byte{0} because that is indistinguishable from byte{}
   119  		return []byte{uint8(StringOffset)}, nil
   120  	} else if input <= SmallByte {
   121  		return []byte{input}, nil
   122  	} else if input >= uint8(StringOffset) {
   123  		return []byte{0x81, input}, nil
   124  	}
   125  	return []byte{uint8(StringOffset)}, nil
   126  }
   127  
   128  func encodeUint64(i uint64) ([]byte, error) {
   129  	// Byte-wise ceiling
   130  	byteCount := (bits.Len64(i) + 7) / 8
   131  	if byteCount == 1 {
   132  		return encodeUint8(uint8(i))
   133  	}
   134  	b := make([]byte, 8)
   135  	binary.BigEndian.PutUint64(b, uint64(i))
   136  	return encodeString(b[8-byteCount:])
   137  }
   138  
   139  func encodeBigInt(b *big.Int) ([]byte, error) {
   140  	if b.Sign() == -1 {
   141  		return nil, fmt.Errorf("cannot RLP encode negative number")
   142  	}
   143  	if b.IsUint64() {
   144  		return encodeUint64(b.Uint64())
   145  	}
   146  	bs := b.Bytes()
   147  	length := encodeLength(len(bs), StringOffset)
   148  	return append(length, bs...), nil
   149  }
   150  
   151  func encodeLength(n int, offset magicOffset) []byte {
   152  	// > if a string is 0-55 bytes long, the RLP encoding consists of a single byte with value 0x80 plus
   153  	// > the length of the string followed by the string.
   154  	if n <= ShortLength {
   155  		return []uint8{uint8(offset) + uint8(n)}
   156  	}
   157  	i := uint64(n)
   158  	b := make([]byte, 8)
   159  	binary.BigEndian.PutUint64(b, i)
   160  	// Byte-wise ceiling
   161  	byteLengthOfLength := (bits.Len64(i) + 7) / 8
   162  	// > If a string is more than 55 bytes long, the RLP encoding consists of a single byte with value 0xb7
   163  	// > plus the length in bytes of the length of the string in binary form, followed by the length of the string,
   164  	// > followed by the string
   165  	return append([]byte{uint8(offset) + ShortLength + uint8(byteLengthOfLength)}, b[8-byteLengthOfLength:]...)
   166  }
   167  
   168  func encodeString(input []byte) ([]byte, error) {
   169  	if len(input) == 1 && input[0] <= SmallByte {
   170  		return encodeUint8(input[0])
   171  	} else {
   172  		return append(encodeLength(len(input), StringOffset), input...), nil
   173  	}
   174  }
   175  
   176  func encodeList(val reflect.Value) ([]byte, error) {
   177  	if val.Len() == 0 {
   178  		return []byte{uint8(SliceOffset)}, nil
   179  	}
   180  
   181  	out := make([][]byte, 0)
   182  	for i := 0; i < val.Len(); i++ {
   183  		data, err := encode(val.Index(i))
   184  		if err != nil {
   185  			return nil, err
   186  		}
   187  		out = append(out, data)
   188  	}
   189  
   190  	sum := bytes.Join(out, []byte{})
   191  	return append(encodeLength(len(sum), SliceOffset), sum...), nil
   192  }
   193  
   194  func encodeStruct(val reflect.Value) ([]byte, error) {
   195  	out := make([][]byte, 0)
   196  
   197  	rt := val.Type()
   198  
   199  	for i := 0; i < val.NumField(); i++ {
   200  		field := val.Field(i)
   201  		// Skip unexported fields
   202  		if rt.Field(i).PkgPath == "" {
   203  			data, err := encode(field)
   204  			if err != nil {
   205  				return nil, err
   206  			}
   207  			out = append(out, data)
   208  		}
   209  	}
   210  	sum := bytes.Join(out, []byte{})
   211  	length := encodeLength(len(sum), SliceOffset)
   212  	return append(length, sum...), nil
   213  }
   214  
   215  func encode(val reflect.Value) ([]byte, error) {
   216  	if val.Kind() == reflect.Interface {
   217  		val = val.Elem()
   218  	}
   219  
   220  	switch val.Kind() {
   221  	case reflect.Ptr:
   222  		if !val.Type().AssignableTo(bigIntType) {
   223  			return nil, fmt.Errorf("cannot encode pointer type %v", val.Type())
   224  		}
   225  		return encodeBigInt(val.Interface().(*big.Int))
   226  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   227  		i := val.Int()
   228  		if i < 0 {
   229  			return nil, fmt.Errorf("cannot rlp encode negative integer")
   230  		}
   231  		return encodeUint64(uint64(i))
   232  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   233  		return encodeUint64(val.Uint())
   234  	case reflect.Bool:
   235  		if val.Bool() {
   236  			return []byte{0x01}, nil
   237  		}
   238  		return []byte{uint8(StringOffset)}, nil
   239  	case reflect.String:
   240  		return encodeString([]byte(val.String()))
   241  	case reflect.Slice:
   242  		switch val.Type().Elem().Kind() {
   243  		case reflect.Uint8:
   244  			i, err := encodeString(val.Bytes())
   245  			return i, err
   246  		default:
   247  			return encodeList(val)
   248  		}
   249  	case reflect.Struct:
   250  		return encodeStruct(val)
   251  	default:
   252  		return []byte{uint8(StringOffset)}, nil
   253  	}
   254  }
   255  
   256  // Split into RLP fields by reading length prefixes and consuming chunks
   257  func decode(in []byte) ([][]byte, error) {
   258  	if len(in) == 0 {
   259  		return nil, nil
   260  	}
   261  
   262  	offset, length, typ := decodeLength(in)
   263  	end := offset + length
   264  
   265  	if end > uint64(len(in)) {
   266  		return nil, fmt.Errorf("read length prefix of %d but there is only %d bytes of unconsumed input",
   267  			length, uint64(len(in))-offset)
   268  	}
   269  
   270  	suffix, err := decode(in[end:])
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  	switch typ {
   275  	case reflect.String:
   276  		return append([][]byte{in[offset:end]}, suffix...), nil
   277  	case reflect.Slice:
   278  		prefix, err := decode(in[offset:end])
   279  		if err != nil {
   280  			return nil, err
   281  		}
   282  		return append(prefix, suffix...), nil
   283  	}
   284  
   285  	return suffix, nil
   286  }
   287  
   288  func decodeLength(input []byte) (uint64, uint64, reflect.Kind) {
   289  	magicByte := magicOffset(input[0])
   290  
   291  	switch {
   292  	case magicByte <= SmallByte:
   293  		// small byte: sufficiently small single byte
   294  		return 0, 1, reflect.String
   295  
   296  	case magicByte <= StringOffset+ShortLength:
   297  		// short string: length less than or equal to 55 bytes
   298  		length := uint64(magicByte - StringOffset)
   299  		return 1, length, reflect.String
   300  
   301  	case magicByte < SliceOffset:
   302  		// long string: length described by magic = 0xb7 + <byte length of length of string>
   303  		byteLengthOfLength := magicByte - StringOffset - ShortLength
   304  		length := getUint64(input[1:byteLengthOfLength])
   305  		offset := uint64(byteLengthOfLength + 1)
   306  		return offset, length, reflect.String
   307  
   308  	case magicByte <= SliceOffset+ShortLength:
   309  		// short slice: length less than or equal to 55 bytes
   310  		length := uint64(magicByte - SliceOffset)
   311  		return 1, length, reflect.Slice
   312  
   313  	// Note this takes us all the way up to <= 255 so this switch is exhaustive
   314  	default:
   315  		// long string: length described by magic = 0xf7 + <byte length of length of string>
   316  		byteLengthOfLength := magicByte - SliceOffset - ShortLength
   317  		length := getUint64(input[1:byteLengthOfLength])
   318  		offset := uint64(byteLengthOfLength + 1)
   319  		return offset, length, reflect.Slice
   320  	}
   321  }
   322  
   323  func getUint64(bs []byte) uint64 {
   324  	bs = binary2.LeftPadBytes(bs, 8)
   325  	return binary.BigEndian.Uint64(bs)
   326  }
   327  
   328  func decodeField(val reflect.Value, field []byte) error {
   329  	typ := val.Type()
   330  
   331  	switch val.Kind() {
   332  	case reflect.Ptr:
   333  		if !typ.AssignableTo(bigIntType) {
   334  			return fmt.Errorf("cannot decode into pointer type %v", typ)
   335  		}
   336  		bi := new(big.Int).SetBytes(field)
   337  		val.Set(reflect.ValueOf(bi))
   338  
   339  	case reflect.String:
   340  		val.SetString(string(field))
   341  	case reflect.Uint64:
   342  		out := make([]byte, 8)
   343  		for j := range field {
   344  			out[len(out)-(len(field)-j)] = field[j]
   345  		}
   346  		val.SetUint(binary.BigEndian.Uint64(out))
   347  	case reflect.Slice:
   348  		if typ.Elem().Kind() != reflect.Uint8 {
   349  			// skip
   350  			return nil
   351  		}
   352  		out := make([]byte, len(field))
   353  		for i, b := range field {
   354  			out[i] = b
   355  		}
   356  		val.SetBytes(out)
   357  	}
   358  	return nil
   359  }