github.com/vedadiyan/sqlparser@v1.0.0/pkg/sqltypes/parse_rows.go (about) 1 /* 2 Copyright 2023 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqltypes 18 19 import ( 20 "fmt" 21 "io" 22 "reflect" 23 "strconv" 24 "strings" 25 "text/scanner" 26 27 querypb "github.com/vedadiyan/sqlparser/pkg/query" 28 ) 29 30 // ParseRows parses the output generated by fmt.Sprintf("#v", rows), and reifies the original []sqltypes.Row 31 // NOTE: This is not meant for production use! 32 func ParseRows(input string) ([]Row, error) { 33 type state int 34 const ( 35 stInvalid state = iota 36 stInit 37 stBeginRow 38 stInRow 39 stInValue0 40 stInValue1 41 stInValue2 42 ) 43 44 var ( 45 scan scanner.Scanner 46 result []Row 47 row Row 48 vtype int32 49 st = stInit 50 ) 51 52 scan.Init(strings.NewReader(input)) 53 54 for tok := scan.Scan(); tok != scanner.EOF; tok = scan.Scan() { 55 var next state 56 57 switch st { 58 case stInit: 59 if tok == '[' { 60 next = stBeginRow 61 } 62 case stBeginRow: 63 switch tok { 64 case '[': 65 next = stInRow 66 case ']': 67 return result, nil 68 } 69 case stInRow: 70 switch tok { 71 case ']': 72 result = append(result, row) 73 row = nil 74 next = stBeginRow 75 case scanner.Ident: 76 ident := scan.TokenText() 77 78 if ident == "NULL" { 79 row = append(row, NULL) 80 continue 81 } 82 83 var ok bool 84 vtype, ok = querypb.Type_value[ident] 85 if !ok { 86 return nil, fmt.Errorf("unknown SQL type %q at %s", ident, scan.Position) 87 } 88 next = stInValue0 89 } 90 case stInValue0: 91 if tok == '(' { 92 next = stInValue1 93 } 94 case stInValue1: 95 literal := scan.TokenText() 96 switch tok { 97 case scanner.String: 98 var err error 99 literal, err = strconv.Unquote(literal) 100 if err != nil { 101 return nil, fmt.Errorf("failed to parse literal string at %s: %w", scan.Position, err) 102 } 103 fallthrough 104 case scanner.Int, scanner.Float: 105 row = append(row, MakeTrusted(Type(vtype), []byte(literal))) 106 next = stInValue2 107 } 108 case stInValue2: 109 if tok == ')' { 110 next = stInRow 111 } 112 } 113 if next == stInvalid { 114 return nil, fmt.Errorf("unexpected token '%s' at %s", scan.TokenText(), scan.Position) 115 } 116 st = next 117 } 118 return nil, io.ErrUnexpectedEOF 119 } 120 121 type RowMismatchError struct { 122 err error 123 want, got []Row 124 } 125 126 func (e *RowMismatchError) Error() string { 127 return fmt.Sprintf("results differ: %v\n\twant: %v\n\tgot: %v", e.err, e.want, e.got) 128 } 129 130 func RowsEquals(want, got []Row) error { 131 if len(want) != len(got) { 132 return &RowMismatchError{ 133 err: fmt.Errorf("expected %d rows in result, got %d", len(want), len(got)), 134 want: want, 135 got: got, 136 } 137 } 138 139 var matched = make([]bool, len(want)) 140 for _, aa := range want { 141 var ok bool 142 for i, bb := range got { 143 if matched[i] { 144 continue 145 } 146 if reflect.DeepEqual(aa, bb) { 147 matched[i] = true 148 ok = true 149 break 150 } 151 } 152 if !ok { 153 return &RowMismatchError{ 154 err: fmt.Errorf("row %v is missing from result", aa), 155 want: want, 156 got: got, 157 } 158 } 159 } 160 for _, m := range matched { 161 if !m { 162 return fmt.Errorf("not all elements matched") 163 } 164 } 165 return nil 166 } 167 168 func RowsEqualsStr(wantStr string, got []Row) error { 169 want, err := ParseRows(wantStr) 170 if err != nil { 171 return fmt.Errorf("malformed row assertion: %w", err) 172 } 173 return RowsEquals(want, got) 174 }