github.com/dolthub/go-mysql-server@v0.18.0/server/golden/validator.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package golden 16 17 import ( 18 "bytes" 19 "context" 20 "fmt" 21 "sort" 22 "strings" 23 24 "github.com/dolthub/vitess/go/mysql" 25 "github.com/dolthub/vitess/go/sqltypes" 26 "github.com/dolthub/vitess/go/vt/proto/query" 27 "github.com/dolthub/vitess/go/vt/sqlparser" 28 _ "github.com/go-sql-driver/mysql" 29 "github.com/sirupsen/logrus" 30 "golang.org/x/sync/errgroup" 31 32 "github.com/dolthub/go-mysql-server/sql" 33 ) 34 35 type Validator struct { 36 handler mysql.Handler 37 golden MySqlProxy 38 logger *logrus.Logger 39 } 40 41 // NewValidatingHandler creates a new Validator wrapping a MySQL connection. 42 func NewValidatingHandler(handler mysql.Handler, mySqlConn string, logger *logrus.Logger) (Validator, error) { 43 golden, err := NewMySqlProxyHandler(logger, mySqlConn) 44 if err != nil { 45 return Validator{}, err 46 } 47 48 // todo: setup mirroring 49 // - assert that both |handler| and |golden| are 50 // working against empty databases 51 // - possibly sync database set between both 52 53 return Validator{ 54 handler: handler, 55 golden: golden, 56 logger: logger, 57 }, nil 58 } 59 60 var _ mysql.Handler = Validator{} 61 62 // NewConnection reports that a new connection has been established. 63 func (v Validator) NewConnection(c *mysql.Conn) { 64 return 65 } 66 67 func (v Validator) ComInitDB(c *mysql.Conn, schemaName string) error { 68 if err := v.handler.ComInitDB(c, schemaName); err != nil { 69 return err 70 } 71 return v.golden.ComInitDB(c, schemaName) 72 } 73 74 // ComPrepare parses, partially analyzes, and caches a prepared statement's plan 75 // with the given [c.ConnectionID]. 76 func (v Validator) ComPrepare(_ *mysql.Conn, _ string, _ *mysql.PrepareData) ([]*query.Field, error) { 77 return nil, fmt.Errorf("ComPrepare unsupported") 78 } 79 80 func (v Validator) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { 81 return fmt.Errorf("ComStmtExecute unsupported") 82 } 83 84 func (v Validator) ComResetConnection(_ *mysql.Conn) error { 85 return nil 86 } 87 88 // ConnectionClosed reports that a connection has been closed. 89 func (v Validator) ConnectionClosed(c *mysql.Conn) { 90 v.handler.ConnectionClosed(c) 91 v.golden.ConnectionClosed(c) 92 } 93 94 func (v Validator) ComMultiQuery( 95 c *mysql.Conn, 96 query string, 97 callback mysql.ResultSpoolFn, 98 ) (string, error) { 99 ag := newResultAggregator(callback) 100 var remainder string 101 eg, _ := errgroup.WithContext(context.Background()) 102 eg.Go(func() (err error) { 103 remainder, err = v.handler.ComMultiQuery(c, query, ag.processResults) 104 return 105 }) 106 eg.Go(func() error { 107 // ignore errors from MySQL connection 108 _, _ = v.golden.ComMultiQuery(c, query, ag.processGoldenResults) 109 return nil 110 }) 111 112 err := eg.Wait() 113 if err != nil { 114 return "", err 115 } 116 ag.compareResults(v.getLogger(c).WithField("query", query)) 117 118 return remainder, nil 119 } 120 121 // ComQuery executes a SQL query on the SQLe engine. 122 func (v Validator) ComQuery( 123 c *mysql.Conn, 124 query string, 125 callback mysql.ResultSpoolFn, 126 ) error { 127 ag := newResultAggregator(callback) 128 eg, _ := errgroup.WithContext(context.Background()) 129 eg.Go(func() error { 130 return v.handler.ComQuery(c, query, ag.processResults) 131 }) 132 eg.Go(func() error { 133 // ignore errors from MySQL connection 134 _ = v.golden.ComQuery(c, query, ag.processGoldenResults) 135 return nil 136 }) 137 138 err := eg.Wait() 139 if err != nil { 140 return err 141 } 142 ag.compareResults(v.getLogger(c).WithField("query", query)) 143 return nil 144 } 145 146 // ComQuery executes a SQL query on the SQLe engine. 147 func (v Validator) ComParsedQuery( 148 c *mysql.Conn, 149 query string, 150 parsed sqlparser.Statement, 151 callback func(*sqltypes.Result, bool) error, 152 ) error { 153 return v.ComQuery(c, query, callback) 154 } 155 156 // WarningCount is called at the end of each query to obtain 157 // the value to be returned to the client in the EOF packet. 158 // Note that this will be called either in the context of the 159 // ComQuery resultsCB if the result does not contain any fields, 160 // or after the last ComQuery call completes. 161 func (v Validator) WarningCount(c *mysql.Conn) uint16 { 162 return 0 163 } 164 165 func (v Validator) ParserOptionsForConnection(_ *mysql.Conn) (sqlparser.ParserOptions, error) { 166 return sqlparser.ParserOptions{}, nil 167 } 168 169 func (v Validator) getLogger(c *mysql.Conn) *logrus.Entry { 170 return logrus.NewEntry(v.logger).WithField( 171 sql.ConnectionIdLogField, c.ConnectionID) 172 } 173 174 type aggregator struct { 175 results []*sqltypes.Result 176 golden []*sqltypes.Result 177 callback func(*sqltypes.Result, bool) error 178 } 179 180 const maxRows = 1024 181 182 func newResultAggregator(cb func(*sqltypes.Result, bool) error) *aggregator { 183 return &aggregator{callback: cb} 184 } 185 186 func (ag *aggregator) processResults(result *sqltypes.Result, more bool) error { 187 if len(ag.results) <= maxRows { 188 ag.results = append(ag.results, result) 189 } 190 return ag.callback(result, more) 191 } 192 193 func (ag *aggregator) processGoldenResults(result *sqltypes.Result, _ bool) error { 194 if len(ag.golden) <= maxRows { 195 ag.golden = append(ag.golden, result) 196 } 197 return nil 198 } 199 200 func (ag *aggregator) compareResults(logger *logrus.Entry) { 201 actual, err := sortResults(ag.results) 202 if err != nil { 203 logger.Errorf("Error comparing result sets (%s)", err) 204 } 205 expected, err := sortResults(ag.golden) 206 if err != nil { 207 logger.Errorf("Error comparing result sets (%s)", err) 208 } 209 logger.Debugf("Validting query expected=(%d) actual=(%d)", 210 len(actual), len(expected)) 211 212 if len(actual) > maxRows || len(expected) > maxRows { 213 logger.Warnf("result set too large to validate") 214 return 215 } 216 217 if len(actual) != len(expected) { 218 logger.Warnf("Incorrect result set expected=%s actual=%s)", 219 formatRowSet(actual), formatRowSet(expected)) 220 return 221 } 222 for i := range actual { 223 left, right := actual[i], expected[i] 224 cmp, err := compareRows(left, right) 225 if err != nil { 226 logger.Errorf("Error comparing result sets (%s)", err) 227 return 228 } else if cmp != 0 { 229 logger.Warnf("Incorrect result set expected=%s actual=%s)", 230 formatRowSet(actual), formatRowSet(expected)) 231 return 232 } 233 } 234 return 235 } 236 237 func sortResults(results []*sqltypes.Result) ([][]sqltypes.Value, error) { 238 var sz uint64 239 for _, r := range results { 240 sz += r.RowsAffected 241 } 242 rows := make([][]sqltypes.Value, 0, sz) 243 for _, r := range results { 244 rows = append(rows, r.Rows...) 245 } 246 247 var cerr error 248 sort.Slice(rows, func(i, j int) bool { 249 cmp, err := compareRows(rows[i], rows[j]) 250 if err != nil { 251 cerr = err 252 } 253 return cmp < 0 254 }) 255 if cerr != nil { 256 return nil, cerr 257 } 258 return rows, nil 259 } 260 261 func compareRows(left, right []sqltypes.Value) (cmp int, err error) { 262 if len(left) != len(right) { 263 return 0, fmt.Errorf("rows differ in length (%s != %s)", 264 formatRow(left), formatRow(right)) 265 } 266 for i := range left { 267 cmp, err = sqltypes.NullsafeCompare(left[i], right[i]) 268 if err != nil { 269 // ignore incompatible types error if types equal 270 if left[i].Type() == right[i].Type() { 271 cmp = bytes.Compare(left[i].Raw(), right[i].Raw()) 272 err = nil 273 } else { 274 return 0, err 275 } 276 } 277 if cmp != 0 { 278 break 279 } 280 } 281 return 282 } 283 284 func formatRowSet(rows [][]sqltypes.Value) string { 285 var seenOne bool 286 var sb strings.Builder 287 sb.WriteString("{") 288 for _, r := range rows { 289 if seenOne { 290 sb.WriteRune(',') 291 } 292 seenOne = true 293 sb.WriteString(formatRow(r)) 294 } 295 sb.WriteString("}") 296 return sb.String() 297 } 298 299 func formatRow(row []sqltypes.Value) string { 300 var seenOne bool 301 var sb strings.Builder 302 sb.WriteRune('[') 303 for _, v := range row { 304 if seenOne { 305 sb.WriteRune(',') 306 } 307 seenOne = true 308 sb.WriteString(v.String()) 309 } 310 sb.WriteRune(']') 311 return sb.String() 312 }