vitess.io/vitess@v0.16.2/go/mysql/fakesqldb/server.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 // Package fakesqldb provides a MySQL server for tests. 18 package fakesqldb 19 20 import ( 21 "errors" 22 "fmt" 23 "os" 24 "path" 25 "regexp" 26 "strings" 27 "sync" 28 "sync/atomic" 29 "testing" 30 "time" 31 32 "vitess.io/vitess/go/vt/sqlparser" 33 34 "vitess.io/vitess/go/vt/log" 35 36 "vitess.io/vitess/go/mysql" 37 "vitess.io/vitess/go/sqltypes" 38 39 "vitess.io/vitess/go/vt/dbconfigs" 40 querypb "vitess.io/vitess/go/vt/proto/query" 41 ) 42 43 const appendEntry = -1 44 45 // DB is a fake database and all its methods are thread safe. It 46 // creates a mysql.Listener and implements the mysql.Handler 47 // interface. We use a Unix socket to connect to the database, as 48 // this is the most common way for clients to connect to MySQL. This 49 // impacts the error codes we're getting back: when the server side is 50 // closed, the client queries will return CRServerGone(2006) when sending 51 // the data, as opposed to CRServerLost(2013) when reading the response. 52 type DB struct { 53 mysql.UnimplementedHandler 54 55 // Fields set at construction time. 56 57 // t is our testing.TB instance 58 t testing.TB 59 60 // listener is our mysql.Listener. 61 listener *mysql.Listener 62 63 // socketFile is the path to the unix socket file. 64 socketFile string 65 66 // acceptWG is set when we listen, and can be waited on to 67 // make sure we don't accept any more. 68 acceptWG sync.WaitGroup 69 70 // orderMatters is set when the query order matters. 71 orderMatters atomic.Bool 72 73 // Fields set at runtime. 74 75 // mu protects all the following fields. 76 mu sync.Mutex 77 // name is the name of this DB. Set to 'fakesqldb' by default. 78 // Use SetName() to change. 79 name string 80 // isConnFail trigger a panic in the connection handler. 81 isConnFail atomic.Bool 82 // connDelay causes a sleep in the connection handler 83 connDelay time.Duration 84 // shouldClose, if true, tells ComQuery() to close the connection when 85 // processing the next query. This will trigger a MySQL client error with 86 // errno 2013 ("server lost"). 87 shouldClose atomic.Bool 88 // allowAll: if set to true, ComQuery returns an empty result 89 // for all queries. This flag is used for benchmarking. 90 allowAll atomic.Bool 91 92 // Handler: interface that allows a caller to override the query handling 93 // implementation. By default it points to the DB itself 94 Handler QueryHandler 95 96 // This next set of fields is used when ordering of the queries doesn't 97 // matter. 98 99 // data maps tolower(query) to a result. 100 data map[string]*ExpectedResult 101 // rejectedData maps tolower(query) to an error. 102 rejectedData map[string]error 103 // patternData is a map of regexp queries to results. 104 patternData map[string]exprResult 105 // queryCalled keeps track of how many times a query was called. 106 queryCalled map[string]int 107 // querylog keeps track of all called queries 108 querylog []string 109 110 // This next set of fields is used when ordering of the queries matters. 111 112 // expectedExecuteFetch is the array of expected queries. 113 expectedExecuteFetch []ExpectedExecuteFetch 114 // expectedExecuteFetchIndex is the current index of the query. 115 expectedExecuteFetchIndex int 116 117 // connections tracks all open connections. 118 // The key for the map is the value of mysql.Conn.ConnectionID. 119 connections map[uint32]*mysql.Conn 120 121 // queryPatternUserCallback stores optional callbacks when a query with a pattern is called 122 queryPatternUserCallback map[*regexp.Regexp]func(string) 123 124 // if fakesqldb is asked to serve queries or query patterns that it has not been explicitly told about it will 125 // error out by default. However if you set this flag then any unmatched query results in an empty result 126 neverFail atomic.Bool 127 } 128 129 // QueryHandler is the interface used by the DB to simulate executed queries 130 type QueryHandler interface { 131 HandleQuery(*mysql.Conn, string, func(*sqltypes.Result) error) error 132 } 133 134 // ExpectedResult holds the data for a matched query. 135 type ExpectedResult struct { 136 *sqltypes.Result 137 // BeforeFunc() is synchronously called before the server returns the result. 138 BeforeFunc func() 139 } 140 141 type exprResult struct { 142 queryPattern string 143 expr *regexp.Regexp 144 result *sqltypes.Result 145 err string 146 } 147 148 // ExpectedExecuteFetch defines for an expected query the to be faked output. 149 // It is used for ordered expected output. 150 type ExpectedExecuteFetch struct { 151 Query string 152 QueryResult *sqltypes.Result 153 Error error 154 // AfterFunc is a callback which is executed while the query 155 // is executed i.e., before the fake responds to the client. 156 AfterFunc func() 157 } 158 159 // New creates a server, and starts listening. 160 func New(t testing.TB) *DB { 161 // Pick a path for our socket. 162 socketDir, err := os.MkdirTemp("", "fakesqldb") 163 if err != nil { 164 t.Fatalf("os.MkdirTemp failed: %v", err) 165 } 166 socketFile := path.Join(socketDir, "fakesqldb.sock") 167 168 // Create our DB. 169 db := &DB{ 170 t: t, 171 socketFile: socketFile, 172 name: "fakesqldb", 173 data: make(map[string]*ExpectedResult), 174 rejectedData: make(map[string]error), 175 queryCalled: make(map[string]int), 176 connections: make(map[uint32]*mysql.Conn), 177 queryPatternUserCallback: make(map[*regexp.Regexp]func(string)), 178 patternData: make(map[string]exprResult), 179 } 180 181 db.Handler = db 182 183 authServer := mysql.NewAuthServerNone() 184 185 // Start listening. 186 db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false) 187 if err != nil { 188 t.Fatalf("NewListener failed: %v", err) 189 } 190 191 db.acceptWG.Add(1) 192 go func() { 193 defer db.acceptWG.Done() 194 db.listener.Accept() 195 }() 196 197 db.AddQuery("use `fakesqldb`", &sqltypes.Result{}) 198 // Return the db. 199 return db 200 } 201 202 // Name returns the name of the DB. 203 func (db *DB) Name() string { 204 db.mu.Lock() 205 defer db.mu.Unlock() 206 207 return db.name 208 } 209 210 // SetName sets the name of the DB. to differentiate them in tests if needed. 211 func (db *DB) SetName(name string) *DB { 212 db.mu.Lock() 213 defer db.mu.Unlock() 214 215 db.name = name 216 return db 217 } 218 219 // OrderMatters sets the orderMatters flag. 220 func (db *DB) OrderMatters() { 221 db.orderMatters.Store(true) 222 } 223 224 // Close closes the Listener and waits for it to stop accepting. 225 // It then closes all connections, and cleans up the temporary directory. 226 func (db *DB) Close() { 227 db.listener.Close() 228 db.acceptWG.Wait() 229 230 db.CloseAllConnections() 231 232 tmpDir := path.Dir(db.socketFile) 233 os.RemoveAll(tmpDir) 234 } 235 236 // CloseAllConnections can be used to provoke MySQL client errors for open 237 // connections. 238 // Make sure to call WaitForClose() as well. 239 func (db *DB) CloseAllConnections() { 240 db.mu.Lock() 241 defer db.mu.Unlock() 242 243 for _, c := range db.connections { 244 c.Close() 245 } 246 } 247 248 // WaitForClose should be used after CloseAllConnections() is closed and 249 // you want to provoke a MySQL client error with errno 2006. 250 // 251 // If you don't call this function and execute the next query right away, you 252 // will very likely see errno 2013 instead due to timing issues. 253 // That's because the following can happen: 254 // 255 // 1. vttablet MySQL client is able to send the query to this fake server. 256 // 2. The fake server sees the query and calls the ComQuery() callback. 257 // 3. The fake server tries to write the response back on the connection. 258 // => This will finally fail because the connection is already closed. 259 // In this example, the client would have been able to send off the query and 260 // therefore return errno 2013 ("server lost"). 261 // Instead, if step 1 already fails, the client returns errno 2006 ("gone away"). 262 // By waiting for the connections to close, you make sure of that. 263 func (db *DB) WaitForClose(timeout time.Duration) error { 264 start := time.Now() 265 for { 266 db.mu.Lock() 267 count := len(db.connections) 268 db.mu.Unlock() 269 270 if count == 0 { 271 return nil 272 } 273 if d := time.Since(start); d > timeout { 274 return fmt.Errorf("connections were not correctly closed after %v: %v are left", d, count) 275 } 276 time.Sleep(1 * time.Microsecond) 277 } 278 } 279 280 // ConnParams returns the ConnParams to connect to the DB. 281 func (db *DB) ConnParams() dbconfigs.Connector { 282 return dbconfigs.New(&mysql.ConnParams{ 283 UnixSocket: db.socketFile, 284 Uname: "user1", 285 Pass: "password1", 286 DbName: "fakesqldb", 287 }) 288 } 289 290 // ConnParamsWithUname returns ConnParams to connect to the DB with the Uname set to the provided value. 291 func (db *DB) ConnParamsWithUname(uname string) dbconfigs.Connector { 292 return dbconfigs.New(&mysql.ConnParams{ 293 UnixSocket: db.socketFile, 294 Uname: uname, 295 Pass: "password1", 296 DbName: "fakesqldb", 297 }) 298 } 299 300 // 301 // mysql.Handler interface 302 // 303 304 // NewConnection is part of the mysql.Handler interface. 305 func (db *DB) NewConnection(c *mysql.Conn) { 306 db.mu.Lock() 307 defer db.mu.Unlock() 308 309 if db.isConnFail.Load() { 310 panic(fmt.Errorf("simulating a connection failure")) 311 } 312 313 if db.connDelay != 0 { 314 time.Sleep(db.connDelay) 315 } 316 317 if conn, ok := db.connections[c.ConnectionID]; ok { 318 db.t.Fatalf("BUG: connection with id: %v is already active. existing conn: %v new conn: %v", c.ConnectionID, conn, c) 319 } 320 db.connections[c.ConnectionID] = c 321 } 322 323 // ConnectionClosed is part of the mysql.Handler interface. 324 func (db *DB) ConnectionClosed(c *mysql.Conn) { 325 db.mu.Lock() 326 defer db.mu.Unlock() 327 328 if _, ok := db.connections[c.ConnectionID]; !ok { 329 panic(fmt.Errorf("BUG: Cannot delete connection from list of open connections because it is not registered. ID: %v Conn: %v", c.ConnectionID, c)) 330 } 331 delete(db.connections, c.ConnectionID) 332 } 333 334 // ComQuery is part of the mysql.Handler interface. 335 func (db *DB) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { 336 return db.Handler.HandleQuery(c, query, callback) 337 } 338 339 // WarningCount is part of the mysql.Handler interface. 340 func (db *DB) WarningCount(c *mysql.Conn) uint16 { 341 return 0 342 } 343 344 // HandleQuery is the default implementation of the QueryHandler interface 345 func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { 346 if db.allowAll.Load() { 347 return callback(&sqltypes.Result{}) 348 } 349 350 if db.orderMatters.Load() { 351 result, err := db.comQueryOrdered(query) 352 if err != nil { 353 return err 354 } 355 return callback(result) 356 } 357 key := strings.ToLower(query) 358 db.mu.Lock() 359 defer db.mu.Unlock() 360 db.queryCalled[key]++ 361 db.querylog = append(db.querylog, key) 362 // Check if we should close the connection and provoke errno 2013. 363 if db.shouldClose.Load() { 364 c.Close() 365 366 //log error 367 if err := callback(&sqltypes.Result{}); err != nil { 368 log.Errorf("callback failed : %v", err) 369 } 370 return nil 371 } 372 373 // Using special handling for setting the charset and connection collation. 374 // The driver may send this at connection time, and we don't want it to 375 // interfere. 376 if key == "set names utf8" || strings.HasPrefix(key, "set collation_connection = ") { 377 //log error 378 if err := callback(&sqltypes.Result{}); err != nil { 379 log.Errorf("callback failed : %v", err) 380 } 381 return nil 382 } 383 384 // check if we should reject it. 385 if err, ok := db.rejectedData[key]; ok { 386 return err 387 } 388 389 // Check explicit queries from AddQuery(). 390 result, ok := db.data[key] 391 if ok { 392 if f := result.BeforeFunc; f != nil { 393 f() 394 } 395 return callback(result.Result) 396 } 397 398 // Check query patterns from AddQueryPattern(). 399 for _, pat := range db.patternData { 400 if pat.expr.MatchString(query) { 401 userCallback, ok := db.queryPatternUserCallback[pat.expr] 402 if ok { 403 userCallback(query) 404 } 405 if pat.err != "" { 406 return fmt.Errorf(pat.err) 407 } 408 return callback(pat.result) 409 } 410 } 411 412 if db.neverFail.Load() { 413 return callback(&sqltypes.Result{}) 414 } 415 // Nothing matched. 416 err := fmt.Errorf("fakesqldb:: query: '%s' is not supported on %v", 417 sqlparser.TruncateForUI(query), db.name) 418 log.Errorf("Query not found: %s", sqlparser.TruncateForUI(query)) 419 420 return err 421 } 422 423 func (db *DB) comQueryOrdered(query string) (*sqltypes.Result, error) { 424 var ( 425 afterFn func() 426 entry ExpectedExecuteFetch 427 err error 428 expected string 429 result *sqltypes.Result 430 ) 431 432 defer func() { 433 if afterFn != nil { 434 afterFn() 435 } 436 }() 437 db.mu.Lock() 438 defer db.mu.Unlock() 439 440 // when creating a connection to the database, we send an initial query to set the connection's 441 // collation, we want to skip the query check if we get such initial query. 442 // this is done to ease the test readability. 443 if strings.HasPrefix(query, "SET collation_connection =") || strings.EqualFold(query, "use `fakesqldb`") { 444 return &sqltypes.Result{}, nil 445 } 446 447 index := db.expectedExecuteFetchIndex 448 449 if index >= len(db.expectedExecuteFetch) { 450 if db.neverFail.Load() { 451 return &sqltypes.Result{}, nil 452 } 453 db.t.Errorf("%v: got unexpected out of bound fetch: %v >= %v", db.name, index, len(db.expectedExecuteFetch)) 454 return nil, errors.New("unexpected out of bound fetch") 455 } 456 457 entry = db.expectedExecuteFetch[index] 458 afterFn = entry.AfterFunc 459 err = entry.Error 460 expected = entry.Query 461 result = entry.QueryResult 462 463 if strings.HasSuffix(expected, "*") { 464 if !strings.HasPrefix(query, expected[0:len(expected)-1]) { 465 if db.neverFail.Load() { 466 return &sqltypes.Result{}, nil 467 } 468 db.t.Errorf("%v: got unexpected query start (index=%v): %v != %v", db.name, index, query, expected) 469 return nil, errors.New("unexpected query") 470 } 471 } else { 472 if query != expected { 473 if db.neverFail.Load() { 474 return &sqltypes.Result{}, nil 475 } 476 db.t.Errorf("%v: got unexpected query (index=%v): %v != %v", db.name, index, query, expected) 477 return nil, errors.New("unexpected query") 478 } 479 } 480 481 db.expectedExecuteFetchIndex++ 482 db.t.Logf("ExecuteFetch: %v: %v", db.name, query) 483 484 if err != nil { 485 return nil, err 486 } 487 return result, nil 488 } 489 490 // ComPrepare is part of the mysql.Handler interface. 491 func (db *DB) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) { 492 return nil, nil 493 } 494 495 // ComStmtExecute is part of the mysql.Handler interface. 496 func (db *DB) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { 497 return nil 498 } 499 500 // ComRegisterReplica is part of the mysql.Handler interface. 501 func (db *DB) ComRegisterReplica(c *mysql.Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error { 502 return nil 503 } 504 505 // ComBinlogDump is part of the mysql.Handler interface. 506 func (db *DB) ComBinlogDump(c *mysql.Conn, logFile string, binlogPos uint32) error { 507 return nil 508 } 509 510 // ComBinlogDumpGTID is part of the mysql.Handler interface. 511 func (db *DB) ComBinlogDumpGTID(c *mysql.Conn, logFile string, logPos uint64, gtidSet mysql.GTIDSet) error { 512 return nil 513 } 514 515 // 516 // Methods to add expected queries and results. 517 // 518 519 // AddQuery adds a query and its expected result. 520 func (db *DB) AddQuery(query string, expectedResult *sqltypes.Result) *ExpectedResult { 521 if len(expectedResult.Rows) > 0 && len(expectedResult.Fields) == 0 { 522 panic(fmt.Errorf("please add Fields to this Result so it's valid: %v", query)) 523 } 524 resultCopy := &sqltypes.Result{} 525 *resultCopy = *expectedResult 526 db.mu.Lock() 527 defer db.mu.Unlock() 528 key := strings.ToLower(query) 529 r := &ExpectedResult{resultCopy, nil} 530 db.data[key] = r 531 db.queryCalled[key] = 0 532 return r 533 } 534 535 // SetBeforeFunc sets the BeforeFunc field for the previously registered "query". 536 func (db *DB) SetBeforeFunc(query string, f func()) { 537 db.mu.Lock() 538 defer db.mu.Unlock() 539 key := strings.ToLower(query) 540 r, ok := db.data[key] 541 if !ok { 542 db.t.Fatalf("BUG: no query registered for: %v", query) 543 } 544 545 r.BeforeFunc = f 546 } 547 548 // AddQueryPattern adds an expected result for a set of queries. 549 // These patterns are checked if no exact matches from AddQuery() are found. 550 // This function forces the addition of begin/end anchors (^$) and turns on 551 // case-insensitive matching mode. 552 func (db *DB) AddQueryPattern(queryPattern string, expectedResult *sqltypes.Result) { 553 if len(expectedResult.Rows) > 0 && len(expectedResult.Fields) == 0 { 554 panic(fmt.Errorf("please add Fields to this Result so it's valid: %v", queryPattern)) 555 } 556 expr := regexp.MustCompile("(?is)^" + queryPattern + "$") 557 result := *expectedResult 558 db.mu.Lock() 559 defer db.mu.Unlock() 560 db.patternData[queryPattern] = exprResult{queryPattern: queryPattern, expr: expr, result: &result} 561 } 562 563 // RejectQueryPattern allows a query pattern to be rejected with an error 564 func (db *DB) RejectQueryPattern(queryPattern, error string) { 565 expr := regexp.MustCompile("(?is)^" + queryPattern + "$") 566 db.mu.Lock() 567 defer db.mu.Unlock() 568 db.patternData[queryPattern] = exprResult{queryPattern: queryPattern, expr: expr, err: error} 569 } 570 571 // ClearQueryPattern removes all query patterns set up 572 func (db *DB) ClearQueryPattern() { 573 db.patternData = make(map[string]exprResult) 574 } 575 576 // AddQueryPatternWithCallback is similar to AddQueryPattern: in addition it calls the provided callback function 577 // The callback can be used to set user counters/variables for testing specific usecases 578 func (db *DB) AddQueryPatternWithCallback(queryPattern string, expectedResult *sqltypes.Result, callback func(string)) { 579 db.AddQueryPattern(queryPattern, expectedResult) 580 db.queryPatternUserCallback[db.patternData[queryPattern].expr] = callback 581 } 582 583 // DeleteQuery deletes query from the fake DB. 584 func (db *DB) DeleteQuery(query string) { 585 db.mu.Lock() 586 defer db.mu.Unlock() 587 key := strings.ToLower(query) 588 delete(db.data, key) 589 delete(db.queryCalled, key) 590 } 591 592 // AddRejectedQuery adds a query which will be rejected at execution time. 593 func (db *DB) AddRejectedQuery(query string, err error) { 594 db.mu.Lock() 595 defer db.mu.Unlock() 596 db.rejectedData[strings.ToLower(query)] = err 597 } 598 599 // DeleteRejectedQuery deletes query from the fake DB. 600 func (db *DB) DeleteRejectedQuery(query string) { 601 db.mu.Lock() 602 defer db.mu.Unlock() 603 delete(db.rejectedData, strings.ToLower(query)) 604 } 605 606 // GetQueryCalledNum returns how many times db executes a certain query. 607 func (db *DB) GetQueryCalledNum(query string) int { 608 db.mu.Lock() 609 defer db.mu.Unlock() 610 num, ok := db.queryCalled[strings.ToLower(query)] 611 if !ok { 612 return 0 613 } 614 return num 615 } 616 617 // QueryLog returns the query log in a semicomma separated string 618 func (db *DB) QueryLog() string { 619 return strings.Join(db.querylog, ";") 620 } 621 622 // ResetQueryLog resets the query log 623 func (db *DB) ResetQueryLog() { 624 db.querylog = nil 625 } 626 627 // EnableConnFail makes connection to this fake DB fail. 628 func (db *DB) EnableConnFail() { 629 db.isConnFail.Store(true) 630 } 631 632 // DisableConnFail makes connection to this fake DB success. 633 func (db *DB) DisableConnFail() { 634 db.isConnFail.Store(false) 635 } 636 637 // SetConnDelay delays connections to this fake DB for the given duration 638 func (db *DB) SetConnDelay(d time.Duration) { 639 db.mu.Lock() 640 defer db.mu.Unlock() 641 db.connDelay = d 642 } 643 644 // EnableShouldClose closes the connection when processing the next query. 645 func (db *DB) EnableShouldClose() { 646 db.shouldClose.Store(true) 647 } 648 649 // 650 // The following methods are used for ordered expected queries. 651 // 652 653 // AddExpectedExecuteFetch adds an ExpectedExecuteFetch directly. 654 func (db *DB) AddExpectedExecuteFetch(entry ExpectedExecuteFetch) { 655 db.AddExpectedExecuteFetchAtIndex(appendEntry, entry) 656 } 657 658 // AddExpectedExecuteFetchAtIndex inserts a new entry at index. 659 // index values start at 0. 660 func (db *DB) AddExpectedExecuteFetchAtIndex(index int, entry ExpectedExecuteFetch) { 661 db.mu.Lock() 662 defer db.mu.Unlock() 663 664 if db.expectedExecuteFetch == nil || index < 0 || index >= len(db.expectedExecuteFetch) { 665 index = appendEntry 666 } 667 if index == appendEntry { 668 db.expectedExecuteFetch = append(db.expectedExecuteFetch, entry) 669 } else { 670 // Grow the slice by one element. 671 if cap(db.expectedExecuteFetch) == len(db.expectedExecuteFetch) { 672 db.expectedExecuteFetch = append(db.expectedExecuteFetch, make([]ExpectedExecuteFetch, 1)...) 673 } else { 674 db.expectedExecuteFetch = db.expectedExecuteFetch[0 : len(db.expectedExecuteFetch)+1] 675 } 676 // Use copy to move the upper part of the slice out of the way and open a hole. 677 copy(db.expectedExecuteFetch[index+1:], db.expectedExecuteFetch[index:]) 678 // Store the new value. 679 db.expectedExecuteFetch[index] = entry 680 } 681 } 682 683 // AddExpectedQuery adds a single query with no result. 684 func (db *DB) AddExpectedQuery(query string, err error) { 685 db.AddExpectedExecuteFetch(ExpectedExecuteFetch{ 686 Query: query, 687 QueryResult: &sqltypes.Result{}, 688 Error: err, 689 }) 690 } 691 692 // AddExpectedQueryAtIndex adds an expected ordered query at an index. 693 func (db *DB) AddExpectedQueryAtIndex(index int, query string, err error) { 694 db.AddExpectedExecuteFetchAtIndex(index, ExpectedExecuteFetch{ 695 Query: query, 696 QueryResult: &sqltypes.Result{}, 697 Error: err, 698 }) 699 } 700 701 // GetEntry returns the expected entry at "index". If index is out of bounds, 702 // the return value will be nil. 703 func (db *DB) GetEntry(index int) *ExpectedExecuteFetch { 704 db.mu.Lock() 705 defer db.mu.Unlock() 706 707 if index < 0 || index >= len(db.expectedExecuteFetch) { 708 panic(fmt.Sprintf("index out of range. current length: %v", len(db.expectedExecuteFetch))) 709 } 710 711 return &db.expectedExecuteFetch[index] 712 } 713 714 // DeleteAllEntries removes all ordered entries. 715 func (db *DB) DeleteAllEntries() { 716 db.mu.Lock() 717 defer db.mu.Unlock() 718 719 db.expectedExecuteFetch = make([]ExpectedExecuteFetch, 0) 720 db.expectedExecuteFetchIndex = 0 721 } 722 723 // DeleteAllEntriesAfterIndex removes all queries after the index. 724 func (db *DB) DeleteAllEntriesAfterIndex(index int) { 725 db.mu.Lock() 726 defer db.mu.Unlock() 727 728 if index < 0 || index >= len(db.expectedExecuteFetch) { 729 panic(fmt.Sprintf("index out of range. current length: %v", len(db.expectedExecuteFetch))) 730 } 731 732 if index+1 < db.expectedExecuteFetchIndex { 733 // Don't delete entries which were already answered. 734 return 735 } 736 737 db.expectedExecuteFetch = db.expectedExecuteFetch[:index+1] 738 } 739 740 // VerifyAllExecutedOrFail checks that all expected queries where actually 741 // received and executed. If not, it will let the test fail. 742 func (db *DB) VerifyAllExecutedOrFail() { 743 db.mu.Lock() 744 defer db.mu.Unlock() 745 746 if db.expectedExecuteFetchIndex != len(db.expectedExecuteFetch) { 747 db.t.Errorf("%v: not all expected queries were executed. leftovers: %v", db.name, db.expectedExecuteFetch[db.expectedExecuteFetchIndex:]) 748 } 749 } 750 751 func (db *DB) SetAllowAll(allowAll bool) { 752 db.allowAll.Store(allowAll) 753 } 754 755 func (db *DB) SetNeverFail(neverFail bool) { 756 db.neverFail.Store(neverFail) 757 } 758 759 func (db *DB) MockQueriesForTable(table string, result *sqltypes.Result) { 760 // pattern for selecting explicit list of columns where database is specified 761 selectQueryPattern := fmt.Sprintf("select .* from `%s`.`%s` where 1 != 1", db.name, table) 762 db.AddQueryPattern(selectQueryPattern, result) 763 764 // pattern for selecting explicit list of columns where database is not specified 765 selectQueryPattern = fmt.Sprintf("select .* from %s where 1 != 1", table) 766 db.AddQueryPattern(selectQueryPattern, result) 767 768 // mock query for returning columns from information_schema.columns based on specified result 769 var cols []string 770 for _, field := range result.Fields { 771 cols = append(cols, field.Name) 772 } 773 db.AddQueryPattern(fmt.Sprintf(mysql.GetColumnNamesQueryPatternForTable, table), sqltypes.MakeTestResult( 774 sqltypes.MakeTestFields( 775 "column_name", 776 "varchar", 777 ), 778 cols..., 779 )) 780 }