github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/tests/rlp_test_util.go (about)

     1  package tests
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/hex"
     6  	"errors"
     7  	"fmt"
     8  	"math/big"
     9  	"strings"
    10  
    11  	"github.com/neatlab/neatio/utilities/rlp"
    12  )
    13  
    14  type RLPTest struct {
    15  	In interface{}
    16  
    17  	Out string
    18  }
    19  
    20  func (t *RLPTest) Run() error {
    21  	outb, err := hex.DecodeString(t.Out)
    22  	if err != nil {
    23  		return fmt.Errorf("invalid hex in Out")
    24  	}
    25  
    26  	if t.In == "VALID" || t.In == "INVALID" {
    27  		return checkDecodeInterface(outb, t.In == "VALID")
    28  	}
    29  
    30  	in := translateJSON(t.In)
    31  	b, err := rlp.EncodeToBytes(in)
    32  	if err != nil {
    33  		return fmt.Errorf("encode failed: %v", err)
    34  	}
    35  	if !bytes.Equal(b, outb) {
    36  		return fmt.Errorf("encode produced %x, want %x", b, outb)
    37  	}
    38  	s := rlp.NewStream(bytes.NewReader(outb), 0)
    39  	return checkDecodeFromJSON(s, in)
    40  }
    41  
    42  func checkDecodeInterface(b []byte, isValid bool) error {
    43  	err := rlp.DecodeBytes(b, new(interface{}))
    44  	switch {
    45  	case isValid && err != nil:
    46  		return fmt.Errorf("decoding failed: %v", err)
    47  	case !isValid && err == nil:
    48  		return fmt.Errorf("decoding of invalid value succeeded")
    49  	}
    50  	return nil
    51  }
    52  
    53  func translateJSON(v interface{}) interface{} {
    54  	switch v := v.(type) {
    55  	case float64:
    56  		return uint64(v)
    57  	case string:
    58  		if len(v) > 0 && v[0] == '#' {
    59  			big, ok := new(big.Int).SetString(v[1:], 10)
    60  			if !ok {
    61  				panic(fmt.Errorf("bad test: bad big int: %q", v))
    62  			}
    63  			return big
    64  		}
    65  		return []byte(v)
    66  	case []interface{}:
    67  		new := make([]interface{}, len(v))
    68  		for i := range v {
    69  			new[i] = translateJSON(v[i])
    70  		}
    71  		return new
    72  	default:
    73  		panic(fmt.Errorf("can't handle %T", v))
    74  	}
    75  }
    76  
    77  func checkDecodeFromJSON(s *rlp.Stream, exp interface{}) error {
    78  	switch exp := exp.(type) {
    79  	case uint64:
    80  		i, err := s.Uint()
    81  		if err != nil {
    82  			return addStack("Uint", exp, err)
    83  		}
    84  		if i != exp {
    85  			return addStack("Uint", exp, fmt.Errorf("result mismatch: got %d", i))
    86  		}
    87  	case *big.Int:
    88  		big := new(big.Int)
    89  		if err := s.Decode(&big); err != nil {
    90  			return addStack("Big", exp, err)
    91  		}
    92  		if big.Cmp(exp) != 0 {
    93  			return addStack("Big", exp, fmt.Errorf("result mismatch: got %d", big))
    94  		}
    95  	case []byte:
    96  		b, err := s.Bytes()
    97  		if err != nil {
    98  			return addStack("Bytes", exp, err)
    99  		}
   100  		if !bytes.Equal(b, exp) {
   101  			return addStack("Bytes", exp, fmt.Errorf("result mismatch: got %x", b))
   102  		}
   103  	case []interface{}:
   104  		if _, err := s.List(); err != nil {
   105  			return addStack("List", exp, err)
   106  		}
   107  		for i, v := range exp {
   108  			if err := checkDecodeFromJSON(s, v); err != nil {
   109  				return addStack(fmt.Sprintf("[%d]", i), exp, err)
   110  			}
   111  		}
   112  		if err := s.ListEnd(); err != nil {
   113  			return addStack("ListEnd", exp, err)
   114  		}
   115  	default:
   116  		panic(fmt.Errorf("unhandled type: %T", exp))
   117  	}
   118  	return nil
   119  }
   120  
   121  func addStack(op string, val interface{}, err error) error {
   122  	lines := strings.Split(err.Error(), "\n")
   123  	lines = append(lines, fmt.Sprintf("\t%s: %v", op, val))
   124  	return errors.New(strings.Join(lines, "\n"))
   125  }