github.com/dolthub/go-mysql-server@v0.18.0/server/handler_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package server
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"strconv"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/dolthub/vitess/go/mysql"
    27  	"github.com/dolthub/vitess/go/sqltypes"
    28  	"github.com/dolthub/vitess/go/vt/proto/query"
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  
    32  	sqle "github.com/dolthub/go-mysql-server"
    33  	"github.com/dolthub/go-mysql-server/memory"
    34  	"github.com/dolthub/go-mysql-server/sql"
    35  	"github.com/dolthub/go-mysql-server/sql/analyzer"
    36  	"github.com/dolthub/go-mysql-server/sql/types"
    37  	"github.com/dolthub/go-mysql-server/sql/variables"
    38  )
    39  
    40  var samplePrepareData = &mysql.PrepareData{
    41  	StatementID: 42,
    42  	ParamsCount: 1,
    43  }
    44  
    45  func TestHandlerOutput(t *testing.T) {
    46  	e, pro := setupMemDB(require.New(t))
    47  	dbFunc := pro.Database
    48  
    49  	dummyConn := newConn(1)
    50  	handler := &Handler{
    51  		e: e,
    52  		sm: NewSessionManager(
    53  			testSessionBuilder(pro),
    54  			sql.NoopTracer,
    55  			dbFunc,
    56  			sql.NewMemoryManager(nil),
    57  			sqle.NewProcessList(),
    58  			"foo",
    59  		),
    60  		readTimeout: time.Second,
    61  	}
    62  	handler.NewConnection(dummyConn)
    63  
    64  	type expectedValues struct {
    65  		callsToCallback  int
    66  		lenLastBatch     int
    67  		lastRowsAffected uint64
    68  	}
    69  
    70  	tests := []struct {
    71  		name     string
    72  		handler  *Handler
    73  		conn     *mysql.Conn
    74  		query    string
    75  		expected expectedValues
    76  	}{
    77  		{
    78  			name:    "select all without limit",
    79  			handler: handler,
    80  			conn:    dummyConn,
    81  			query:   "SELECT * FROM test",
    82  			expected: expectedValues{
    83  				callsToCallback:  8,
    84  				lenLastBatch:     114,
    85  				lastRowsAffected: uint64(114),
    86  			},
    87  		},
    88  		{
    89  			name:    "with limit equal to batch capacity",
    90  			handler: handler,
    91  			conn:    dummyConn,
    92  			query:   "SELECT * FROM test limit 100",
    93  			expected: expectedValues{
    94  				callsToCallback:  1,
    95  				lenLastBatch:     100,
    96  				lastRowsAffected: uint64(100),
    97  			},
    98  		},
    99  		{
   100  			name:    "with limit less than batch capacity",
   101  			handler: handler,
   102  			conn:    dummyConn,
   103  			query:   "SELECT * FROM test limit 60",
   104  			expected: expectedValues{
   105  				callsToCallback:  1,
   106  				lenLastBatch:     60,
   107  				lastRowsAffected: uint64(60),
   108  			},
   109  		},
   110  		{
   111  			name:    "with limit greater than batch capacity",
   112  			handler: handler,
   113  			conn:    dummyConn,
   114  			query:   "SELECT * FROM test limit 200",
   115  			expected: expectedValues{
   116  				callsToCallback:  2,
   117  				lenLastBatch:     72,
   118  				lastRowsAffected: uint64(72),
   119  			},
   120  		},
   121  		{
   122  			name:    "with limit set to a number not multiple of the batch capacity",
   123  			handler: handler,
   124  			conn:    dummyConn,
   125  			query:   "SELECT * FROM test limit 530",
   126  			expected: expectedValues{
   127  				callsToCallback:  5,
   128  				lenLastBatch:     18,
   129  				lastRowsAffected: uint64(18),
   130  			},
   131  		},
   132  		{
   133  			name:    "with limit zero",
   134  			handler: handler,
   135  			conn:    dummyConn,
   136  			query:   "SELECT * FROM test limit 0",
   137  			expected: expectedValues{
   138  				callsToCallback:  1,
   139  				lenLastBatch:     0,
   140  				lastRowsAffected: uint64(0),
   141  			},
   142  		},
   143  	}
   144  
   145  	for _, test := range tests {
   146  		t.Run(test.name, func(t *testing.T) {
   147  			var callsToCallback int
   148  			var lenLastBatch int
   149  			var lastRowsAffected uint64
   150  			handler.ComInitDB(test.conn, "test")
   151  			err := handler.ComQuery(test.conn, test.query, func(res *sqltypes.Result, more bool) error {
   152  				callsToCallback++
   153  				lenLastBatch = len(res.Rows)
   154  				lastRowsAffected = res.RowsAffected
   155  				return nil
   156  			})
   157  
   158  			require.NoError(t, err)
   159  			assert.Equal(t, test.expected.callsToCallback, callsToCallback)
   160  			assert.Equal(t, test.expected.lenLastBatch, lenLastBatch)
   161  			assert.Equal(t, test.expected.lastRowsAffected, lastRowsAffected)
   162  
   163  		})
   164  	}
   165  
   166  	t.Run("sum aggregation type is correct", func(t *testing.T) {
   167  		handler.ComInitDB(dummyConn, "test")
   168  		var result *sqltypes.Result
   169  		err := handler.ComQuery(dummyConn, "select sum(1) from test", func(res *sqltypes.Result, more bool) error {
   170  			result = res
   171  			return nil
   172  		})
   173  		require.NoError(t, err)
   174  		require.Equal(t, 1, len(result.Rows))
   175  		require.Equal(t, sqltypes.Float64, result.Rows[0][0].Type())
   176  		require.Equal(t, []byte("1010"), result.Rows[0][0].ToBytes())
   177  	})
   178  
   179  	t.Run("avg aggregation type is correct", func(t *testing.T) {
   180  		handler.ComInitDB(dummyConn, "test")
   181  		var result *sqltypes.Result
   182  		err := handler.ComQuery(dummyConn, "select avg(1) from test", func(res *sqltypes.Result, more bool) error {
   183  			result = res
   184  			return nil
   185  		})
   186  		require.NoError(t, err)
   187  		require.Equal(t, 1, len(result.Rows))
   188  		require.Equal(t, sqltypes.Float64, result.Rows[0][0].Type())
   189  		require.Equal(t, []byte("1"), result.Rows[0][0].ToBytes())
   190  	})
   191  
   192  	t.Run("if() type is correct", func(t *testing.T) {
   193  		handler.ComInitDB(dummyConn, "test")
   194  		var result *sqltypes.Result
   195  		err := handler.ComQuery(dummyConn, "select if(1, 123, 'def')", func(res *sqltypes.Result, more bool) error {
   196  			result = res
   197  			return nil
   198  		})
   199  		require.NoError(t, err)
   200  		require.Equal(t, 1, len(result.Rows))
   201  		require.Equal(t, sqltypes.Text, result.Rows[0][0].Type())
   202  		require.Equal(t, []byte("123"), result.Rows[0][0].ToBytes())
   203  
   204  		err = handler.ComQuery(dummyConn, "select if(0, 123, 456)", func(res *sqltypes.Result, more bool) error {
   205  			result = res
   206  			return nil
   207  		})
   208  		require.NoError(t, err)
   209  		require.Equal(t, 1, len(result.Rows))
   210  		require.Equal(t, sqltypes.Int64, result.Rows[0][0].Type())
   211  		require.Equal(t, []byte("456"), result.Rows[0][0].ToBytes())
   212  	})
   213  }
   214  
   215  func TestHandlerErrors(t *testing.T) {
   216  	e, pro := setupMemDB(require.New(t))
   217  	dbFunc := pro.Database
   218  
   219  	dummyConn := newConn(1)
   220  	handler := &Handler{
   221  		e: e,
   222  		sm: NewSessionManager(
   223  			testSessionBuilder(pro),
   224  			sql.NoopTracer,
   225  			dbFunc,
   226  			sql.NewMemoryManager(nil),
   227  			sqle.NewProcessList(),
   228  			"foo",
   229  		),
   230  		readTimeout: time.Second,
   231  	}
   232  	handler.NewConnection(dummyConn)
   233  
   234  	type expectedValues struct {
   235  		callsToCallback  int
   236  		lenLastBatch     int
   237  		lastRowsAffected uint64
   238  	}
   239  
   240  	setupCommands := []string{"CREATE TABLE `test_table` ( `id` INT NOT NULL PRIMARY KEY, `v` INT );"}
   241  
   242  	tests := []struct {
   243  		name              string
   244  		query             string
   245  		expectedErrorCode int
   246  	}{
   247  		{
   248  			name:              "insert with nonexistent field name",
   249  			query:             "INSERT INTO `test_table` (`id`, `v_`) VALUES (1, 2)",
   250  			expectedErrorCode: mysql.ERBadFieldError,
   251  		},
   252  		{
   253  			name:              "insert into nonexistent table",
   254  			query:             "INSERT INTO `test`.`no_such_table` (`id`, `v`) VALUES (1, 2)",
   255  			expectedErrorCode: mysql.ERNoSuchTable,
   256  		},
   257  		{
   258  			name:              "insert into same column twice",
   259  			query:             "INSERT INTO `test`.`test_table` (`id`, `id`, `v`) VALUES (1, 2, 3)",
   260  			expectedErrorCode: mysql.ERFieldSpecifiedTwice,
   261  		},
   262  	}
   263  
   264  	handler.ComInitDB(dummyConn, "test")
   265  	for _, setupCommand := range setupCommands {
   266  		err := handler.ComQuery(dummyConn, setupCommand, func(res *sqltypes.Result, more bool) error {
   267  			return nil
   268  		})
   269  		require.NoError(t, err)
   270  	}
   271  
   272  	for _, test := range tests {
   273  		t.Run(test.name, func(t *testing.T) {
   274  			err := handler.ComQuery(dummyConn, test.query, func(res *sqltypes.Result, more bool) error {
   275  				return nil
   276  			})
   277  			require.NotNil(t, err)
   278  			sqlErr, isSqlError := err.(*mysql.SQLError)
   279  			require.True(t, isSqlError)
   280  			require.Equal(t, test.expectedErrorCode, sqlErr.Number())
   281  		})
   282  	}
   283  }
   284  
   285  // TestHandlerComReset asserts that the Handler.ComResetConnection method correctly clears all session
   286  // state (e.g. table locks, prepared statements, user variables, session variables), and keeps the current
   287  // database selected.
   288  func TestHandlerComResetConnection(t *testing.T) {
   289  	e, pro := setupMemDB(require.New(t))
   290  	dummyConn := newConn(1)
   291  	dbFunc := pro.Database
   292  
   293  	handler := &Handler{
   294  		e: e,
   295  		sm: NewSessionManager(
   296  			testSessionBuilder(pro),
   297  			sql.NoopTracer,
   298  			dbFunc,
   299  			sql.NewMemoryManager(nil),
   300  			sqle.NewProcessList(),
   301  			"foo",
   302  		),
   303  	}
   304  	handler.NewConnection(dummyConn)
   305  	handler.ComInitDB(dummyConn, "test")
   306  
   307  	prepareData := &mysql.PrepareData{
   308  		StatementID: 0,
   309  		PrepareStmt: "select 42 + ? from dual",
   310  		ParamsCount: 0,
   311  		ParamsType:  nil,
   312  		ColumnNames: nil,
   313  		BindVars: map[string]*query.BindVariable{
   314  			"v1": {Type: query.Type_INT8, Value: []byte("5")},
   315  		},
   316  	}
   317  
   318  	// Create a prepared statement, a table lock, and a user var in the current session
   319  	_, err := handler.ComPrepare(dummyConn, prepareData.PrepareStmt, prepareData)
   320  	require.NoError(t, err)
   321  	_, cached := e.PreparedDataCache.GetCachedStmt(dummyConn.ConnectionID, prepareData.PrepareStmt)
   322  	require.True(t, cached)
   323  	err = handler.ComQuery(dummyConn, "SET @userVar = 42;", func(res *sqltypes.Result, more bool) error {
   324  		return nil
   325  	})
   326  	require.NoError(t, err)
   327  
   328  	// Reset the connection to clear all session state
   329  	err = handler.ComResetConnection(dummyConn)
   330  	require.NoError(t, err)
   331  
   332  	// Assert that the session is clean – the selected database should not change, and all session state
   333  	// such as user vars, session vars, prepared statements, table locks, and temporary tables should be cleared.
   334  	err = handler.ComQuery(dummyConn, "SELECT database()", func(res *sqltypes.Result, more bool) error {
   335  		require.Equal(t, "test", res.Rows[0][0].ToString())
   336  		return nil
   337  	})
   338  	require.NoError(t, err)
   339  	_, cached = e.PreparedDataCache.GetCachedStmt(dummyConn.ConnectionID, prepareData.PrepareStmt)
   340  	require.False(t, cached)
   341  	err = handler.ComQuery(dummyConn, "SELECT @userVar;", func(res *sqltypes.Result, more bool) error {
   342  		require.True(t, res.Rows[0][0].IsNull())
   343  		return nil
   344  	})
   345  	require.NoError(t, err)
   346  }
   347  
   348  func TestHandlerComPrepare(t *testing.T) {
   349  	e, pro := setupMemDB(require.New(t))
   350  	dummyConn := newConn(1)
   351  	dbFunc := pro.Database
   352  
   353  	handler := &Handler{
   354  		e: e,
   355  		sm: NewSessionManager(
   356  			testSessionBuilder(pro),
   357  			sql.NoopTracer,
   358  			dbFunc,
   359  			sql.NewMemoryManager(nil),
   360  			sqle.NewProcessList(),
   361  			"foo",
   362  		),
   363  	}
   364  	handler.NewConnection(dummyConn)
   365  
   366  	type testcase struct {
   367  		name        string
   368  		statement   string
   369  		expected    []*query.Field
   370  		expectedErr *mysql.SQLError
   371  	}
   372  
   373  	for _, test := range []testcase{
   374  		{
   375  			name:      "insert statement returns nil schema",
   376  			statement: "insert into test (c1) values (?)",
   377  			expected:  nil,
   378  		},
   379  		{
   380  			name:      "update statement returns nil schema",
   381  			statement: "update test set c1 = ?",
   382  			expected:  nil,
   383  		},
   384  		{
   385  			name:      "delete statement returns nil schema",
   386  			statement: "delete from test where c1 = ?",
   387  			expected:  nil,
   388  		},
   389  		{
   390  			name:      "select statement returns non-nil schema",
   391  			statement: "select c1 from test where c1 > ?",
   392  			expected: []*query.Field{
   393  				{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   394  			},
   395  		},
   396  		{
   397  			name:        "errors are cast to SQLError",
   398  			statement:   "SELECT * from doesnotexist LIMIT ?",
   399  			expectedErr: mysql.NewSQLError(mysql.ERNoSuchTable, "", "table not found: %s", "doesnotexist"),
   400  		},
   401  	} {
   402  		t.Run(test.name, func(t *testing.T) {
   403  			handler.ComInitDB(dummyConn, "test")
   404  			schema, err := handler.ComPrepare(dummyConn, test.statement, samplePrepareData)
   405  			if test.expectedErr == nil {
   406  				require.NoError(t, err)
   407  				require.Equal(t, test.expected, schema)
   408  			} else {
   409  				require.NotNil(t, err)
   410  				sqlErr, isSqlError := err.(*mysql.SQLError)
   411  				require.True(t, isSqlError)
   412  				require.Equal(t, test.expectedErr.Number(), sqlErr.Number())
   413  				require.Equal(t, test.expectedErr.SQLState(), sqlErr.SQLState())
   414  				require.Equal(t, test.expectedErr.Error(), sqlErr.Error())
   415  			}
   416  		})
   417  	}
   418  }
   419  
   420  func TestHandlerComPrepareExecute(t *testing.T) {
   421  	e, pro := setupMemDB(require.New(t))
   422  	dummyConn := newConn(1)
   423  	dbFunc := pro.Database
   424  
   425  	handler := &Handler{
   426  		e: e,
   427  		sm: NewSessionManager(
   428  			testSessionBuilder(pro),
   429  			sql.NoopTracer,
   430  			dbFunc,
   431  			sql.NewMemoryManager(nil),
   432  			sqle.NewProcessList(),
   433  			"foo",
   434  		),
   435  	}
   436  	handler.NewConnection(dummyConn)
   437  
   438  	type testcase struct {
   439  		name     string
   440  		prepare  *mysql.PrepareData
   441  		execute  map[string]*query.BindVariable
   442  		schema   []*query.Field
   443  		expected []sql.Row
   444  	}
   445  
   446  	for _, test := range []testcase{
   447  		{
   448  			name: "select statement returns nil schema",
   449  			prepare: &mysql.PrepareData{
   450  				StatementID: 0,
   451  				PrepareStmt: "select c1 from test where c1 < ?",
   452  				ParamsCount: 0,
   453  				ParamsType:  nil,
   454  				ColumnNames: nil,
   455  				BindVars: map[string]*query.BindVariable{
   456  					"v1": {Type: query.Type_INT8, Value: []byte("5")},
   457  				},
   458  			},
   459  			schema: []*query.Field{
   460  				{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   461  			},
   462  			expected: []sql.Row{
   463  				{0}, {1}, {2}, {3}, {4},
   464  			},
   465  		},
   466  	} {
   467  		t.Run(test.name, func(t *testing.T) {
   468  			handler.ComInitDB(dummyConn, "test")
   469  			schema, err := handler.ComPrepare(dummyConn, test.prepare.PrepareStmt, samplePrepareData)
   470  			require.NoError(t, err)
   471  			require.Equal(t, test.schema, schema)
   472  
   473  			var res []sql.Row
   474  			callback := func(r *sqltypes.Result) error {
   475  				for _, r := range r.Rows {
   476  					var vals []interface{}
   477  					for _, v := range r {
   478  						val, err := strconv.ParseInt(string(v.Raw()), 0, 64)
   479  						if err != nil {
   480  							return err
   481  						}
   482  						vals = append(vals, int(val))
   483  					}
   484  					res = append(res, sql.NewRow(vals...))
   485  				}
   486  				return nil
   487  			}
   488  			err = handler.ComStmtExecute(dummyConn, test.prepare, callback)
   489  			require.NoError(t, err)
   490  			require.Equal(t, test.expected, res)
   491  		})
   492  	}
   493  }
   494  
   495  func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) {
   496  	e, pro := setupMemDB(require.New(t))
   497  	dummyConn := newConn(1)
   498  	dbFunc := pro.Database
   499  
   500  	handler := &Handler{
   501  		e: e,
   502  		sm: NewSessionManager(
   503  			testSessionBuilder(pro),
   504  			sql.NoopTracer,
   505  			dbFunc,
   506  			sql.NewMemoryManager(nil),
   507  			sqle.NewProcessList(),
   508  			"foo",
   509  		),
   510  	}
   511  	handler.NewConnection(dummyConn)
   512  	analyzer.SetPreparedStmts(true)
   513  	defer func() {
   514  		analyzer.SetPreparedStmts(false)
   515  	}()
   516  	type testcase struct {
   517  		name     string
   518  		prepare  *mysql.PrepareData
   519  		execute  map[string]*query.BindVariable
   520  		schema   []*query.Field
   521  		expected []sql.Row
   522  	}
   523  
   524  	for _, test := range []testcase{
   525  		{
   526  			name: "select statement returns nil schema",
   527  			prepare: &mysql.PrepareData{
   528  				StatementID: 0,
   529  				PrepareStmt: "select c1 from test where c1 < ?",
   530  				ParamsCount: 0,
   531  				ParamsType:  nil,
   532  				ColumnNames: nil,
   533  				BindVars: map[string]*query.BindVariable{
   534  					"v1": {Type: query.Type_INT8, Value: []byte("5")},
   535  				},
   536  			},
   537  			schema: []*query.Field{
   538  				{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   539  			},
   540  			expected: []sql.Row{
   541  				{0}, {1}, {2}, {3}, {4},
   542  			},
   543  		},
   544  	} {
   545  		t.Run(test.name, func(t *testing.T) {
   546  			handler.ComInitDB(dummyConn, "test")
   547  			schema, err := handler.ComPrepare(dummyConn, test.prepare.PrepareStmt, samplePrepareData)
   548  			require.NoError(t, err)
   549  			require.Equal(t, test.schema, schema)
   550  
   551  			var res []sql.Row
   552  			callback := func(r *sqltypes.Result) error {
   553  				for _, r := range r.Rows {
   554  					var vals []interface{}
   555  					for _, v := range r {
   556  						val, err := strconv.ParseInt(string(v.Raw()), 0, 64)
   557  						if err != nil {
   558  							return err
   559  						}
   560  						vals = append(vals, int(val))
   561  					}
   562  					res = append(res, sql.NewRow(vals...))
   563  				}
   564  				return nil
   565  			}
   566  			err = handler.ComStmtExecute(dummyConn, test.prepare, callback)
   567  			require.NoError(t, err)
   568  			require.Equal(t, test.expected, res)
   569  		})
   570  	}
   571  }
   572  
   573  type TestListener struct {
   574  	Connections int
   575  	Queries     int
   576  	Disconnects int
   577  	Successes   int
   578  	Failures    int
   579  }
   580  
   581  func (tl *TestListener) ClientConnected() {
   582  	tl.Connections++
   583  }
   584  
   585  func (tl *TestListener) ClientDisconnected() {
   586  	tl.Disconnects++
   587  }
   588  
   589  func (tl *TestListener) QueryStarted() {
   590  	tl.Queries++
   591  }
   592  
   593  func (tl *TestListener) QueryCompleted(success bool, duration time.Duration) {
   594  	if success {
   595  		tl.Successes++
   596  	} else {
   597  		tl.Failures++
   598  	}
   599  }
   600  
   601  func TestServerEventListener(t *testing.T) {
   602  	require := require.New(t)
   603  	e, pro := setupMemDB(require)
   604  	dbFunc := pro.Database
   605  
   606  	listener := &TestListener{}
   607  	handler := &Handler{
   608  		e: e,
   609  		sm: NewSessionManager(
   610  			func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
   611  				return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil
   612  			},
   613  			sql.NoopTracer,
   614  			dbFunc,
   615  			e.MemoryManager,
   616  			e.ProcessList,
   617  			"foo",
   618  		),
   619  		sel: listener,
   620  	}
   621  
   622  	cb := func(res *sqltypes.Result, more bool) error {
   623  		return nil
   624  	}
   625  
   626  	require.Equal(listener.Connections, 0)
   627  	require.Equal(listener.Disconnects, 0)
   628  	require.Equal(listener.Queries, 0)
   629  	require.Equal(listener.Successes, 0)
   630  	require.Equal(listener.Failures, 0)
   631  
   632  	conn1 := newConn(1)
   633  	handler.NewConnection(conn1)
   634  	require.Equal(listener.Connections, 1)
   635  	require.Equal(listener.Disconnects, 0)
   636  
   637  	err := handler.sm.SetDB(conn1, "test")
   638  	require.NoError(err)
   639  
   640  	err = handler.ComQuery(conn1, "SELECT 1", cb)
   641  	require.NoError(err)
   642  	require.Equal(listener.Queries, 1)
   643  	require.Equal(listener.Successes, 1)
   644  	require.Equal(listener.Failures, 0)
   645  
   646  	conn2 := newConn(2)
   647  	handler.NewConnection(conn2)
   648  	require.Equal(listener.Connections, 2)
   649  	require.Equal(listener.Disconnects, 0)
   650  
   651  	handler.ComInitDB(conn2, "test")
   652  	err = handler.ComQuery(conn2, "select 1", cb)
   653  	require.NoError(err)
   654  	require.Equal(listener.Queries, 2)
   655  	require.Equal(listener.Successes, 2)
   656  	require.Equal(listener.Failures, 0)
   657  
   658  	err = handler.ComQuery(conn1, "select bad_col from bad_table with illegal syntax", cb)
   659  	require.Error(err)
   660  	require.Equal(listener.Queries, 3)
   661  	require.Equal(listener.Successes, 2)
   662  	require.Equal(listener.Failures, 1)
   663  
   664  	handler.ConnectionClosed(conn1)
   665  	require.Equal(listener.Connections, 2)
   666  	require.Equal(listener.Disconnects, 1)
   667  
   668  	handler.ConnectionClosed(conn2)
   669  	require.Equal(listener.Connections, 2)
   670  	require.Equal(listener.Disconnects, 2)
   671  
   672  	conn3 := newConn(3)
   673  	query := "SELECT ?"
   674  	_, err = handler.ComPrepare(conn3, query, samplePrepareData)
   675  	require.NoError(err)
   676  	require.Equal(1, len(e.PreparedDataCache.CachedStatementsForSession(conn3.ConnectionID)))
   677  	require.NotNil(e.PreparedDataCache.GetCachedStmt(conn3.ConnectionID, query))
   678  
   679  	handler.ConnectionClosed(conn3)
   680  	require.Equal(0, len(e.PreparedDataCache.CachedStatementsForSession(conn3.ConnectionID)))
   681  }
   682  
   683  func TestHandlerKill(t *testing.T) {
   684  	require := require.New(t)
   685  	e, pro := setupMemDB(require)
   686  	dbFunc := pro.Database
   687  
   688  	handler := &Handler{
   689  		e: e,
   690  		sm: NewSessionManager(
   691  			func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
   692  				return sql.NewBaseSessionWithClientServer(addr, sql.Client{Capabilities: conn.Capabilities}, conn.ConnectionID), nil
   693  			},
   694  			sql.NoopTracer,
   695  			dbFunc,
   696  			e.MemoryManager,
   697  			e.ProcessList,
   698  			"foo",
   699  		),
   700  	}
   701  
   702  	conn1 := newConn(1)
   703  	handler.NewConnection(conn1)
   704  
   705  	conn2 := newConn(2)
   706  	handler.NewConnection(conn2)
   707  
   708  	require.Len(handler.sm.connections, 2)
   709  	require.Len(handler.sm.sessions, 0)
   710  
   711  	handler.ComInitDB(conn2, "test")
   712  	err := handler.ComQuery(conn2, "KILL QUERY 1", func(res *sqltypes.Result, more bool) error {
   713  		return nil
   714  	})
   715  	require.NoError(err)
   716  
   717  	require.False(conn1.Conn.(*mockConn).closed)
   718  	require.Len(handler.sm.connections, 2)
   719  	require.Len(handler.sm.sessions, 1)
   720  
   721  	err = handler.sm.SetDB(conn1, "test")
   722  	require.NoError(err)
   723  	ctx1, err := handler.sm.NewContextWithQuery(conn1, "SELECT 1")
   724  	require.NoError(err)
   725  	ctx1, err = handler.e.ProcessList.BeginQuery(ctx1, "SELECT 1")
   726  	require.NoError(err)
   727  
   728  	err = handler.ComQuery(conn2, "KILL "+fmt.Sprint(ctx1.ID()), func(res *sqltypes.Result, more bool) error {
   729  		return nil
   730  	})
   731  	require.NoError(err)
   732  
   733  	require.Error(ctx1.Err())
   734  	require.True(conn1.Conn.(*mockConn).closed)
   735  	handler.ConnectionClosed(conn1)
   736  	require.Len(handler.sm.sessions, 1)
   737  }
   738  
   739  func TestSchemaToFields(t *testing.T) {
   740  	require := require.New(t)
   741  
   742  	schema := sql.Schema{
   743  		// Blob, Text, and JSON Types
   744  		{Name: "tinyblob", Source: "table1", DatabaseSource: "db1", Type: types.TinyBlob},
   745  		{Name: "blob", Source: "table1", DatabaseSource: "db1", Type: types.Blob},
   746  		{Name: "mediumblob", Source: "table1", DatabaseSource: "db1", Type: types.MediumBlob},
   747  		{Name: "longblob", Source: "table1", DatabaseSource: "db1", Type: types.LongBlob},
   748  		{Name: "tinytext", Source: "table1", DatabaseSource: "db1", Type: types.TinyText},
   749  		{Name: "text", Source: "table1", DatabaseSource: "db1", Type: types.Text},
   750  		{Name: "mediumtext", Source: "table1", DatabaseSource: "db1", Type: types.MediumText},
   751  		{Name: "longtext", Source: "table1", DatabaseSource: "db1", Type: types.LongText},
   752  		{Name: "json", Source: "table1", DatabaseSource: "db1", Type: types.JSON},
   753  
   754  		// Geometry Types
   755  		{Name: "geometry", Source: "table1", DatabaseSource: "db1", Type: types.GeometryType{}},
   756  		{Name: "point", Source: "table1", DatabaseSource: "db1", Type: types.PointType{}},
   757  		{Name: "polygon", Source: "table1", DatabaseSource: "db1", Type: types.PolygonType{}},
   758  		{Name: "linestring", Source: "table1", DatabaseSource: "db1", Type: types.LineStringType{}},
   759  
   760  		// Integer Types
   761  		{Name: "uint8", Source: "table1", DatabaseSource: "db1", Type: types.Uint8},
   762  		{Name: "int8", Source: "table1", DatabaseSource: "db1", Type: types.Int8},
   763  		{Name: "uint16", Source: "table1", DatabaseSource: "db1", Type: types.Uint16},
   764  		{Name: "int16", Source: "table1", DatabaseSource: "db1", Type: types.Int16},
   765  		{Name: "uint24", Source: "table1", DatabaseSource: "db1", Type: types.Uint24},
   766  		{Name: "int24", Source: "table1", DatabaseSource: "db1", Type: types.Int24},
   767  		{Name: "uint32", Source: "table1", DatabaseSource: "db1", Type: types.Uint32},
   768  		{Name: "int32", Source: "table1", DatabaseSource: "db1", Type: types.Int32},
   769  		{Name: "uint64", Source: "table1", DatabaseSource: "db1", Type: types.Uint64},
   770  		{Name: "int64", Source: "table1", DatabaseSource: "db1", Type: types.Int64},
   771  
   772  		// Floating Point and Decimal Types
   773  		{Name: "float32", Source: "table1", DatabaseSource: "db1", Type: types.Float32},
   774  		{Name: "float64", Source: "table1", DatabaseSource: "db1", Type: types.Float64},
   775  		{Name: "decimal10_0", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateDecimalType(10, 0)},
   776  		{Name: "decimal60_30", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateDecimalType(60, 30)},
   777  
   778  		// Char, Binary, and Bit Types
   779  		{Name: "varchar50", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateString(sqltypes.VarChar, 50, sql.Collation_Default)},
   780  		{Name: "varbinary12345", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateBinary(sqltypes.VarBinary, 12345)},
   781  		{Name: "binary123", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateBinary(sqltypes.Binary, 123)},
   782  		{Name: "char123", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateString(sqltypes.Char, 123, sql.Collation_Default)},
   783  		{Name: "bit12", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateBitType(12)},
   784  
   785  		// Dates
   786  		{Name: "datetime", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateDatetimeType(sqltypes.Datetime, 0)},
   787  		{Name: "timestamp", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateDatetimeType(sqltypes.Timestamp, 0)},
   788  		{Name: "date", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateDatetimeType(sqltypes.Date, 0)},
   789  		{Name: "time", Source: "table1", DatabaseSource: "db1", Type: types.Time},
   790  		{Name: "year", Source: "table1", DatabaseSource: "db1", Type: types.Year},
   791  
   792  		// Set and Enum Types
   793  		{Name: "set", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateSetType([]string{"one", "two", "three", "four"}, sql.Collation_Default)},
   794  		{Name: "enum", Source: "table1", DatabaseSource: "db1", Type: types.MustCreateEnumType([]string{"one", "two", "three", "four"}, sql.Collation_Default)},
   795  	}
   796  
   797  	expected := []*query.Field{
   798  		// Blob, Text, and JSON Types
   799  		{Name: "tinyblob", OrgName: "tinyblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 255, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   800  		{Name: "blob", OrgName: "blob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 65_535, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   801  		{Name: "mediumblob", OrgName: "mediumblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 16_777_215, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   802  		{Name: "longblob", OrgName: "longblob", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BLOB, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   803  		{Name: "tinytext", OrgName: "tinytext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 1020, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   804  		{Name: "text", OrgName: "text", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 262_140, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   805  		{Name: "mediumtext", OrgName: "mediumtext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 67_108_860, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   806  		{Name: "longtext", OrgName: "longtext", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TEXT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   807  		{Name: "json", OrgName: "json", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_JSON, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   808  
   809  		// Geometry Types
   810  		{Name: "geometry", OrgName: "geometry", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   811  		{Name: "point", OrgName: "point", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   812  		{Name: "polygon", OrgName: "polygon", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   813  		{Name: "linestring", OrgName: "linestring", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_GEOMETRY, Charset: mysql.CharacterSetBinary, ColumnLength: 4_294_967_295, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   814  
   815  		// Integer Types
   816  		{Name: "uint8", OrgName: "uint8", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT8, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 3, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)},
   817  		{Name: "int8", OrgName: "int8", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT8, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   818  		{Name: "uint16", OrgName: "uint16", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 5, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)},
   819  		{Name: "int16", OrgName: "int16", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   820  		{Name: "uint24", OrgName: "uint24", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT24, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 8, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)},
   821  		{Name: "int24", OrgName: "int24", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT24, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 9, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   822  		{Name: "uint32", OrgName: "uint32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 10, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)},
   823  		{Name: "int32", OrgName: "int32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   824  		{Name: "uint64", OrgName: "uint64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_UINT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG | query.MySqlFlag_UNSIGNED_FLAG)},
   825  		{Name: "int64", OrgName: "int64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_INT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   826  
   827  		// Floating Point and Decimal Types
   828  		{Name: "float32", OrgName: "float32", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_FLOAT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 12, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   829  		{Name: "float64", OrgName: "float64", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_FLOAT64, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 22, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   830  		{Name: "decimal10_0", OrgName: "decimal10_0", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DECIMAL, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Decimals: 0, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   831  		{Name: "decimal60_30", OrgName: "decimal60_30", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DECIMAL, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 62, Decimals: 30, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   832  
   833  		// Char, Binary, and Bit Types
   834  		{Name: "varchar50", OrgName: "varchar50", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_VARCHAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 50 * 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   835  		{Name: "varbinary12345", OrgName: "varbinary12345", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_VARBINARY, Charset: mysql.CharacterSetBinary, ColumnLength: 12345, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   836  		{Name: "binary123", OrgName: "binary123", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BINARY, Charset: mysql.CharacterSetBinary, ColumnLength: 123, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   837  		{Name: "char123", OrgName: "char123", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_CHAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 123 * 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   838  		{Name: "bit12", OrgName: "bit12", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_BIT, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 12, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   839  
   840  		// Dates
   841  		{Name: "datetime", OrgName: "datetime", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DATETIME, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 26, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   842  		{Name: "timestamp", OrgName: "timestamp", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TIMESTAMP, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 26, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   843  		{Name: "date", OrgName: "date", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_DATE, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 10, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   844  		{Name: "time", OrgName: "time", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_TIME, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 17, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   845  		{Name: "year", OrgName: "year", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_YEAR, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 4, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   846  
   847  		// Set and Enum Types
   848  		{Name: "set", OrgName: "set", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_SET, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 72, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   849  		{Name: "enum", OrgName: "enum", Table: "table1", OrgTable: "table1", Database: "db1", Type: query.Type_ENUM, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 20, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
   850  	}
   851  
   852  	require.Equal(len(schema), len(expected))
   853  
   854  	e, pro := setupMemDB(require)
   855  	dbFunc := pro.Database
   856  
   857  	handler := &Handler{
   858  		e: e,
   859  		sm: NewSessionManager(
   860  			testSessionBuilder(pro),
   861  			sql.NoopTracer,
   862  			dbFunc,
   863  			sql.NewMemoryManager(nil),
   864  			sqle.NewProcessList(),
   865  			"foo",
   866  		),
   867  		readTimeout: time.Second,
   868  	}
   869  
   870  	conn := newConn(1)
   871  	handler.NewConnection(conn)
   872  
   873  	ctx, err := handler.sm.NewContextWithQuery(conn, "SELECT 1")
   874  	require.NoError(err)
   875  
   876  	fields := schemaToFields(ctx, schema)
   877  	for i := 0; i < len(fields); i++ {
   878  		t.Run(schema[i].Name, func(t *testing.T) {
   879  			assert.Equal(t, expected[i], fields[i])
   880  		})
   881  	}
   882  }
   883  
   884  // TestHandlerMaxTextResponseBytes tests that the handler calculates the correct max text response byte
   885  // metadata for TEXT types, including honoring the character_set_results session variable. This is tested
   886  // here, instead of in string type unit tests, because of the dependency on system variables being loaded.
   887  func TestHandlerMaxTextResponseBytes(t *testing.T) {
   888  	variables.InitSystemVariables()
   889  	session := sql.NewBaseSession()
   890  	ctx := sql.NewContext(
   891  		context.Background(),
   892  		sql.WithSession(session),
   893  	)
   894  
   895  	tinyTextUtf8mb4 := types.MustCreateString(sqltypes.Text, types.TinyTextBlobMax, sql.Collation_Default)
   896  	textUtf8mb4 := types.MustCreateString(sqltypes.Text, types.TextBlobMax, sql.Collation_Default)
   897  	mediumTextUtf8mb4 := types.MustCreateString(sqltypes.Text, types.MediumTextBlobMax, sql.Collation_Default)
   898  	longTextUtf8mb4 := types.MustCreateString(sqltypes.Text, types.LongTextBlobMax, sql.Collation_Default)
   899  
   900  	// When character_set_results is set to utf8mb4, the multibyte character multiplier is 4
   901  	require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", "utf8mb4"))
   902  	require.EqualValues(t, types.TinyTextBlobMax*4, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx))
   903  	require.EqualValues(t, types.TextBlobMax*4, textUtf8mb4.MaxTextResponseByteLength(ctx))
   904  	require.EqualValues(t, types.MediumTextBlobMax*4, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx))
   905  	require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx))
   906  
   907  	// When character_set_results is set to utf8mb3, the multibyte character multiplier is 3
   908  	require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", "utf8mb3"))
   909  	require.EqualValues(t, types.TinyTextBlobMax*3, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx))
   910  	require.EqualValues(t, types.TextBlobMax*3, textUtf8mb4.MaxTextResponseByteLength(ctx))
   911  	require.EqualValues(t, types.MediumTextBlobMax*3, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx))
   912  	require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx))
   913  
   914  	// When character_set_results is set to utf8, the multibyte character multiplier is 3
   915  	require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", "utf8"))
   916  	require.EqualValues(t, types.TinyTextBlobMax*3, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx))
   917  	require.EqualValues(t, types.TextBlobMax*3, textUtf8mb4.MaxTextResponseByteLength(ctx))
   918  	require.EqualValues(t, types.MediumTextBlobMax*3, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx))
   919  	require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx))
   920  
   921  	// When character_set_results is set to NULL, the multibyte character multiplier is taken from
   922  	// the type's charset (4 in this case)
   923  	require.NoError(t, session.SetSessionVariable(ctx, "character_set_results", nil))
   924  	require.EqualValues(t, types.TinyTextBlobMax*4, tinyTextUtf8mb4.MaxTextResponseByteLength(ctx))
   925  	require.EqualValues(t, types.TextBlobMax*4, textUtf8mb4.MaxTextResponseByteLength(ctx))
   926  	require.EqualValues(t, types.MediumTextBlobMax*4, mediumTextUtf8mb4.MaxTextResponseByteLength(ctx))
   927  	require.EqualValues(t, types.LongTextBlobMax, longTextUtf8mb4.MaxTextResponseByteLength(ctx))
   928  }
   929  
   930  func TestHandlerTimeout(t *testing.T) {
   931  	require := require.New(t)
   932  
   933  	e, pro := setupMemDB(require)
   934  	dbFunc := pro.Database
   935  
   936  	e2, pro2 := setupMemDB(require)
   937  	dbFunc2 := pro2.Database
   938  
   939  	timeOutHandler := &Handler{
   940  		e: e,
   941  		sm: NewSessionManager(testSessionBuilder(pro),
   942  			sql.NoopTracer,
   943  			dbFunc,
   944  			sql.NewMemoryManager(nil),
   945  			sqle.NewProcessList(),
   946  			"foo"),
   947  		readTimeout: 1 * time.Second,
   948  	}
   949  
   950  	noTimeOutHandler := &Handler{
   951  		e: e2,
   952  		sm: NewSessionManager(testSessionBuilder(pro2),
   953  			sql.NoopTracer,
   954  			dbFunc2,
   955  			sql.NewMemoryManager(nil),
   956  			sqle.NewProcessList(),
   957  			"foo"),
   958  	}
   959  	require.Equal(1*time.Second, timeOutHandler.readTimeout)
   960  	require.Equal(0*time.Second, noTimeOutHandler.readTimeout)
   961  
   962  	connTimeout := newConn(1)
   963  	timeOutHandler.NewConnection(connTimeout)
   964  
   965  	connNoTimeout := newConn(2)
   966  	noTimeOutHandler.NewConnection(connNoTimeout)
   967  
   968  	timeOutHandler.ComInitDB(connTimeout, "test")
   969  	err := timeOutHandler.ComQuery(connTimeout, "SELECT SLEEP(2)", func(res *sqltypes.Result, more bool) error {
   970  		return nil
   971  	})
   972  	require.EqualError(err, "row read wait bigger than connection timeout (errno 1105) (sqlstate HY000)")
   973  
   974  	err = timeOutHandler.ComQuery(connTimeout, "SELECT SLEEP(0.5)", func(res *sqltypes.Result, more bool) error {
   975  		return nil
   976  	})
   977  	require.NoError(err)
   978  
   979  	noTimeOutHandler.ComInitDB(connNoTimeout, "test")
   980  	err = noTimeOutHandler.ComQuery(connNoTimeout, "SELECT SLEEP(2)", func(res *sqltypes.Result, more bool) error {
   981  		return nil
   982  	})
   983  	require.NoError(err)
   984  }
   985  
   986  func TestOkClosedConnection(t *testing.T) {
   987  	require := require.New(t)
   988  	e, pro := setupMemDB(require)
   989  	dbFunc := pro.Database
   990  
   991  	port, err := getFreePort()
   992  	require.NoError(err)
   993  
   994  	ready := make(chan struct{})
   995  	go okTestServer(t, ready, port)
   996  	<-ready
   997  	conn, err := net.Dial("tcp", "localhost:"+port)
   998  	require.NoError(err)
   999  	defer func() {
  1000  		_ = conn.Close()
  1001  	}()
  1002  
  1003  	h := &Handler{
  1004  		e: e,
  1005  		sm: NewSessionManager(
  1006  			testSessionBuilder(pro),
  1007  			sql.NoopTracer,
  1008  			dbFunc,
  1009  			sql.NewMemoryManager(nil),
  1010  			sqle.NewProcessList(),
  1011  			"foo",
  1012  		),
  1013  	}
  1014  	c := newConn(1)
  1015  	h.NewConnection(c)
  1016  
  1017  	q := fmt.Sprintf("SELECT SLEEP(%d)", (tcpCheckerSleepDuration * 4 / time.Second))
  1018  	h.ComInitDB(c, "test")
  1019  	err = h.ComQuery(c, q, func(res *sqltypes.Result, more bool) error {
  1020  		return nil
  1021  	})
  1022  	require.NoError(err)
  1023  }
  1024  
  1025  // Tests the CLIENT_FOUND_ROWS capabilities flag
  1026  func TestHandlerFoundRowsCapabilities(t *testing.T) {
  1027  	e, pro := setupMemDB(require.New(t))
  1028  	dbFunc := pro.Database
  1029  	dummyConn := newConn(1)
  1030  
  1031  	// Set the capabilities to include found rows
  1032  	dummyConn.Capabilities = mysql.CapabilityClientFoundRows
  1033  
  1034  	// Setup the handler
  1035  	handler := &Handler{
  1036  		e: e,
  1037  		sm: NewSessionManager(
  1038  			testSessionBuilder(pro),
  1039  			sql.NoopTracer,
  1040  			dbFunc,
  1041  			sql.NewMemoryManager(nil),
  1042  			sqle.NewProcessList(),
  1043  			"foo",
  1044  		),
  1045  	}
  1046  
  1047  	tests := []struct {
  1048  		name                 string
  1049  		handler              *Handler
  1050  		conn                 *mysql.Conn
  1051  		query                string
  1052  		expectedRowsAffected uint64
  1053  	}{
  1054  		{
  1055  			name:                 "Update query should return number of rows matched instead of rows affected",
  1056  			handler:              handler,
  1057  			conn:                 dummyConn,
  1058  			query:                "UPDATE test set c1 = c1 where c1 < 10",
  1059  			expectedRowsAffected: uint64(10),
  1060  		},
  1061  		{
  1062  			name:                 "INSERT ON UPDATE returns +1 for every row that already exists",
  1063  			handler:              handler,
  1064  			conn:                 dummyConn,
  1065  			query:                "INSERT INTO test VALUES (1), (2), (3) ON DUPLICATE KEY UPDATE c1=c1",
  1066  			expectedRowsAffected: uint64(3),
  1067  		},
  1068  		{
  1069  			name:                 "SQL_CALC_ROWS should not affect CLIENT_FOUND_ROWS output",
  1070  			handler:              handler,
  1071  			conn:                 dummyConn,
  1072  			query:                "SELECT SQL_CALC_FOUND_ROWS * FROM test LIMIT 5",
  1073  			expectedRowsAffected: uint64(5),
  1074  		},
  1075  		{
  1076  			name:                 "INSERT returns rows affected",
  1077  			handler:              handler,
  1078  			conn:                 dummyConn,
  1079  			query:                "INSERT into test VALUES (10000),(10001),(10002)",
  1080  			expectedRowsAffected: uint64(3),
  1081  		},
  1082  	}
  1083  
  1084  	for _, test := range tests {
  1085  		t.Run(test.name, func(t *testing.T) {
  1086  			handler.ComInitDB(test.conn, "test")
  1087  			var rowsAffected uint64
  1088  			err := handler.ComQuery(test.conn, test.query, func(res *sqltypes.Result, more bool) error {
  1089  				rowsAffected = uint64(res.RowsAffected)
  1090  				return nil
  1091  			})
  1092  
  1093  			require.NoError(t, err)
  1094  			require.Equal(t, test.expectedRowsAffected, rowsAffected)
  1095  		})
  1096  	}
  1097  }
  1098  
  1099  func setupMemDB(require *require.Assertions) (*sqle.Engine, *memory.DbProvider) {
  1100  	db := memory.NewDatabase("test")
  1101  	pro := memory.NewDBProvider(db)
  1102  	e := sqle.NewDefault(pro)
  1103  	ctx := sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro)))
  1104  
  1105  	tableTest := memory.NewTable(db, "test", sql.NewPrimaryKeySchema(sql.Schema{{Name: "c1", Type: types.Int32, Source: "test"}}), nil)
  1106  	tableTest.EnablePrimaryKeyIndexes()
  1107  
  1108  	for i := 0; i < 1010; i++ {
  1109  		require.NoError(tableTest.Insert(
  1110  			ctx,
  1111  			sql.NewRow(int32(i)),
  1112  		))
  1113  	}
  1114  
  1115  	db.AddTable("test", tableTest)
  1116  
  1117  	return e, pro
  1118  }
  1119  
  1120  func getFreePort() (string, error) {
  1121  	addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
  1122  	if err != nil {
  1123  		return "", err
  1124  	}
  1125  
  1126  	l, err := net.ListenTCP("tcp", addr)
  1127  	if err != nil {
  1128  		return "", err
  1129  	}
  1130  	defer l.Close()
  1131  	return strconv.Itoa(l.Addr().(*net.TCPAddr).Port), nil
  1132  }
  1133  
  1134  func testServer(t *testing.T, ready chan struct{}, port string, breakConn bool) {
  1135  	l, err := net.Listen("tcp", ":"+port)
  1136  	defer func() {
  1137  		_ = l.Close()
  1138  	}()
  1139  	if err != nil {
  1140  		t.Fatal(err)
  1141  	}
  1142  	close(ready)
  1143  	conn, err := l.Accept()
  1144  	if err != nil {
  1145  		return
  1146  	}
  1147  
  1148  	if !breakConn {
  1149  		defer func() {
  1150  			_ = conn.Close()
  1151  		}()
  1152  
  1153  		_, err = io.ReadAll(conn)
  1154  		if err != nil {
  1155  			t.Fatal(err)
  1156  		}
  1157  	} // else: dirty return without closing or reading to force the socket into TIME_WAIT
  1158  }
  1159  func okTestServer(t *testing.T, ready chan struct{}, port string) {
  1160  	testServer(t, ready, port, false)
  1161  }
  1162  
  1163  // This session builder is used as dummy mysql Conn is not complete and
  1164  // causes panic when accessing remote address.
  1165  func testSessionBuilder(pro *memory.DbProvider) func(ctx context.Context, c *mysql.Conn, addr string) (sql.Session, error) {
  1166  	return func(ctx context.Context, c *mysql.Conn, addr string) (sql.Session, error) {
  1167  		base := sql.NewBaseSessionWithClientServer(addr, sql.Client{Address: "127.0.0.1:34567", User: c.User, Capabilities: c.Capabilities}, c.ConnectionID)
  1168  		return memory.NewSession(base, pro), nil
  1169  	}
  1170  }
  1171  
  1172  type mockConn struct {
  1173  	net.Conn
  1174  	closed bool
  1175  }
  1176  
  1177  func (c *mockConn) Close() error {
  1178  	c.closed = true
  1179  	return nil
  1180  }
  1181  
  1182  func (c *mockConn) RemoteAddr() net.Addr {
  1183  	return mockAddr{}
  1184  }
  1185  
  1186  type mockAddr struct{}
  1187  
  1188  func (mockAddr) Network() string {
  1189  	return "tcp"
  1190  }
  1191  
  1192  func (mockAddr) String() string {
  1193  	return "localhost"
  1194  }
  1195  
  1196  func newConn(id uint32) *mysql.Conn {
  1197  	return &mysql.Conn{
  1198  		ConnectionID: id,
  1199  		Conn:         new(mockConn),
  1200  	}
  1201  }