github.com/dolthub/go-mysql-server@v0.18.0/server/golden/validator.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package golden
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"fmt"
    21  	"sort"
    22  	"strings"
    23  
    24  	"github.com/dolthub/vitess/go/mysql"
    25  	"github.com/dolthub/vitess/go/sqltypes"
    26  	"github.com/dolthub/vitess/go/vt/proto/query"
    27  	"github.com/dolthub/vitess/go/vt/sqlparser"
    28  	_ "github.com/go-sql-driver/mysql"
    29  	"github.com/sirupsen/logrus"
    30  	"golang.org/x/sync/errgroup"
    31  
    32  	"github.com/dolthub/go-mysql-server/sql"
    33  )
    34  
    35  type Validator struct {
    36  	handler mysql.Handler
    37  	golden  MySqlProxy
    38  	logger  *logrus.Logger
    39  }
    40  
    41  // NewValidatingHandler creates a new Validator wrapping a MySQL connection.
    42  func NewValidatingHandler(handler mysql.Handler, mySqlConn string, logger *logrus.Logger) (Validator, error) {
    43  	golden, err := NewMySqlProxyHandler(logger, mySqlConn)
    44  	if err != nil {
    45  		return Validator{}, err
    46  	}
    47  
    48  	// todo: setup mirroring
    49  	//  - assert that both |handler| and |golden| are
    50  	//    working against empty databases
    51  	//  - possibly sync database set between both
    52  
    53  	return Validator{
    54  		handler: handler,
    55  		golden:  golden,
    56  		logger:  logger,
    57  	}, nil
    58  }
    59  
    60  var _ mysql.Handler = Validator{}
    61  
    62  // NewConnection reports that a new connection has been established.
    63  func (v Validator) NewConnection(c *mysql.Conn) {
    64  	return
    65  }
    66  
    67  func (v Validator) ComInitDB(c *mysql.Conn, schemaName string) error {
    68  	if err := v.handler.ComInitDB(c, schemaName); err != nil {
    69  		return err
    70  	}
    71  	return v.golden.ComInitDB(c, schemaName)
    72  }
    73  
    74  // ComPrepare parses, partially analyzes, and caches a prepared statement's plan
    75  // with the given [c.ConnectionID].
    76  func (v Validator) ComPrepare(_ *mysql.Conn, _ string, _ *mysql.PrepareData) ([]*query.Field, error) {
    77  	return nil, fmt.Errorf("ComPrepare unsupported")
    78  }
    79  
    80  func (v Validator) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
    81  	return fmt.Errorf("ComStmtExecute unsupported")
    82  }
    83  
    84  func (v Validator) ComResetConnection(_ *mysql.Conn) error {
    85  	return nil
    86  }
    87  
    88  // ConnectionClosed reports that a connection has been closed.
    89  func (v Validator) ConnectionClosed(c *mysql.Conn) {
    90  	v.handler.ConnectionClosed(c)
    91  	v.golden.ConnectionClosed(c)
    92  }
    93  
    94  func (v Validator) ComMultiQuery(
    95  	c *mysql.Conn,
    96  	query string,
    97  	callback mysql.ResultSpoolFn,
    98  ) (string, error) {
    99  	ag := newResultAggregator(callback)
   100  	var remainder string
   101  	eg, _ := errgroup.WithContext(context.Background())
   102  	eg.Go(func() (err error) {
   103  		remainder, err = v.handler.ComMultiQuery(c, query, ag.processResults)
   104  		return
   105  	})
   106  	eg.Go(func() error {
   107  		// ignore errors from MySQL connection
   108  		_, _ = v.golden.ComMultiQuery(c, query, ag.processGoldenResults)
   109  		return nil
   110  	})
   111  
   112  	err := eg.Wait()
   113  	if err != nil {
   114  		return "", err
   115  	}
   116  	ag.compareResults(v.getLogger(c).WithField("query", query))
   117  
   118  	return remainder, nil
   119  }
   120  
   121  // ComQuery executes a SQL query on the SQLe engine.
   122  func (v Validator) ComQuery(
   123  	c *mysql.Conn,
   124  	query string,
   125  	callback mysql.ResultSpoolFn,
   126  ) error {
   127  	ag := newResultAggregator(callback)
   128  	eg, _ := errgroup.WithContext(context.Background())
   129  	eg.Go(func() error {
   130  		return v.handler.ComQuery(c, query, ag.processResults)
   131  	})
   132  	eg.Go(func() error {
   133  		// ignore errors from MySQL connection
   134  		_ = v.golden.ComQuery(c, query, ag.processGoldenResults)
   135  		return nil
   136  	})
   137  
   138  	err := eg.Wait()
   139  	if err != nil {
   140  		return err
   141  	}
   142  	ag.compareResults(v.getLogger(c).WithField("query", query))
   143  	return nil
   144  }
   145  
   146  // ComQuery executes a SQL query on the SQLe engine.
   147  func (v Validator) ComParsedQuery(
   148  	c *mysql.Conn,
   149  	query string,
   150  	parsed sqlparser.Statement,
   151  	callback func(*sqltypes.Result, bool) error,
   152  ) error {
   153  	return v.ComQuery(c, query, callback)
   154  }
   155  
   156  // WarningCount is called at the end of each query to obtain
   157  // the value to be returned to the client in the EOF packet.
   158  // Note that this will be called either in the context of the
   159  // ComQuery resultsCB if the result does not contain any fields,
   160  // or after the last ComQuery call completes.
   161  func (v Validator) WarningCount(c *mysql.Conn) uint16 {
   162  	return 0
   163  }
   164  
   165  func (v Validator) ParserOptionsForConnection(_ *mysql.Conn) (sqlparser.ParserOptions, error) {
   166  	return sqlparser.ParserOptions{}, nil
   167  }
   168  
   169  func (v Validator) getLogger(c *mysql.Conn) *logrus.Entry {
   170  	return logrus.NewEntry(v.logger).WithField(
   171  		sql.ConnectionIdLogField, c.ConnectionID)
   172  }
   173  
   174  type aggregator struct {
   175  	results  []*sqltypes.Result
   176  	golden   []*sqltypes.Result
   177  	callback func(*sqltypes.Result, bool) error
   178  }
   179  
   180  const maxRows = 1024
   181  
   182  func newResultAggregator(cb func(*sqltypes.Result, bool) error) *aggregator {
   183  	return &aggregator{callback: cb}
   184  }
   185  
   186  func (ag *aggregator) processResults(result *sqltypes.Result, more bool) error {
   187  	if len(ag.results) <= maxRows {
   188  		ag.results = append(ag.results, result)
   189  	}
   190  	return ag.callback(result, more)
   191  }
   192  
   193  func (ag *aggregator) processGoldenResults(result *sqltypes.Result, _ bool) error {
   194  	if len(ag.golden) <= maxRows {
   195  		ag.golden = append(ag.golden, result)
   196  	}
   197  	return nil
   198  }
   199  
   200  func (ag *aggregator) compareResults(logger *logrus.Entry) {
   201  	actual, err := sortResults(ag.results)
   202  	if err != nil {
   203  		logger.Errorf("Error comparing result sets (%s)", err)
   204  	}
   205  	expected, err := sortResults(ag.golden)
   206  	if err != nil {
   207  		logger.Errorf("Error comparing result sets (%s)", err)
   208  	}
   209  	logger.Debugf("Validting query expected=(%d) actual=(%d)",
   210  		len(actual), len(expected))
   211  
   212  	if len(actual) > maxRows || len(expected) > maxRows {
   213  		logger.Warnf("result set too large to validate")
   214  		return
   215  	}
   216  
   217  	if len(actual) != len(expected) {
   218  		logger.Warnf("Incorrect result set expected=%s actual=%s)",
   219  			formatRowSet(actual), formatRowSet(expected))
   220  		return
   221  	}
   222  	for i := range actual {
   223  		left, right := actual[i], expected[i]
   224  		cmp, err := compareRows(left, right)
   225  		if err != nil {
   226  			logger.Errorf("Error comparing result sets (%s)", err)
   227  			return
   228  		} else if cmp != 0 {
   229  			logger.Warnf("Incorrect result set expected=%s actual=%s)",
   230  				formatRowSet(actual), formatRowSet(expected))
   231  			return
   232  		}
   233  	}
   234  	return
   235  }
   236  
   237  func sortResults(results []*sqltypes.Result) ([][]sqltypes.Value, error) {
   238  	var sz uint64
   239  	for _, r := range results {
   240  		sz += r.RowsAffected
   241  	}
   242  	rows := make([][]sqltypes.Value, 0, sz)
   243  	for _, r := range results {
   244  		rows = append(rows, r.Rows...)
   245  	}
   246  
   247  	var cerr error
   248  	sort.Slice(rows, func(i, j int) bool {
   249  		cmp, err := compareRows(rows[i], rows[j])
   250  		if err != nil {
   251  			cerr = err
   252  		}
   253  		return cmp < 0
   254  	})
   255  	if cerr != nil {
   256  		return nil, cerr
   257  	}
   258  	return rows, nil
   259  }
   260  
   261  func compareRows(left, right []sqltypes.Value) (cmp int, err error) {
   262  	if len(left) != len(right) {
   263  		return 0, fmt.Errorf("rows differ in length (%s != %s)",
   264  			formatRow(left), formatRow(right))
   265  	}
   266  	for i := range left {
   267  		cmp, err = sqltypes.NullsafeCompare(left[i], right[i])
   268  		if err != nil {
   269  			// ignore incompatible types error if types equal
   270  			if left[i].Type() == right[i].Type() {
   271  				cmp = bytes.Compare(left[i].Raw(), right[i].Raw())
   272  				err = nil
   273  			} else {
   274  				return 0, err
   275  			}
   276  		}
   277  		if cmp != 0 {
   278  			break
   279  		}
   280  	}
   281  	return
   282  }
   283  
   284  func formatRowSet(rows [][]sqltypes.Value) string {
   285  	var seenOne bool
   286  	var sb strings.Builder
   287  	sb.WriteString("{")
   288  	for _, r := range rows {
   289  		if seenOne {
   290  			sb.WriteRune(',')
   291  		}
   292  		seenOne = true
   293  		sb.WriteString(formatRow(r))
   294  	}
   295  	sb.WriteString("}")
   296  	return sb.String()
   297  }
   298  
   299  func formatRow(row []sqltypes.Value) string {
   300  	var seenOne bool
   301  	var sb strings.Builder
   302  	sb.WriteRune('[')
   303  	for _, v := range row {
   304  		if seenOne {
   305  			sb.WriteRune(',')
   306  		}
   307  		seenOne = true
   308  		sb.WriteString(v.String())
   309  	}
   310  	sb.WriteRune(']')
   311  	return sb.String()
   312  }