vitess.io/vitess@v0.16.2/go/mysql/fakesqldb/server.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package fakesqldb provides a MySQL server for tests.
    18  package fakesqldb
    19  
    20  import (
    21  	"errors"
    22  	"fmt"
    23  	"os"
    24  	"path"
    25  	"regexp"
    26  	"strings"
    27  	"sync"
    28  	"sync/atomic"
    29  	"testing"
    30  	"time"
    31  
    32  	"vitess.io/vitess/go/vt/sqlparser"
    33  
    34  	"vitess.io/vitess/go/vt/log"
    35  
    36  	"vitess.io/vitess/go/mysql"
    37  	"vitess.io/vitess/go/sqltypes"
    38  
    39  	"vitess.io/vitess/go/vt/dbconfigs"
    40  	querypb "vitess.io/vitess/go/vt/proto/query"
    41  )
    42  
    43  const appendEntry = -1
    44  
    45  // DB is a fake database and all its methods are thread safe.  It
    46  // creates a mysql.Listener and implements the mysql.Handler
    47  // interface.  We use a Unix socket to connect to the database, as
    48  // this is the most common way for clients to connect to MySQL. This
    49  // impacts the error codes we're getting back: when the server side is
    50  // closed, the client queries will return CRServerGone(2006) when sending
    51  // the data, as opposed to CRServerLost(2013) when reading the response.
    52  type DB struct {
    53  	mysql.UnimplementedHandler
    54  
    55  	// Fields set at construction time.
    56  
    57  	// t is our testing.TB instance
    58  	t testing.TB
    59  
    60  	// listener is our mysql.Listener.
    61  	listener *mysql.Listener
    62  
    63  	// socketFile is the path to the unix socket file.
    64  	socketFile string
    65  
    66  	// acceptWG is set when we listen, and can be waited on to
    67  	// make sure we don't accept any more.
    68  	acceptWG sync.WaitGroup
    69  
    70  	// orderMatters is set when the query order matters.
    71  	orderMatters atomic.Bool
    72  
    73  	// Fields set at runtime.
    74  
    75  	// mu protects all the following fields.
    76  	mu sync.Mutex
    77  	// name is the name of this DB. Set to 'fakesqldb' by default.
    78  	// Use SetName() to change.
    79  	name string
    80  	// isConnFail trigger a panic in the connection handler.
    81  	isConnFail atomic.Bool
    82  	// connDelay causes a sleep in the connection handler
    83  	connDelay time.Duration
    84  	// shouldClose, if true, tells ComQuery() to close the connection when
    85  	// processing the next query. This will trigger a MySQL client error with
    86  	// errno 2013 ("server lost").
    87  	shouldClose atomic.Bool
    88  	// allowAll: if set to true, ComQuery returns an empty result
    89  	// for all queries. This flag is used for benchmarking.
    90  	allowAll atomic.Bool
    91  
    92  	// Handler: interface that allows a caller to override the query handling
    93  	// implementation. By default it points to the DB itself
    94  	Handler QueryHandler
    95  
    96  	// This next set of fields is used when ordering of the queries doesn't
    97  	// matter.
    98  
    99  	// data maps tolower(query) to a result.
   100  	data map[string]*ExpectedResult
   101  	// rejectedData maps tolower(query) to an error.
   102  	rejectedData map[string]error
   103  	// patternData is a map of regexp queries to results.
   104  	patternData map[string]exprResult
   105  	// queryCalled keeps track of how many times a query was called.
   106  	queryCalled map[string]int
   107  	// querylog keeps track of all called queries
   108  	querylog []string
   109  
   110  	// This next set of fields is used when ordering of the queries matters.
   111  
   112  	// expectedExecuteFetch is the array of expected queries.
   113  	expectedExecuteFetch []ExpectedExecuteFetch
   114  	// expectedExecuteFetchIndex is the current index of the query.
   115  	expectedExecuteFetchIndex int
   116  
   117  	// connections tracks all open connections.
   118  	// The key for the map is the value of mysql.Conn.ConnectionID.
   119  	connections map[uint32]*mysql.Conn
   120  
   121  	// queryPatternUserCallback stores optional callbacks when a query with a pattern is called
   122  	queryPatternUserCallback map[*regexp.Regexp]func(string)
   123  
   124  	// if fakesqldb is asked to serve queries or query patterns that it has not been explicitly told about it will
   125  	// error out by default. However if you set this flag then any unmatched query results in an empty result
   126  	neverFail atomic.Bool
   127  }
   128  
   129  // QueryHandler is the interface used by the DB to simulate executed queries
   130  type QueryHandler interface {
   131  	HandleQuery(*mysql.Conn, string, func(*sqltypes.Result) error) error
   132  }
   133  
   134  // ExpectedResult holds the data for a matched query.
   135  type ExpectedResult struct {
   136  	*sqltypes.Result
   137  	// BeforeFunc() is synchronously called before the server returns the result.
   138  	BeforeFunc func()
   139  }
   140  
   141  type exprResult struct {
   142  	queryPattern string
   143  	expr         *regexp.Regexp
   144  	result       *sqltypes.Result
   145  	err          string
   146  }
   147  
   148  // ExpectedExecuteFetch defines for an expected query the to be faked output.
   149  // It is used for ordered expected output.
   150  type ExpectedExecuteFetch struct {
   151  	Query       string
   152  	QueryResult *sqltypes.Result
   153  	Error       error
   154  	// AfterFunc is a callback which is executed while the query
   155  	// is executed i.e., before the fake responds to the client.
   156  	AfterFunc func()
   157  }
   158  
   159  // New creates a server, and starts listening.
   160  func New(t testing.TB) *DB {
   161  	// Pick a path for our socket.
   162  	socketDir, err := os.MkdirTemp("", "fakesqldb")
   163  	if err != nil {
   164  		t.Fatalf("os.MkdirTemp failed: %v", err)
   165  	}
   166  	socketFile := path.Join(socketDir, "fakesqldb.sock")
   167  
   168  	// Create our DB.
   169  	db := &DB{
   170  		t:                        t,
   171  		socketFile:               socketFile,
   172  		name:                     "fakesqldb",
   173  		data:                     make(map[string]*ExpectedResult),
   174  		rejectedData:             make(map[string]error),
   175  		queryCalled:              make(map[string]int),
   176  		connections:              make(map[uint32]*mysql.Conn),
   177  		queryPatternUserCallback: make(map[*regexp.Regexp]func(string)),
   178  		patternData:              make(map[string]exprResult),
   179  	}
   180  
   181  	db.Handler = db
   182  
   183  	authServer := mysql.NewAuthServerNone()
   184  
   185  	// Start listening.
   186  	db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false)
   187  	if err != nil {
   188  		t.Fatalf("NewListener failed: %v", err)
   189  	}
   190  
   191  	db.acceptWG.Add(1)
   192  	go func() {
   193  		defer db.acceptWG.Done()
   194  		db.listener.Accept()
   195  	}()
   196  
   197  	db.AddQuery("use `fakesqldb`", &sqltypes.Result{})
   198  	// Return the db.
   199  	return db
   200  }
   201  
   202  // Name returns the name of the DB.
   203  func (db *DB) Name() string {
   204  	db.mu.Lock()
   205  	defer db.mu.Unlock()
   206  
   207  	return db.name
   208  }
   209  
   210  // SetName sets the name of the DB. to differentiate them in tests if needed.
   211  func (db *DB) SetName(name string) *DB {
   212  	db.mu.Lock()
   213  	defer db.mu.Unlock()
   214  
   215  	db.name = name
   216  	return db
   217  }
   218  
   219  // OrderMatters sets the orderMatters flag.
   220  func (db *DB) OrderMatters() {
   221  	db.orderMatters.Store(true)
   222  }
   223  
   224  // Close closes the Listener and waits for it to stop accepting.
   225  // It then closes all connections, and cleans up the temporary directory.
   226  func (db *DB) Close() {
   227  	db.listener.Close()
   228  	db.acceptWG.Wait()
   229  
   230  	db.CloseAllConnections()
   231  
   232  	tmpDir := path.Dir(db.socketFile)
   233  	os.RemoveAll(tmpDir)
   234  }
   235  
   236  // CloseAllConnections can be used to provoke MySQL client errors for open
   237  // connections.
   238  // Make sure to call WaitForClose() as well.
   239  func (db *DB) CloseAllConnections() {
   240  	db.mu.Lock()
   241  	defer db.mu.Unlock()
   242  
   243  	for _, c := range db.connections {
   244  		c.Close()
   245  	}
   246  }
   247  
   248  // WaitForClose should be used after CloseAllConnections() is closed and
   249  // you want to provoke a MySQL client error with errno 2006.
   250  //
   251  // If you don't call this function and execute the next query right away, you
   252  // will very likely see errno 2013 instead due to timing issues.
   253  // That's because the following can happen:
   254  //
   255  // 1. vttablet MySQL client is able to send the query to this fake server.
   256  // 2. The fake server sees the query and calls the ComQuery() callback.
   257  // 3. The fake server tries to write the response back on the connection.
   258  // => This will finally fail because the connection is already closed.
   259  // In this example, the client would have been able to send off the query and
   260  // therefore return errno 2013 ("server lost").
   261  // Instead, if step 1 already fails, the client returns errno 2006 ("gone away").
   262  // By waiting for the connections to close, you make sure of that.
   263  func (db *DB) WaitForClose(timeout time.Duration) error {
   264  	start := time.Now()
   265  	for {
   266  		db.mu.Lock()
   267  		count := len(db.connections)
   268  		db.mu.Unlock()
   269  
   270  		if count == 0 {
   271  			return nil
   272  		}
   273  		if d := time.Since(start); d > timeout {
   274  			return fmt.Errorf("connections were not correctly closed after %v: %v are left", d, count)
   275  		}
   276  		time.Sleep(1 * time.Microsecond)
   277  	}
   278  }
   279  
   280  // ConnParams returns the ConnParams to connect to the DB.
   281  func (db *DB) ConnParams() dbconfigs.Connector {
   282  	return dbconfigs.New(&mysql.ConnParams{
   283  		UnixSocket: db.socketFile,
   284  		Uname:      "user1",
   285  		Pass:       "password1",
   286  		DbName:     "fakesqldb",
   287  	})
   288  }
   289  
   290  // ConnParamsWithUname returns  ConnParams to connect to the DB with the Uname set to the provided value.
   291  func (db *DB) ConnParamsWithUname(uname string) dbconfigs.Connector {
   292  	return dbconfigs.New(&mysql.ConnParams{
   293  		UnixSocket: db.socketFile,
   294  		Uname:      uname,
   295  		Pass:       "password1",
   296  		DbName:     "fakesqldb",
   297  	})
   298  }
   299  
   300  //
   301  // mysql.Handler interface
   302  //
   303  
   304  // NewConnection is part of the mysql.Handler interface.
   305  func (db *DB) NewConnection(c *mysql.Conn) {
   306  	db.mu.Lock()
   307  	defer db.mu.Unlock()
   308  
   309  	if db.isConnFail.Load() {
   310  		panic(fmt.Errorf("simulating a connection failure"))
   311  	}
   312  
   313  	if db.connDelay != 0 {
   314  		time.Sleep(db.connDelay)
   315  	}
   316  
   317  	if conn, ok := db.connections[c.ConnectionID]; ok {
   318  		db.t.Fatalf("BUG: connection with id: %v is already active. existing conn: %v new conn: %v", c.ConnectionID, conn, c)
   319  	}
   320  	db.connections[c.ConnectionID] = c
   321  }
   322  
   323  // ConnectionClosed is part of the mysql.Handler interface.
   324  func (db *DB) ConnectionClosed(c *mysql.Conn) {
   325  	db.mu.Lock()
   326  	defer db.mu.Unlock()
   327  
   328  	if _, ok := db.connections[c.ConnectionID]; !ok {
   329  		panic(fmt.Errorf("BUG: Cannot delete connection from list of open connections because it is not registered. ID: %v Conn: %v", c.ConnectionID, c))
   330  	}
   331  	delete(db.connections, c.ConnectionID)
   332  }
   333  
   334  // ComQuery is part of the mysql.Handler interface.
   335  func (db *DB) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
   336  	return db.Handler.HandleQuery(c, query, callback)
   337  }
   338  
   339  // WarningCount is part of the mysql.Handler interface.
   340  func (db *DB) WarningCount(c *mysql.Conn) uint16 {
   341  	return 0
   342  }
   343  
   344  // HandleQuery is the default implementation of the QueryHandler interface
   345  func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
   346  	if db.allowAll.Load() {
   347  		return callback(&sqltypes.Result{})
   348  	}
   349  
   350  	if db.orderMatters.Load() {
   351  		result, err := db.comQueryOrdered(query)
   352  		if err != nil {
   353  			return err
   354  		}
   355  		return callback(result)
   356  	}
   357  	key := strings.ToLower(query)
   358  	db.mu.Lock()
   359  	defer db.mu.Unlock()
   360  	db.queryCalled[key]++
   361  	db.querylog = append(db.querylog, key)
   362  	// Check if we should close the connection and provoke errno 2013.
   363  	if db.shouldClose.Load() {
   364  		c.Close()
   365  
   366  		//log error
   367  		if err := callback(&sqltypes.Result{}); err != nil {
   368  			log.Errorf("callback failed : %v", err)
   369  		}
   370  		return nil
   371  	}
   372  
   373  	// Using special handling for setting the charset and connection collation.
   374  	// The driver may send this at connection time, and we don't want it to
   375  	// interfere.
   376  	if key == "set names utf8" || strings.HasPrefix(key, "set collation_connection = ") {
   377  		//log error
   378  		if err := callback(&sqltypes.Result{}); err != nil {
   379  			log.Errorf("callback failed : %v", err)
   380  		}
   381  		return nil
   382  	}
   383  
   384  	// check if we should reject it.
   385  	if err, ok := db.rejectedData[key]; ok {
   386  		return err
   387  	}
   388  
   389  	// Check explicit queries from AddQuery().
   390  	result, ok := db.data[key]
   391  	if ok {
   392  		if f := result.BeforeFunc; f != nil {
   393  			f()
   394  		}
   395  		return callback(result.Result)
   396  	}
   397  
   398  	// Check query patterns from AddQueryPattern().
   399  	for _, pat := range db.patternData {
   400  		if pat.expr.MatchString(query) {
   401  			userCallback, ok := db.queryPatternUserCallback[pat.expr]
   402  			if ok {
   403  				userCallback(query)
   404  			}
   405  			if pat.err != "" {
   406  				return fmt.Errorf(pat.err)
   407  			}
   408  			return callback(pat.result)
   409  		}
   410  	}
   411  
   412  	if db.neverFail.Load() {
   413  		return callback(&sqltypes.Result{})
   414  	}
   415  	// Nothing matched.
   416  	err := fmt.Errorf("fakesqldb:: query: '%s' is not supported on %v",
   417  		sqlparser.TruncateForUI(query), db.name)
   418  	log.Errorf("Query not found: %s", sqlparser.TruncateForUI(query))
   419  
   420  	return err
   421  }
   422  
   423  func (db *DB) comQueryOrdered(query string) (*sqltypes.Result, error) {
   424  	var (
   425  		afterFn  func()
   426  		entry    ExpectedExecuteFetch
   427  		err      error
   428  		expected string
   429  		result   *sqltypes.Result
   430  	)
   431  
   432  	defer func() {
   433  		if afterFn != nil {
   434  			afterFn()
   435  		}
   436  	}()
   437  	db.mu.Lock()
   438  	defer db.mu.Unlock()
   439  
   440  	// when creating a connection to the database, we send an initial query to set the connection's
   441  	// collation, we want to skip the query check if we get such initial query.
   442  	// this is done to ease the test readability.
   443  	if strings.HasPrefix(query, "SET collation_connection =") || strings.EqualFold(query, "use `fakesqldb`") {
   444  		return &sqltypes.Result{}, nil
   445  	}
   446  
   447  	index := db.expectedExecuteFetchIndex
   448  
   449  	if index >= len(db.expectedExecuteFetch) {
   450  		if db.neverFail.Load() {
   451  			return &sqltypes.Result{}, nil
   452  		}
   453  		db.t.Errorf("%v: got unexpected out of bound fetch: %v >= %v", db.name, index, len(db.expectedExecuteFetch))
   454  		return nil, errors.New("unexpected out of bound fetch")
   455  	}
   456  
   457  	entry = db.expectedExecuteFetch[index]
   458  	afterFn = entry.AfterFunc
   459  	err = entry.Error
   460  	expected = entry.Query
   461  	result = entry.QueryResult
   462  
   463  	if strings.HasSuffix(expected, "*") {
   464  		if !strings.HasPrefix(query, expected[0:len(expected)-1]) {
   465  			if db.neverFail.Load() {
   466  				return &sqltypes.Result{}, nil
   467  			}
   468  			db.t.Errorf("%v: got unexpected query start (index=%v): %v != %v", db.name, index, query, expected)
   469  			return nil, errors.New("unexpected query")
   470  		}
   471  	} else {
   472  		if query != expected {
   473  			if db.neverFail.Load() {
   474  				return &sqltypes.Result{}, nil
   475  			}
   476  			db.t.Errorf("%v: got unexpected query (index=%v): %v != %v", db.name, index, query, expected)
   477  			return nil, errors.New("unexpected query")
   478  		}
   479  	}
   480  
   481  	db.expectedExecuteFetchIndex++
   482  	db.t.Logf("ExecuteFetch: %v: %v", db.name, query)
   483  
   484  	if err != nil {
   485  		return nil, err
   486  	}
   487  	return result, nil
   488  }
   489  
   490  // ComPrepare is part of the mysql.Handler interface.
   491  func (db *DB) ComPrepare(c *mysql.Conn, query string, bindVars map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
   492  	return nil, nil
   493  }
   494  
   495  // ComStmtExecute is part of the mysql.Handler interface.
   496  func (db *DB) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
   497  	return nil
   498  }
   499  
   500  // ComRegisterReplica is part of the mysql.Handler interface.
   501  func (db *DB) ComRegisterReplica(c *mysql.Conn, replicaHost string, replicaPort uint16, replicaUser string, replicaPassword string) error {
   502  	return nil
   503  }
   504  
   505  // ComBinlogDump is part of the mysql.Handler interface.
   506  func (db *DB) ComBinlogDump(c *mysql.Conn, logFile string, binlogPos uint32) error {
   507  	return nil
   508  }
   509  
   510  // ComBinlogDumpGTID is part of the mysql.Handler interface.
   511  func (db *DB) ComBinlogDumpGTID(c *mysql.Conn, logFile string, logPos uint64, gtidSet mysql.GTIDSet) error {
   512  	return nil
   513  }
   514  
   515  //
   516  // Methods to add expected queries and results.
   517  //
   518  
   519  // AddQuery adds a query and its expected result.
   520  func (db *DB) AddQuery(query string, expectedResult *sqltypes.Result) *ExpectedResult {
   521  	if len(expectedResult.Rows) > 0 && len(expectedResult.Fields) == 0 {
   522  		panic(fmt.Errorf("please add Fields to this Result so it's valid: %v", query))
   523  	}
   524  	resultCopy := &sqltypes.Result{}
   525  	*resultCopy = *expectedResult
   526  	db.mu.Lock()
   527  	defer db.mu.Unlock()
   528  	key := strings.ToLower(query)
   529  	r := &ExpectedResult{resultCopy, nil}
   530  	db.data[key] = r
   531  	db.queryCalled[key] = 0
   532  	return r
   533  }
   534  
   535  // SetBeforeFunc sets the BeforeFunc field for the previously registered "query".
   536  func (db *DB) SetBeforeFunc(query string, f func()) {
   537  	db.mu.Lock()
   538  	defer db.mu.Unlock()
   539  	key := strings.ToLower(query)
   540  	r, ok := db.data[key]
   541  	if !ok {
   542  		db.t.Fatalf("BUG: no query registered for: %v", query)
   543  	}
   544  
   545  	r.BeforeFunc = f
   546  }
   547  
   548  // AddQueryPattern adds an expected result for a set of queries.
   549  // These patterns are checked if no exact matches from AddQuery() are found.
   550  // This function forces the addition of begin/end anchors (^$) and turns on
   551  // case-insensitive matching mode.
   552  func (db *DB) AddQueryPattern(queryPattern string, expectedResult *sqltypes.Result) {
   553  	if len(expectedResult.Rows) > 0 && len(expectedResult.Fields) == 0 {
   554  		panic(fmt.Errorf("please add Fields to this Result so it's valid: %v", queryPattern))
   555  	}
   556  	expr := regexp.MustCompile("(?is)^" + queryPattern + "$")
   557  	result := *expectedResult
   558  	db.mu.Lock()
   559  	defer db.mu.Unlock()
   560  	db.patternData[queryPattern] = exprResult{queryPattern: queryPattern, expr: expr, result: &result}
   561  }
   562  
   563  // RejectQueryPattern allows a query pattern to be rejected with an error
   564  func (db *DB) RejectQueryPattern(queryPattern, error string) {
   565  	expr := regexp.MustCompile("(?is)^" + queryPattern + "$")
   566  	db.mu.Lock()
   567  	defer db.mu.Unlock()
   568  	db.patternData[queryPattern] = exprResult{queryPattern: queryPattern, expr: expr, err: error}
   569  }
   570  
   571  // ClearQueryPattern removes all query patterns set up
   572  func (db *DB) ClearQueryPattern() {
   573  	db.patternData = make(map[string]exprResult)
   574  }
   575  
   576  // AddQueryPatternWithCallback is similar to AddQueryPattern: in addition it calls the provided callback function
   577  // The callback can be used to set user counters/variables for testing specific usecases
   578  func (db *DB) AddQueryPatternWithCallback(queryPattern string, expectedResult *sqltypes.Result, callback func(string)) {
   579  	db.AddQueryPattern(queryPattern, expectedResult)
   580  	db.queryPatternUserCallback[db.patternData[queryPattern].expr] = callback
   581  }
   582  
   583  // DeleteQuery deletes query from the fake DB.
   584  func (db *DB) DeleteQuery(query string) {
   585  	db.mu.Lock()
   586  	defer db.mu.Unlock()
   587  	key := strings.ToLower(query)
   588  	delete(db.data, key)
   589  	delete(db.queryCalled, key)
   590  }
   591  
   592  // AddRejectedQuery adds a query which will be rejected at execution time.
   593  func (db *DB) AddRejectedQuery(query string, err error) {
   594  	db.mu.Lock()
   595  	defer db.mu.Unlock()
   596  	db.rejectedData[strings.ToLower(query)] = err
   597  }
   598  
   599  // DeleteRejectedQuery deletes query from the fake DB.
   600  func (db *DB) DeleteRejectedQuery(query string) {
   601  	db.mu.Lock()
   602  	defer db.mu.Unlock()
   603  	delete(db.rejectedData, strings.ToLower(query))
   604  }
   605  
   606  // GetQueryCalledNum returns how many times db executes a certain query.
   607  func (db *DB) GetQueryCalledNum(query string) int {
   608  	db.mu.Lock()
   609  	defer db.mu.Unlock()
   610  	num, ok := db.queryCalled[strings.ToLower(query)]
   611  	if !ok {
   612  		return 0
   613  	}
   614  	return num
   615  }
   616  
   617  // QueryLog returns the query log in a semicomma separated string
   618  func (db *DB) QueryLog() string {
   619  	return strings.Join(db.querylog, ";")
   620  }
   621  
   622  // ResetQueryLog resets the query log
   623  func (db *DB) ResetQueryLog() {
   624  	db.querylog = nil
   625  }
   626  
   627  // EnableConnFail makes connection to this fake DB fail.
   628  func (db *DB) EnableConnFail() {
   629  	db.isConnFail.Store(true)
   630  }
   631  
   632  // DisableConnFail makes connection to this fake DB success.
   633  func (db *DB) DisableConnFail() {
   634  	db.isConnFail.Store(false)
   635  }
   636  
   637  // SetConnDelay delays connections to this fake DB for the given duration
   638  func (db *DB) SetConnDelay(d time.Duration) {
   639  	db.mu.Lock()
   640  	defer db.mu.Unlock()
   641  	db.connDelay = d
   642  }
   643  
   644  // EnableShouldClose closes the connection when processing the next query.
   645  func (db *DB) EnableShouldClose() {
   646  	db.shouldClose.Store(true)
   647  }
   648  
   649  //
   650  // The following methods are used for ordered expected queries.
   651  //
   652  
   653  // AddExpectedExecuteFetch adds an ExpectedExecuteFetch directly.
   654  func (db *DB) AddExpectedExecuteFetch(entry ExpectedExecuteFetch) {
   655  	db.AddExpectedExecuteFetchAtIndex(appendEntry, entry)
   656  }
   657  
   658  // AddExpectedExecuteFetchAtIndex inserts a new entry at index.
   659  // index values start at 0.
   660  func (db *DB) AddExpectedExecuteFetchAtIndex(index int, entry ExpectedExecuteFetch) {
   661  	db.mu.Lock()
   662  	defer db.mu.Unlock()
   663  
   664  	if db.expectedExecuteFetch == nil || index < 0 || index >= len(db.expectedExecuteFetch) {
   665  		index = appendEntry
   666  	}
   667  	if index == appendEntry {
   668  		db.expectedExecuteFetch = append(db.expectedExecuteFetch, entry)
   669  	} else {
   670  		// Grow the slice by one element.
   671  		if cap(db.expectedExecuteFetch) == len(db.expectedExecuteFetch) {
   672  			db.expectedExecuteFetch = append(db.expectedExecuteFetch, make([]ExpectedExecuteFetch, 1)...)
   673  		} else {
   674  			db.expectedExecuteFetch = db.expectedExecuteFetch[0 : len(db.expectedExecuteFetch)+1]
   675  		}
   676  		// Use copy to move the upper part of the slice out of the way and open a hole.
   677  		copy(db.expectedExecuteFetch[index+1:], db.expectedExecuteFetch[index:])
   678  		// Store the new value.
   679  		db.expectedExecuteFetch[index] = entry
   680  	}
   681  }
   682  
   683  // AddExpectedQuery adds a single query with no result.
   684  func (db *DB) AddExpectedQuery(query string, err error) {
   685  	db.AddExpectedExecuteFetch(ExpectedExecuteFetch{
   686  		Query:       query,
   687  		QueryResult: &sqltypes.Result{},
   688  		Error:       err,
   689  	})
   690  }
   691  
   692  // AddExpectedQueryAtIndex adds an expected ordered query at an index.
   693  func (db *DB) AddExpectedQueryAtIndex(index int, query string, err error) {
   694  	db.AddExpectedExecuteFetchAtIndex(index, ExpectedExecuteFetch{
   695  		Query:       query,
   696  		QueryResult: &sqltypes.Result{},
   697  		Error:       err,
   698  	})
   699  }
   700  
   701  // GetEntry returns the expected entry at "index". If index is out of bounds,
   702  // the return value will be nil.
   703  func (db *DB) GetEntry(index int) *ExpectedExecuteFetch {
   704  	db.mu.Lock()
   705  	defer db.mu.Unlock()
   706  
   707  	if index < 0 || index >= len(db.expectedExecuteFetch) {
   708  		panic(fmt.Sprintf("index out of range. current length: %v", len(db.expectedExecuteFetch)))
   709  	}
   710  
   711  	return &db.expectedExecuteFetch[index]
   712  }
   713  
   714  // DeleteAllEntries removes all ordered entries.
   715  func (db *DB) DeleteAllEntries() {
   716  	db.mu.Lock()
   717  	defer db.mu.Unlock()
   718  
   719  	db.expectedExecuteFetch = make([]ExpectedExecuteFetch, 0)
   720  	db.expectedExecuteFetchIndex = 0
   721  }
   722  
   723  // DeleteAllEntriesAfterIndex removes all queries after the index.
   724  func (db *DB) DeleteAllEntriesAfterIndex(index int) {
   725  	db.mu.Lock()
   726  	defer db.mu.Unlock()
   727  
   728  	if index < 0 || index >= len(db.expectedExecuteFetch) {
   729  		panic(fmt.Sprintf("index out of range. current length: %v", len(db.expectedExecuteFetch)))
   730  	}
   731  
   732  	if index+1 < db.expectedExecuteFetchIndex {
   733  		// Don't delete entries which were already answered.
   734  		return
   735  	}
   736  
   737  	db.expectedExecuteFetch = db.expectedExecuteFetch[:index+1]
   738  }
   739  
   740  // VerifyAllExecutedOrFail checks that all expected queries where actually
   741  // received and executed. If not, it will let the test fail.
   742  func (db *DB) VerifyAllExecutedOrFail() {
   743  	db.mu.Lock()
   744  	defer db.mu.Unlock()
   745  
   746  	if db.expectedExecuteFetchIndex != len(db.expectedExecuteFetch) {
   747  		db.t.Errorf("%v: not all expected queries were executed. leftovers: %v", db.name, db.expectedExecuteFetch[db.expectedExecuteFetchIndex:])
   748  	}
   749  }
   750  
   751  func (db *DB) SetAllowAll(allowAll bool) {
   752  	db.allowAll.Store(allowAll)
   753  }
   754  
   755  func (db *DB) SetNeverFail(neverFail bool) {
   756  	db.neverFail.Store(neverFail)
   757  }
   758  
   759  func (db *DB) MockQueriesForTable(table string, result *sqltypes.Result) {
   760  	// pattern for selecting explicit list of columns where database is specified
   761  	selectQueryPattern := fmt.Sprintf("select .* from `%s`.`%s` where 1 != 1", db.name, table)
   762  	db.AddQueryPattern(selectQueryPattern, result)
   763  
   764  	// pattern for selecting explicit list of columns where database is not specified
   765  	selectQueryPattern = fmt.Sprintf("select .* from %s where 1 != 1", table)
   766  	db.AddQueryPattern(selectQueryPattern, result)
   767  
   768  	// mock query for returning columns from information_schema.columns based on specified result
   769  	var cols []string
   770  	for _, field := range result.Fields {
   771  		cols = append(cols, field.Name)
   772  	}
   773  	db.AddQueryPattern(fmt.Sprintf(mysql.GetColumnNamesQueryPatternForTable, table), sqltypes.MakeTestResult(
   774  		sqltypes.MakeTestFields(
   775  			"column_name",
   776  			"varchar",
   777  		),
   778  		cols...,
   779  	))
   780  }