github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/tests/rlp_test_util.go (about) 1 package tests 2 3 import ( 4 "bytes" 5 "encoding/hex" 6 "errors" 7 "fmt" 8 "math/big" 9 "strings" 10 11 "github.com/neatlab/neatio/utilities/rlp" 12 ) 13 14 type RLPTest struct { 15 In interface{} 16 17 Out string 18 } 19 20 func (t *RLPTest) Run() error { 21 outb, err := hex.DecodeString(t.Out) 22 if err != nil { 23 return fmt.Errorf("invalid hex in Out") 24 } 25 26 if t.In == "VALID" || t.In == "INVALID" { 27 return checkDecodeInterface(outb, t.In == "VALID") 28 } 29 30 in := translateJSON(t.In) 31 b, err := rlp.EncodeToBytes(in) 32 if err != nil { 33 return fmt.Errorf("encode failed: %v", err) 34 } 35 if !bytes.Equal(b, outb) { 36 return fmt.Errorf("encode produced %x, want %x", b, outb) 37 } 38 s := rlp.NewStream(bytes.NewReader(outb), 0) 39 return checkDecodeFromJSON(s, in) 40 } 41 42 func checkDecodeInterface(b []byte, isValid bool) error { 43 err := rlp.DecodeBytes(b, new(interface{})) 44 switch { 45 case isValid && err != nil: 46 return fmt.Errorf("decoding failed: %v", err) 47 case !isValid && err == nil: 48 return fmt.Errorf("decoding of invalid value succeeded") 49 } 50 return nil 51 } 52 53 func translateJSON(v interface{}) interface{} { 54 switch v := v.(type) { 55 case float64: 56 return uint64(v) 57 case string: 58 if len(v) > 0 && v[0] == '#' { 59 big, ok := new(big.Int).SetString(v[1:], 10) 60 if !ok { 61 panic(fmt.Errorf("bad test: bad big int: %q", v)) 62 } 63 return big 64 } 65 return []byte(v) 66 case []interface{}: 67 new := make([]interface{}, len(v)) 68 for i := range v { 69 new[i] = translateJSON(v[i]) 70 } 71 return new 72 default: 73 panic(fmt.Errorf("can't handle %T", v)) 74 } 75 } 76 77 func checkDecodeFromJSON(s *rlp.Stream, exp interface{}) error { 78 switch exp := exp.(type) { 79 case uint64: 80 i, err := s.Uint() 81 if err != nil { 82 return addStack("Uint", exp, err) 83 } 84 if i != exp { 85 return addStack("Uint", exp, fmt.Errorf("result mismatch: got %d", i)) 86 } 87 case *big.Int: 88 big := new(big.Int) 89 if err := s.Decode(&big); err != nil { 90 return addStack("Big", exp, err) 91 } 92 if big.Cmp(exp) != 0 { 93 return addStack("Big", exp, fmt.Errorf("result mismatch: got %d", big)) 94 } 95 case []byte: 96 b, err := s.Bytes() 97 if err != nil { 98 return addStack("Bytes", exp, err) 99 } 100 if !bytes.Equal(b, exp) { 101 return addStack("Bytes", exp, fmt.Errorf("result mismatch: got %x", b)) 102 } 103 case []interface{}: 104 if _, err := s.List(); err != nil { 105 return addStack("List", exp, err) 106 } 107 for i, v := range exp { 108 if err := checkDecodeFromJSON(s, v); err != nil { 109 return addStack(fmt.Sprintf("[%d]", i), exp, err) 110 } 111 } 112 if err := s.ListEnd(); err != nil { 113 return addStack("ListEnd", exp, err) 114 } 115 default: 116 panic(fmt.Errorf("unhandled type: %T", exp)) 117 } 118 return nil 119 } 120 121 func addStack(op string, val interface{}, err error) error { 122 lines := strings.Split(err.Error(), "\n") 123 lines = append(lines, fmt.Sprintf("\t%s: %v", op, val)) 124 return errors.New(strings.Join(lines, "\n")) 125 }