github.com/core-coin/go-core/v2@v2.1.9/tests/rlp_test_util.go (about)

     1  // Copyright 2015 by the Authors
     2  // This file is part of the go-core library.
     3  //
     4  // The go-core library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-core library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-core library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package tests
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/hex"
    22  	"errors"
    23  	"fmt"
    24  	"math/big"
    25  	"strings"
    26  
    27  	"github.com/core-coin/go-core/v2/rlp"
    28  )
    29  
    30  // RLPTest is the JSON structure of a single RLP test.
    31  type RLPTest struct {
    32  	// If the value of In is "INVALID" or "VALID", the test
    33  	// checks whether Out can be decoded into a value of
    34  	// type interface{}.
    35  	//
    36  	// For other JSON values, In is treated as a driver for
    37  	// calls to rlp.Stream. The test also verifies that encoding
    38  	// In produces the bytes in Out.
    39  	In interface{}
    40  
    41  	// Out is a hex-encoded RLP value.
    42  	Out string
    43  }
    44  
    45  // FromHex returns the bytes represented by the hexadecimal string s.
    46  // s may be prefixed with "0x".
    47  // This is copy-pasted from bytes.go, which does not return the error
    48  func FromHex(s string) ([]byte, error) {
    49  	if len(s) > 1 && (s[0:2] == "0x" || s[0:2] == "0X") {
    50  		s = s[2:]
    51  	}
    52  	if len(s)%2 == 1 {
    53  		s = "0" + s
    54  	}
    55  	return hex.DecodeString(s)
    56  }
    57  
    58  // Run executes the test.
    59  func (t *RLPTest) Run() error {
    60  	outb, err := FromHex(t.Out)
    61  	if err != nil {
    62  		return fmt.Errorf("invalid hex in Out")
    63  	}
    64  
    65  	// Handle simple decoding tests with no actual In value.
    66  	if t.In == "VALID" || t.In == "INVALID" {
    67  		return checkDecodeInterface(outb, t.In == "VALID")
    68  	}
    69  
    70  	// Check whether encoding the value produces the same bytes.
    71  	in := translateJSON(t.In)
    72  	b, err := rlp.EncodeToBytes(in)
    73  	if err != nil {
    74  		return fmt.Errorf("encode failed: %v", err)
    75  	}
    76  	if !bytes.Equal(b, outb) {
    77  		return fmt.Errorf("encode produced %x, want %x", b, outb)
    78  	}
    79  	// Test stream decoding.
    80  	s := rlp.NewStream(bytes.NewReader(outb), 0)
    81  	return checkDecodeFromJSON(s, in)
    82  }
    83  
    84  func checkDecodeInterface(b []byte, isValid bool) error {
    85  	err := rlp.DecodeBytes(b, new(interface{}))
    86  	switch {
    87  	case isValid && err != nil:
    88  		return fmt.Errorf("decoding failed: %v", err)
    89  	case !isValid && err == nil:
    90  		return fmt.Errorf("decoding of invalid value succeeded")
    91  	}
    92  	return nil
    93  }
    94  
    95  // translateJSON makes test json values encodable with RLP.
    96  func translateJSON(v interface{}) interface{} {
    97  	switch v := v.(type) {
    98  	case float64:
    99  		return uint64(v)
   100  	case string:
   101  		if len(v) > 0 && v[0] == '#' { // # starts a faux big int.
   102  			big, ok := new(big.Int).SetString(v[1:], 10)
   103  			if !ok {
   104  				panic(fmt.Errorf("bad test: bad big int: %q", v))
   105  			}
   106  			return big
   107  		}
   108  		return []byte(v)
   109  	case []interface{}:
   110  		new := make([]interface{}, len(v))
   111  		for i := range v {
   112  			new[i] = translateJSON(v[i])
   113  		}
   114  		return new
   115  	default:
   116  		panic(fmt.Errorf("can't handle %T", v))
   117  	}
   118  }
   119  
   120  // checkDecodeFromJSON decodes from s guided by exp. exp drives the
   121  // Stream by invoking decoding operations (Uint, Big, List, ...) based
   122  // on the type of each value. The value decoded from the RLP stream
   123  // must match the JSON value.
   124  func checkDecodeFromJSON(s *rlp.Stream, exp interface{}) error {
   125  	switch exp := exp.(type) {
   126  	case uint64:
   127  		i, err := s.Uint()
   128  		if err != nil {
   129  			return addStack("Uint", exp, err)
   130  		}
   131  		if i != exp {
   132  			return addStack("Uint", exp, fmt.Errorf("result mismatch: got %d", i))
   133  		}
   134  	case *big.Int:
   135  		big := new(big.Int)
   136  		if err := s.Decode(&big); err != nil {
   137  			return addStack("Big", exp, err)
   138  		}
   139  		if big.Cmp(exp) != 0 {
   140  			return addStack("Big", exp, fmt.Errorf("result mismatch: got %d", big))
   141  		}
   142  	case []byte:
   143  		b, err := s.Bytes()
   144  		if err != nil {
   145  			return addStack("Bytes", exp, err)
   146  		}
   147  		if !bytes.Equal(b, exp) {
   148  			return addStack("Bytes", exp, fmt.Errorf("result mismatch: got %x", b))
   149  		}
   150  	case []interface{}:
   151  		if _, err := s.List(); err != nil {
   152  			return addStack("List", exp, err)
   153  		}
   154  		for i, v := range exp {
   155  			if err := checkDecodeFromJSON(s, v); err != nil {
   156  				return addStack(fmt.Sprintf("[%d]", i), exp, err)
   157  			}
   158  		}
   159  		if err := s.ListEnd(); err != nil {
   160  			return addStack("ListEnd", exp, err)
   161  		}
   162  	default:
   163  		panic(fmt.Errorf("unhandled type: %T", exp))
   164  	}
   165  	return nil
   166  }
   167  
   168  func addStack(op string, val interface{}, err error) error {
   169  	lines := strings.Split(err.Error(), "\n")
   170  	lines = append(lines, fmt.Sprintf("\t%s: %v", op, val))
   171  	return errors.New(strings.Join(lines, "\n"))
   172  }