github.com/linapex/ethereum-dpos-chinese@v0.0.0-20190316121959-b78b3a4a1ece/tests/rlp_test_util.go (about)

     1  
     2  //<developer>
     3  //    <name>linapex 曹一峰</name>
     4  //    <email>linapex@163.com</email>
     5  //    <wx>superexc</wx>
     6  //    <qqgroup>128148617</qqgroup>
     7  //    <url>https://jsq.ink</url>
     8  //    <role>pku engineer</role>
     9  //    <date>2019-03-16 12:09:50</date>
    10  //</624342686060515328>
    11  
    12  //
    13  //
    14  //
    15  //
    16  //
    17  //
    18  //
    19  //
    20  //
    21  //
    22  //
    23  //
    24  //
    25  //
    26  //
    27  
    28  package tests
    29  
    30  import (
    31  	"bytes"
    32  	"encoding/hex"
    33  	"errors"
    34  	"fmt"
    35  	"math/big"
    36  	"strings"
    37  
    38  	"github.com/ethereum/go-ethereum/rlp"
    39  )
    40  
    41  //
    42  type RLPTest struct {
    43  //
    44  //
    45  //
    46  //
    47  //
    48  //
    49  //
    50  	In interface{}
    51  
    52  //
    53  	Out string
    54  }
    55  
    56  //
    57  func (t *RLPTest) Run() error {
    58  	outb, err := hex.DecodeString(t.Out)
    59  	if err != nil {
    60  		return fmt.Errorf("invalid hex in Out")
    61  	}
    62  
    63  //
    64  	if t.In == "VALID" || t.In == "INVALID" {
    65  		return checkDecodeInterface(outb, t.In == "VALID")
    66  	}
    67  
    68  //
    69  	in := translateJSON(t.In)
    70  	b, err := rlp.EncodeToBytes(in)
    71  	if err != nil {
    72  		return fmt.Errorf("encode failed: %v", err)
    73  	}
    74  	if !bytes.Equal(b, outb) {
    75  		return fmt.Errorf("encode produced %x, want %x", b, outb)
    76  	}
    77  //
    78  	s := rlp.NewStream(bytes.NewReader(outb), 0)
    79  	return checkDecodeFromJSON(s, in)
    80  }
    81  
    82  func checkDecodeInterface(b []byte, isValid bool) error {
    83  	err := rlp.DecodeBytes(b, new(interface{}))
    84  	switch {
    85  	case isValid && err != nil:
    86  		return fmt.Errorf("decoding failed: %v", err)
    87  	case !isValid && err == nil:
    88  		return fmt.Errorf("decoding of invalid value succeeded")
    89  	}
    90  	return nil
    91  }
    92  
    93  //
    94  func translateJSON(v interface{}) interface{} {
    95  	switch v := v.(type) {
    96  	case float64:
    97  		return uint64(v)
    98  	case string:
    99  if len(v) > 0 && v[0] == '#' { //
   100  			big, ok := new(big.Int).SetString(v[1:], 10)
   101  			if !ok {
   102  				panic(fmt.Errorf("bad test: bad big int: %q", v))
   103  			}
   104  			return big
   105  		}
   106  		return []byte(v)
   107  	case []interface{}:
   108  		new := make([]interface{}, len(v))
   109  		for i := range v {
   110  			new[i] = translateJSON(v[i])
   111  		}
   112  		return new
   113  	default:
   114  		panic(fmt.Errorf("can't handle %T", v))
   115  	}
   116  }
   117  
   118  //
   119  //
   120  //
   121  //
   122  func checkDecodeFromJSON(s *rlp.Stream, exp interface{}) error {
   123  	switch exp := exp.(type) {
   124  	case uint64:
   125  		i, err := s.Uint()
   126  		if err != nil {
   127  			return addStack("Uint", exp, err)
   128  		}
   129  		if i != exp {
   130  			return addStack("Uint", exp, fmt.Errorf("result mismatch: got %d", i))
   131  		}
   132  	case *big.Int:
   133  		big := new(big.Int)
   134  		if err := s.Decode(&big); err != nil {
   135  			return addStack("Big", exp, err)
   136  		}
   137  		if big.Cmp(exp) != 0 {
   138  			return addStack("Big", exp, fmt.Errorf("result mismatch: got %d", big))
   139  		}
   140  	case []byte:
   141  		b, err := s.Bytes()
   142  		if err != nil {
   143  			return addStack("Bytes", exp, err)
   144  		}
   145  		if !bytes.Equal(b, exp) {
   146  			return addStack("Bytes", exp, fmt.Errorf("result mismatch: got %x", b))
   147  		}
   148  	case []interface{}:
   149  		if _, err := s.List(); err != nil {
   150  			return addStack("List", exp, err)
   151  		}
   152  		for i, v := range exp {
   153  			if err := checkDecodeFromJSON(s, v); err != nil {
   154  				return addStack(fmt.Sprintf("[%d]", i), exp, err)
   155  			}
   156  		}
   157  		if err := s.ListEnd(); err != nil {
   158  			return addStack("ListEnd", exp, err)
   159  		}
   160  	default:
   161  		panic(fmt.Errorf("unhandled type: %T", exp))
   162  	}
   163  	return nil
   164  }
   165  
   166  func addStack(op string, val interface{}, err error) error {
   167  	lines := strings.Split(err.Error(), "\n")
   168  	lines = append(lines, fmt.Sprintf("\t%s: %v", op, val))
   169  	return errors.New(strings.Join(lines, "\n"))
   170  }
   171