
     1  /*
     2  Copyright 2023 The Vitess Authors.
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    17  package sqltypes
    19  import (
    20  	"fmt"
    21  	"io"
    22  	"reflect"
    23  	"strconv"
    24  	"strings"
    25  	"text/scanner"
    27  	querypb ""
    28  )
    30  // ParseRows parses the output generated by fmt.Sprintf("#v", rows), and reifies the original []sqltypes.Row
    31  // NOTE: This is not meant for production use!
    32  func ParseRows(input string) ([]Row, error) {
    33  	type state int
    34  	const (
    35  		stInvalid state = iota
    36  		stInit
    37  		stBeginRow
    38  		stInRow
    39  		stInValue0
    40  		stInValue1
    41  		stInValue2
    42  	)
    44  	var (
    45  		scan   scanner.Scanner
    46  		result []Row
    47  		row    Row
    48  		vtype  int32
    49  		st     = stInit
    50  	)
    52  	scan.Init(strings.NewReader(input))
    54  	for tok := scan.Scan(); tok != scanner.EOF; tok = scan.Scan() {
    55  		var next state
    57  		switch st {
    58  		case stInit:
    59  			if tok == '[' {
    60  				next = stBeginRow
    61  			}
    62  		case stBeginRow:
    63  			switch tok {
    64  			case '[':
    65  				next = stInRow
    66  			case ']':
    67  				return result, nil
    68  			}
    69  		case stInRow:
    70  			switch tok {
    71  			case ']':
    72  				result = append(result, row)
    73  				row = nil
    74  				next = stBeginRow
    75  			case scanner.Ident:
    76  				ident := scan.TokenText()
    78  				if ident == "NULL" {
    79  					row = append(row, NULL)
    80  					continue
    81  				}
    83  				var ok bool
    84  				vtype, ok = querypb.Type_value[ident]
    85  				if !ok {
    86  					return nil, fmt.Errorf("unknown SQL type %q at %s", ident, scan.Position)
    87  				}
    88  				next = stInValue0
    89  			}
    90  		case stInValue0:
    91  			if tok == '(' {
    92  				next = stInValue1
    93  			}
    94  		case stInValue1:
    95  			literal := scan.TokenText()
    96  			switch tok {
    97  			case scanner.String:
    98  				var err error
    99  				literal, err = strconv.Unquote(literal)
   100  				if err != nil {
   101  					return nil, fmt.Errorf("failed to parse literal string at %s: %w", scan.Position, err)
   102  				}
   103  				fallthrough
   104  			case scanner.Int, scanner.Float:
   105  				row = append(row, MakeTrusted(Type(vtype), []byte(literal)))
   106  				next = stInValue2
   107  			}
   108  		case stInValue2:
   109  			if tok == ')' {
   110  				next = stInRow
   111  			}
   112  		}
   113  		if next == stInvalid {
   114  			return nil, fmt.Errorf("unexpected token '%s' at %s", scan.TokenText(), scan.Position)
   115  		}
   116  		st = next
   117  	}
   118  	return nil, io.ErrUnexpectedEOF
   119  }
   121  type RowMismatchError struct {
   122  	err       error
   123  	want, got []Row
   124  }
   126  func (e *RowMismatchError) Error() string {
   127  	return fmt.Sprintf("results differ: %v\n\twant: %v\n\tgot:  %v", e.err, e.want,
   128  }
   130  func RowsEquals(want, got []Row) error {
   131  	if len(want) != len(got) {
   132  		return &RowMismatchError{
   133  			err:  fmt.Errorf("expected %d rows in result, got %d", len(want), len(got)),
   134  			want: want,
   135  			got:  got,
   136  		}
   137  	}
   139  	var matched = make([]bool, len(want))
   140  	for _, aa := range want {
   141  		var ok bool
   142  		for i, bb := range got {
   143  			if matched[i] {
   144  				continue
   145  			}
   146  			if reflect.DeepEqual(aa, bb) {
   147  				matched[i] = true
   148  				ok = true
   149  				break
   150  			}
   151  		}
   152  		if !ok {
   153  			return &RowMismatchError{
   154  				err:  fmt.Errorf("row %v is missing from result", aa),
   155  				want: want,
   156  				got:  got,
   157  			}
   158  		}
   159  	}
   160  	for _, m := range matched {
   161  		if !m {
   162  			return fmt.Errorf("not all elements matched")
   163  		}
   164  	}
   165  	return nil
   166  }
   168  func RowsEqualsStr(wantStr string, got []Row) error {
   169  	want, err := ParseRows(wantStr)
   170  	if err != nil {
   171  		return fmt.Errorf("malformed row assertion: %w", err)
   172  	}
   173  	return RowsEquals(want, got)
   174  }