vitess.io/vitess@v0.16.2/go/mysql/query_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package mysql
    18  
    19  import (
    20  	"fmt"
    21  	"reflect"
    22  	"sync"
    23  	"testing"
    24  
    25  	"google.golang.org/protobuf/proto"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  
    30  	"vitess.io/vitess/go/sqltypes"
    31  
    32  	querypb "vitess.io/vitess/go/vt/proto/query"
    33  )
    34  
    35  // Utility function to write sql query as packets to test parseComPrepare
    36  func preparePacket(t *testing.T, query string) []byte {
    37  	data := make([]byte, len(query)+1+packetHeaderSize)
    38  	// Not sure if it makes a difference
    39  	pos := packetHeaderSize
    40  	pos = writeByte(data, pos, ComPrepare)
    41  	copy(data[pos:], query)
    42  	return data
    43  }
    44  
    45  func MockPrepareData(t *testing.T) (*PrepareData, *sqltypes.Result) {
    46  	sql := "select * from test_table where id = ?"
    47  
    48  	result := &sqltypes.Result{
    49  		Fields: []*querypb.Field{
    50  			{
    51  				Name: "id",
    52  				Type: querypb.Type_INT32,
    53  			},
    54  		},
    55  		Rows: [][]sqltypes.Value{
    56  			{
    57  				sqltypes.MakeTrusted(querypb.Type_INT32, []byte("1")),
    58  			},
    59  		},
    60  		RowsAffected: 1,
    61  	}
    62  
    63  	prepare := &PrepareData{
    64  		StatementID: 18,
    65  		PrepareStmt: sql,
    66  		ParamsCount: 1,
    67  		ParamsType:  []int32{263},
    68  		ColumnNames: []string{"id"},
    69  		BindVars: map[string]*querypb.BindVariable{
    70  			"v1": sqltypes.Int32BindVariable(10),
    71  		},
    72  	}
    73  
    74  	return prepare, result
    75  }
    76  
    77  func TestComInitDB(t *testing.T) {
    78  	listener, sConn, cConn := createSocketPair(t)
    79  	defer func() {
    80  		listener.Close()
    81  		sConn.Close()
    82  		cConn.Close()
    83  	}()
    84  
    85  	// Write ComInitDB packet, read it, compare.
    86  	if err := cConn.writeComInitDB("my_db"); err != nil {
    87  		t.Fatalf("writeComInitDB failed: %v", err)
    88  	}
    89  	data, err := sConn.ReadPacket()
    90  	if err != nil || len(data) == 0 || data[0] != ComInitDB {
    91  		t.Fatalf("sConn.ReadPacket - ComInitDB failed: %v %v", data, err)
    92  	}
    93  	db := sConn.parseComInitDB(data)
    94  	assert.Equal(t, "my_db", db, "parseComInitDB returned unexpected data: %v", db)
    95  
    96  }
    97  
    98  func TestComSetOption(t *testing.T) {
    99  	listener, sConn, cConn := createSocketPair(t)
   100  	defer func() {
   101  		listener.Close()
   102  		sConn.Close()
   103  		cConn.Close()
   104  	}()
   105  
   106  	// Write ComSetOption packet, read it, compare.
   107  	if err := cConn.writeComSetOption(1); err != nil {
   108  		t.Fatalf("writeComSetOption failed: %v", err)
   109  	}
   110  	data, err := sConn.ReadPacket()
   111  	if err != nil || len(data) == 0 || data[0] != ComSetOption {
   112  		t.Fatalf("sConn.ReadPacket - ComSetOption failed: %v %v", data, err)
   113  	}
   114  	operation, ok := sConn.parseComSetOption(data)
   115  	require.True(t, ok, "parseComSetOption failed unexpectedly")
   116  	assert.Equal(t, uint16(1), operation, "parseComSetOption returned unexpected data: %v", operation)
   117  
   118  }
   119  
   120  func TestComStmtPrepare(t *testing.T) {
   121  	listener, sConn, cConn := createSocketPair(t)
   122  	defer func() {
   123  		listener.Close()
   124  		sConn.Close()
   125  		cConn.Close()
   126  	}()
   127  
   128  	sql := "select * from test_table where id = ?"
   129  	mockData := preparePacket(t, sql)
   130  
   131  	if err := cConn.writePacket(mockData); err != nil {
   132  		t.Fatalf("writePacket failed: %v", err)
   133  	}
   134  
   135  	data, err := sConn.ReadPacket()
   136  	require.NoError(t, err, "sConn.ReadPacket - ComPrepare failed: %v", err)
   137  
   138  	parsedQuery := sConn.parseComPrepare(data)
   139  	require.Equal(t, sql, parsedQuery, "Received incorrect query, want: %v, got: %v", sql, parsedQuery)
   140  
   141  	prepare, result := MockPrepareData(t)
   142  	sConn.PrepareData = make(map[uint32]*PrepareData)
   143  	sConn.PrepareData[prepare.StatementID] = prepare
   144  
   145  	// write the response to the client
   146  	if err := sConn.writePrepare(result.Fields, prepare); err != nil {
   147  		t.Fatalf("sConn.writePrepare failed: %v", err)
   148  	}
   149  
   150  	resp, err := cConn.ReadPacket()
   151  	require.NoError(t, err, "cConn.ReadPacket failed: %v", err)
   152  	require.Equal(t, prepare.StatementID, uint32(resp[1]), "Received incorrect Statement ID, want: %v, got: %v", prepare.StatementID, resp[1])
   153  
   154  }
   155  
   156  func TestComStmtPrepareUpdStmt(t *testing.T) {
   157  	listener, sConn, cConn := createSocketPair(t)
   158  	defer func() {
   159  		listener.Close()
   160  		sConn.Close()
   161  		cConn.Close()
   162  	}()
   163  
   164  	sql := "UPDATE test SET __bit = ?, __tinyInt = ?, __tinyIntU = ?, __smallInt = ?, __smallIntU = ?, __mediumInt = ?, __mediumIntU = ?, __int = ?, __intU = ?, __bigInt = ?, __bigIntU = ?, __decimal = ?, __float = ?, __double = ?, __date = ?, __datetime = ?, __timestamp = ?, __time = ?, __year = ?, __char = ?, __varchar = ?, __binary = ?, __varbinary = ?, __tinyblob = ?, __tinytext = ?, __blob = ?, __text = ?, __enum = ?, __set = ? WHERE __id = 0"
   165  	mockData := preparePacket(t, sql)
   166  
   167  	err := cConn.writePacket(mockData)
   168  	require.NoError(t, err, "writePacket failed")
   169  
   170  	data, err := sConn.ReadPacket()
   171  	require.NoError(t, err, "sConn.ReadPacket - ComPrepare failed")
   172  
   173  	parsedQuery := sConn.parseComPrepare(data)
   174  	require.Equal(t, sql, parsedQuery, "Received incorrect query")
   175  
   176  	paramsCount := uint16(29)
   177  	prepare := &PrepareData{
   178  		StatementID: 1,
   179  		PrepareStmt: sql,
   180  		ParamsCount: paramsCount,
   181  	}
   182  	sConn.PrepareData = make(map[uint32]*PrepareData)
   183  	sConn.PrepareData[prepare.StatementID] = prepare
   184  
   185  	// write the response to the client
   186  	err = sConn.writePrepare(nil, prepare)
   187  	require.NoError(t, err, "sConn.writePrepare failed")
   188  
   189  	resp, err := cConn.ReadPacket()
   190  	require.NoError(t, err, "cConn.ReadPacket failed")
   191  	require.EqualValues(t, prepare.StatementID, resp[1], "Received incorrect Statement ID")
   192  
   193  	for i := uint16(0); i < paramsCount; i++ {
   194  		resp, err := cConn.ReadPacket()
   195  		require.NoError(t, err, "cConn.ReadPacket failed")
   196  		require.EqualValues(t, 0xfd, resp[17], "Received incorrect Statement ID")
   197  	}
   198  }
   199  
   200  func TestComStmtSendLongData(t *testing.T) {
   201  	listener, sConn, cConn := createSocketPair(t)
   202  	defer func() {
   203  		listener.Close()
   204  		sConn.Close()
   205  		cConn.Close()
   206  	}()
   207  
   208  	prepare, result := MockPrepareData(t)
   209  	cConn.PrepareData = make(map[uint32]*PrepareData)
   210  	cConn.PrepareData[prepare.StatementID] = prepare
   211  	if err := cConn.writePrepare(result.Fields, prepare); err != nil {
   212  		t.Fatalf("writePrepare failed: %v", err)
   213  	}
   214  
   215  	// Since there's no writeComStmtSendLongData, we'll write a prepareStmt and check if we can read the StatementID
   216  	data, err := sConn.ReadPacket()
   217  	if err != nil || len(data) == 0 {
   218  		t.Fatalf("sConn.ReadPacket - ComStmtClose failed: %v %v", data, err)
   219  	}
   220  	stmtID, paramID, chunkData, ok := sConn.parseComStmtSendLongData(data)
   221  	require.True(t, ok, "parseComStmtSendLongData failed")
   222  	require.Equal(t, uint16(1), paramID, "Received incorrect ParamID, want %v, got %v:", paramID, 1)
   223  	require.Equal(t, prepare.StatementID, stmtID, "Received incorrect value, want: %v, got: %v", uint32(data[1]), prepare.StatementID)
   224  	// Check length of chunkData, Since its a subset of `data` and compare with it after we subtract the number of bytes that was read from it.
   225  	// sizeof(uint32) + sizeof(uint16) + 1 = 7
   226  	require.Equal(t, len(data)-7, len(chunkData), "Received bad chunkData")
   227  
   228  }
   229  
   230  func TestComStmtExecute(t *testing.T) {
   231  	listener, sConn, cConn := createSocketPair(t)
   232  	defer func() {
   233  		listener.Close()
   234  		sConn.Close()
   235  		cConn.Close()
   236  	}()
   237  
   238  	prepare, _ := MockPrepareData(t)
   239  	cConn.PrepareData = make(map[uint32]*PrepareData)
   240  	cConn.PrepareData[prepare.StatementID] = prepare
   241  
   242  	// This is simulated packets for `select * from test_table where id = ?`
   243  	data := []byte{23, 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 1, 1, 128, 1}
   244  
   245  	stmtID, _, err := sConn.parseComStmtExecute(cConn.PrepareData, data)
   246  	require.NoError(t, err, "parseComStmtExeute failed: %v", err)
   247  	require.Equal(t, uint32(18), stmtID, "Parsed incorrect values")
   248  
   249  }
   250  
   251  func TestComStmtExecuteUpdStmt(t *testing.T) {
   252  	listener, sConn, cConn := createSocketPair(t)
   253  	defer func() {
   254  		listener.Close()
   255  		sConn.Close()
   256  		cConn.Close()
   257  	}()
   258  
   259  	prepareDataMap := map[uint32]*PrepareData{
   260  		1: {
   261  			StatementID: 1,
   262  			ParamsCount: 29,
   263  			ParamsType:  make([]int32, 29),
   264  			BindVars:    map[string]*querypb.BindVariable{},
   265  		}}
   266  
   267  	// This is simulated packets for update query
   268  	data := []byte{
   269  		0x29, 0x01, 0x00, 0x00, 0x17, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
   270  		0x00, 0x00, 0x01, 0x10, 0x00, 0x01, 0x00, 0x01, 0x80, 0x02, 0x00, 0x02, 0x80, 0x03, 0x00, 0x03,
   271  		0x80, 0x03, 0x00, 0x03, 0x80, 0x08, 0x00, 0x08, 0x80, 0x00, 0x00, 0x04, 0x00, 0x05, 0x00, 0x0a,
   272  		0x00, 0x0c, 0x00, 0x07, 0x00, 0x0b, 0x00, 0x0d, 0x80, 0xfe, 0x00, 0xfe, 0x00, 0xfc, 0x00, 0xfc,
   273  		0x00, 0xfc, 0x00, 0xfe, 0x00, 0xfc, 0x00, 0xfe, 0x00, 0xfe, 0x00, 0xfe, 0x00, 0x08, 0x00, 0x00,
   274  		0x00, 0x00, 0x00, 0x00, 0xaa, 0xe0, 0x80, 0xff, 0x00, 0x80, 0xff, 0xff, 0x00, 0x00, 0x80, 0xff,
   275  		0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
   276  		0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x15, 0x31, 0x32, 0x33,
   277  		0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x30, 0x2e, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
   278  		0x38, 0x39, 0xd0, 0x0f, 0x49, 0x40, 0x44, 0x17, 0x41, 0x54, 0xfb, 0x21, 0x09, 0x40, 0x04, 0xe0,
   279  		0x07, 0x08, 0x08, 0x0b, 0xe0, 0x07, 0x08, 0x08, 0x11, 0x19, 0x3b, 0x00, 0x00, 0x00, 0x00, 0x0b,
   280  		0xe0, 0x07, 0x08, 0x08, 0x11, 0x19, 0x3b, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x01, 0x08, 0x00, 0x00,
   281  		0x00, 0x07, 0x3b, 0x3b, 0x00, 0x00, 0x00, 0x00, 0x04, 0x31, 0x39, 0x39, 0x39, 0x08, 0x31, 0x32,
   282  		0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x0c, 0xe9, 0x9f, 0xa9, 0xe5, 0x86, 0xac, 0xe7, 0x9c, 0x9f,
   283  		0xe8, 0xb5, 0x9e, 0x08, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x08, 0x31, 0x32, 0x33,
   284  		0x34, 0x35, 0x36, 0x37, 0x38, 0x08, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x0c, 0xe9,
   285  		0x9f, 0xa9, 0xe5, 0x86, 0xac, 0xe7, 0x9c, 0x9f, 0xe8, 0xb5, 0x9e, 0x08, 0x31, 0x32, 0x33, 0x34,
   286  		0x35, 0x36, 0x37, 0x38, 0x0c, 0xe9, 0x9f, 0xa9, 0xe5, 0x86, 0xac, 0xe7, 0x9c, 0x9f, 0xe8, 0xb5,
   287  		0x9e, 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x66, 0x6f, 0x6f, 0x2c, 0x62, 0x61, 0x72}
   288  
   289  	stmtID, _, err := sConn.parseComStmtExecute(prepareDataMap, data[4:]) // first 4 are header
   290  	require.NoError(t, err)
   291  	require.EqualValues(t, 1, stmtID)
   292  
   293  	prepData := prepareDataMap[stmtID]
   294  	assert.EqualValues(t, querypb.Type_BIT, prepData.ParamsType[0], "got: %s", querypb.Type(prepData.ParamsType[0]))
   295  	assert.EqualValues(t, querypb.Type_INT8, prepData.ParamsType[1], "got: %s", querypb.Type(prepData.ParamsType[1]))
   296  	assert.EqualValues(t, querypb.Type_INT8, prepData.ParamsType[2], "got: %s", querypb.Type(prepData.ParamsType[2]))
   297  	assert.EqualValues(t, querypb.Type_INT16, prepData.ParamsType[3], "got: %s", querypb.Type(prepData.ParamsType[3]))
   298  	assert.EqualValues(t, querypb.Type_INT16, prepData.ParamsType[4], "got: %s", querypb.Type(prepData.ParamsType[4]))
   299  	assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[5], "got: %s", querypb.Type(prepData.ParamsType[5]))
   300  	assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[6], "got: %s", querypb.Type(prepData.ParamsType[6]))
   301  	assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[7], "got: %s", querypb.Type(prepData.ParamsType[7]))
   302  	assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[8], "got: %s", querypb.Type(prepData.ParamsType[8]))
   303  	assert.EqualValues(t, querypb.Type_INT64, prepData.ParamsType[9], "got: %s", querypb.Type(prepData.ParamsType[9]))
   304  	assert.EqualValues(t, querypb.Type_INT64, prepData.ParamsType[10], "got: %s", querypb.Type(prepData.ParamsType[10]))
   305  	assert.EqualValues(t, querypb.Type_DECIMAL, prepData.ParamsType[11], "got: %s", querypb.Type(prepData.ParamsType[11]))
   306  	assert.EqualValues(t, querypb.Type_FLOAT32, prepData.ParamsType[12], "got: %s", querypb.Type(prepData.ParamsType[12]))
   307  	assert.EqualValues(t, querypb.Type_FLOAT64, prepData.ParamsType[13], "got: %s", querypb.Type(prepData.ParamsType[13]))
   308  	assert.EqualValues(t, querypb.Type_DATE, prepData.ParamsType[14], "got: %s", querypb.Type(prepData.ParamsType[14]))
   309  	assert.EqualValues(t, querypb.Type_DATETIME, prepData.ParamsType[15], "got: %s", querypb.Type(prepData.ParamsType[15]))
   310  	assert.EqualValues(t, querypb.Type_TIMESTAMP, prepData.ParamsType[16], "got: %s", querypb.Type(prepData.ParamsType[16]))
   311  	assert.EqualValues(t, querypb.Type_TIME, prepData.ParamsType[17], "got: %s", querypb.Type(prepData.ParamsType[17]))
   312  
   313  	// this is year but in binary it is changed to varbinary
   314  	assert.EqualValues(t, querypb.Type_VARBINARY, prepData.ParamsType[18], "got: %s", querypb.Type(prepData.ParamsType[18]))
   315  
   316  	assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[19], "got: %s", querypb.Type(prepData.ParamsType[19]))
   317  	assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[20], "got: %s", querypb.Type(prepData.ParamsType[20]))
   318  	assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[21], "got: %s", querypb.Type(prepData.ParamsType[21]))
   319  	assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[22], "got: %s", querypb.Type(prepData.ParamsType[22]))
   320  	assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[23], "got: %s", querypb.Type(prepData.ParamsType[23]))
   321  	assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[24], "got: %s", querypb.Type(prepData.ParamsType[24]))
   322  	assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[25], "got: %s", querypb.Type(prepData.ParamsType[25]))
   323  	assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[26], "got: %s", querypb.Type(prepData.ParamsType[26]))
   324  	assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[27], "got: %s", querypb.Type(prepData.ParamsType[27]))
   325  	assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[28], "got: %s", querypb.Type(prepData.ParamsType[28]))
   326  }
   327  
   328  func TestComStmtClose(t *testing.T) {
   329  	listener, sConn, cConn := createSocketPair(t)
   330  	defer func() {
   331  		listener.Close()
   332  		sConn.Close()
   333  		cConn.Close()
   334  	}()
   335  
   336  	prepare, result := MockPrepareData(t)
   337  	cConn.PrepareData = make(map[uint32]*PrepareData)
   338  	cConn.PrepareData[prepare.StatementID] = prepare
   339  	if err := cConn.writePrepare(result.Fields, prepare); err != nil {
   340  		t.Fatalf("writePrepare failed: %v", err)
   341  	}
   342  
   343  	// Since there's no writeComStmtClose, we'll write a prepareStmt and check if we can read the StatementID
   344  	data, err := sConn.ReadPacket()
   345  	if err != nil || len(data) == 0 {
   346  		t.Fatalf("sConn.ReadPacket - ComStmtClose failed: %v %v", data, err)
   347  	}
   348  	stmtID, ok := sConn.parseComStmtClose(data)
   349  	require.True(t, ok, "parseComStmtClose failed")
   350  	require.Equal(t, prepare.StatementID, stmtID, "Received incorrect value, want: %v, got: %v", uint32(data[1]), prepare.StatementID)
   351  
   352  }
   353  
   354  // This test has been added to verify that IO errors in a connection lead to SQL Server lost errors
   355  // So that we end up closing the connection higher up the stack and not reusing it.
   356  // This test was added in response to a panic that was run into.
   357  func TestSQLErrorOnServerClose(t *testing.T) {
   358  	// Create socket pair for the server and client
   359  	listener, sConn, cConn := createSocketPair(t)
   360  	defer func() {
   361  		listener.Close()
   362  		sConn.Close()
   363  		cConn.Close()
   364  	}()
   365  
   366  	err := cConn.WriteComQuery("close before rows read")
   367  	require.NoError(t, err)
   368  
   369  	handler := &testRun{t: t}
   370  	_ = sConn.handleNextCommand(handler)
   371  
   372  	// From the server we will receive a field packet which the client will read
   373  	// At that point, if the server crashes and closes the connection.
   374  	// We should be getting a Connection lost error.
   375  	_, _, _, err = cConn.ReadQueryResult(100, true)
   376  	require.Error(t, err)
   377  	require.True(t, IsConnLostDuringQuery(err), err.Error())
   378  }
   379  
   380  func TestQueries(t *testing.T) {
   381  	listener, sConn, cConn := createSocketPair(t)
   382  	defer func() {
   383  		listener.Close()
   384  		sConn.Close()
   385  		cConn.Close()
   386  	}()
   387  
   388  	// Smallest result
   389  	checkQuery(t, "tiny", sConn, cConn, &sqltypes.Result{})
   390  
   391  	// Typical Insert result
   392  	checkQuery(t, "insert", sConn, cConn, &sqltypes.Result{
   393  		RowsAffected: 0x8010203040506070,
   394  		InsertID:     0x0102030405060708,
   395  	})
   396  
   397  	// Typical Select with TYPE_AND_NAME.
   398  	// One value is also NULL.
   399  	checkQuery(t, "type and name", sConn, cConn, &sqltypes.Result{
   400  		Fields: []*querypb.Field{
   401  			{
   402  				Name: "id",
   403  				Type: querypb.Type_INT32,
   404  			},
   405  			{
   406  				Name: "name",
   407  				Type: querypb.Type_VARCHAR,
   408  			},
   409  		},
   410  		Rows: [][]sqltypes.Value{
   411  			{
   412  				sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")),
   413  				sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")),
   414  			},
   415  			{
   416  				sqltypes.MakeTrusted(querypb.Type_INT32, []byte("20")),
   417  				sqltypes.NULL,
   418  			},
   419  		},
   420  	})
   421  
   422  	// Typical Select with TYPE_AND_NAME.
   423  	// All types are represented.
   424  	// One row has all NULL values.
   425  	checkQuery(t, "all types", sConn, cConn, &sqltypes.Result{
   426  		Fields: []*querypb.Field{
   427  			{Name: "Type_INT8     ", Type: querypb.Type_INT8},
   428  			{Name: "Type_UINT8    ", Type: querypb.Type_UINT8},
   429  			{Name: "Type_INT16    ", Type: querypb.Type_INT16},
   430  			{Name: "Type_UINT16   ", Type: querypb.Type_UINT16},
   431  			{Name: "Type_INT24    ", Type: querypb.Type_INT24},
   432  			{Name: "Type_UINT24   ", Type: querypb.Type_UINT24},
   433  			{Name: "Type_INT32    ", Type: querypb.Type_INT32},
   434  			{Name: "Type_UINT32   ", Type: querypb.Type_UINT32},
   435  			{Name: "Type_INT64    ", Type: querypb.Type_INT64},
   436  			{Name: "Type_UINT64   ", Type: querypb.Type_UINT64},
   437  			{Name: "Type_FLOAT32  ", Type: querypb.Type_FLOAT32},
   438  			{Name: "Type_FLOAT64  ", Type: querypb.Type_FLOAT64},
   439  			{Name: "Type_TIMESTAMP", Type: querypb.Type_TIMESTAMP},
   440  			{Name: "Type_DATE     ", Type: querypb.Type_DATE},
   441  			{Name: "Type_TIME     ", Type: querypb.Type_TIME},
   442  			{Name: "Type_DATETIME ", Type: querypb.Type_DATETIME},
   443  			{Name: "Type_YEAR     ", Type: querypb.Type_YEAR},
   444  			{Name: "Type_DECIMAL  ", Type: querypb.Type_DECIMAL},
   445  			{Name: "Type_TEXT     ", Type: querypb.Type_TEXT},
   446  			{Name: "Type_BLOB     ", Type: querypb.Type_BLOB},
   447  			{Name: "Type_VARCHAR  ", Type: querypb.Type_VARCHAR},
   448  			{Name: "Type_VARBINARY", Type: querypb.Type_VARBINARY},
   449  			{Name: "Type_CHAR     ", Type: querypb.Type_CHAR},
   450  			{Name: "Type_BINARY   ", Type: querypb.Type_BINARY},
   451  			{Name: "Type_BIT      ", Type: querypb.Type_BIT},
   452  			{Name: "Type_ENUM     ", Type: querypb.Type_ENUM},
   453  			{Name: "Type_SET      ", Type: querypb.Type_SET},
   454  			// Skip TUPLE, not possible in Result.
   455  			{Name: "Type_GEOMETRY ", Type: querypb.Type_GEOMETRY},
   456  			{Name: "Type_JSON     ", Type: querypb.Type_JSON},
   457  		},
   458  		Rows: [][]sqltypes.Value{
   459  			{
   460  				sqltypes.MakeTrusted(querypb.Type_INT8, []byte("Type_INT8")),
   461  				sqltypes.MakeTrusted(querypb.Type_UINT8, []byte("Type_UINT8")),
   462  				sqltypes.MakeTrusted(querypb.Type_INT16, []byte("Type_INT16")),
   463  				sqltypes.MakeTrusted(querypb.Type_UINT16, []byte("Type_UINT16")),
   464  				sqltypes.MakeTrusted(querypb.Type_INT24, []byte("Type_INT24")),
   465  				sqltypes.MakeTrusted(querypb.Type_UINT24, []byte("Type_UINT24")),
   466  				sqltypes.MakeTrusted(querypb.Type_INT32, []byte("Type_INT32")),
   467  				sqltypes.MakeTrusted(querypb.Type_UINT32, []byte("Type_UINT32")),
   468  				sqltypes.MakeTrusted(querypb.Type_INT64, []byte("Type_INT64")),
   469  				sqltypes.MakeTrusted(querypb.Type_UINT64, []byte("Type_UINT64")),
   470  				sqltypes.MakeTrusted(querypb.Type_FLOAT32, []byte("Type_FLOAT32")),
   471  				sqltypes.MakeTrusted(querypb.Type_FLOAT64, []byte("Type_FLOAT64")),
   472  				sqltypes.MakeTrusted(querypb.Type_TIMESTAMP, []byte("Type_TIMESTAMP")),
   473  				sqltypes.MakeTrusted(querypb.Type_DATE, []byte("Type_DATE")),
   474  				sqltypes.MakeTrusted(querypb.Type_TIME, []byte("Type_TIME")),
   475  				sqltypes.MakeTrusted(querypb.Type_DATETIME, []byte("Type_DATETIME")),
   476  				sqltypes.MakeTrusted(querypb.Type_YEAR, []byte("Type_YEAR")),
   477  				sqltypes.MakeTrusted(querypb.Type_DECIMAL, []byte("Type_DECIMAL")),
   478  				sqltypes.MakeTrusted(querypb.Type_TEXT, []byte("Type_TEXT")),
   479  				sqltypes.MakeTrusted(querypb.Type_BLOB, []byte("Type_BLOB")),
   480  				sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("Type_VARCHAR")),
   481  				sqltypes.MakeTrusted(querypb.Type_VARBINARY, []byte("Type_VARBINARY")),
   482  				sqltypes.MakeTrusted(querypb.Type_CHAR, []byte("Type_CHAR")),
   483  				sqltypes.MakeTrusted(querypb.Type_BINARY, []byte("Type_BINARY")),
   484  				sqltypes.MakeTrusted(querypb.Type_BIT, []byte("Type_BIT")),
   485  				sqltypes.MakeTrusted(querypb.Type_ENUM, []byte("Type_ENUM")),
   486  				sqltypes.MakeTrusted(querypb.Type_SET, []byte("Type_SET")),
   487  				sqltypes.MakeTrusted(querypb.Type_GEOMETRY, []byte("Type_GEOMETRY")),
   488  				sqltypes.MakeTrusted(querypb.Type_JSON, []byte("Type_JSON")),
   489  			},
   490  			{
   491  				sqltypes.NULL,
   492  				sqltypes.NULL,
   493  				sqltypes.NULL,
   494  				sqltypes.NULL,
   495  				sqltypes.NULL,
   496  				sqltypes.NULL,
   497  				sqltypes.NULL,
   498  				sqltypes.NULL,
   499  				sqltypes.NULL,
   500  				sqltypes.NULL,
   501  				sqltypes.NULL,
   502  				sqltypes.NULL,
   503  				sqltypes.NULL,
   504  				sqltypes.NULL,
   505  				sqltypes.NULL,
   506  				sqltypes.NULL,
   507  				sqltypes.NULL,
   508  				sqltypes.NULL,
   509  				sqltypes.NULL,
   510  				sqltypes.NULL,
   511  				sqltypes.NULL,
   512  				sqltypes.NULL,
   513  				sqltypes.NULL,
   514  				sqltypes.NULL,
   515  				sqltypes.NULL,
   516  				sqltypes.NULL,
   517  				sqltypes.NULL,
   518  				sqltypes.NULL,
   519  				sqltypes.NULL,
   520  			},
   521  		},
   522  	})
   523  
   524  	// Typical Select with TYPE_AND_NAME.
   525  	// First value first column is an empty string, so it's encoded as 0.
   526  	checkQuery(t, "first empty string", sConn, cConn, &sqltypes.Result{
   527  		Fields: []*querypb.Field{
   528  			{
   529  				Name: "name",
   530  				Type: querypb.Type_VARCHAR,
   531  			},
   532  		},
   533  		Rows: [][]sqltypes.Value{
   534  			{
   535  				sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("")),
   536  			},
   537  			{
   538  				sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")),
   539  			},
   540  		},
   541  	})
   542  
   543  	// Typical Select with TYPE_ONLY.
   544  	checkQuery(t, "type only", sConn, cConn, &sqltypes.Result{
   545  		Fields: []*querypb.Field{
   546  			{
   547  				Type: querypb.Type_INT64,
   548  			},
   549  		},
   550  		Rows: [][]sqltypes.Value{
   551  			{
   552  				sqltypes.MakeTrusted(querypb.Type_INT64, []byte("10")),
   553  			},
   554  			{
   555  				sqltypes.MakeTrusted(querypb.Type_INT64, []byte("20")),
   556  			},
   557  		},
   558  	})
   559  
   560  	// Typical Select with ALL.
   561  	checkQuery(t, "complete", sConn, cConn, &sqltypes.Result{
   562  		Fields: []*querypb.Field{
   563  			{
   564  				Type:         querypb.Type_INT64,
   565  				Name:         "cool column name",
   566  				Table:        "table name",
   567  				OrgTable:     "org table",
   568  				Database:     "fine db",
   569  				OrgName:      "crazy org",
   570  				ColumnLength: 0x80020304,
   571  				Charset:      0x1234,
   572  				Decimals:     36,
   573  				Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG |
   574  					querypb.MySqlFlag_PRI_KEY_FLAG |
   575  					querypb.MySqlFlag_PART_KEY_FLAG |
   576  					querypb.MySqlFlag_NUM_FLAG),
   577  			},
   578  		},
   579  		Rows: [][]sqltypes.Value{
   580  			{
   581  				sqltypes.MakeTrusted(querypb.Type_INT64, []byte("10")),
   582  			},
   583  			{
   584  				sqltypes.MakeTrusted(querypb.Type_INT64, []byte("20")),
   585  			},
   586  			{
   587  				sqltypes.MakeTrusted(querypb.Type_INT64, []byte("30")),
   588  			},
   589  		},
   590  	})
   591  }
   592  
   593  func checkQuery(t *testing.T, query string, sConn, cConn *Conn, result *sqltypes.Result) {
   594  	// The protocol depends on the CapabilityClientDeprecateEOF flag.
   595  	// So we want to test both cases.
   596  
   597  	sConn.Capabilities = 0
   598  	cConn.Capabilities = 0
   599  	checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, false /* warnings */)
   600  	checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, true /* allRows */, false /* warnings */)
   601  	checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, false /* allRows */, false /* warnings */)
   602  	checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, false /* allRows */, false /* warnings */)
   603  
   604  	checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, true /* warnings */)
   605  
   606  	sConn.Capabilities = CapabilityClientDeprecateEOF
   607  	cConn.Capabilities = CapabilityClientDeprecateEOF
   608  	checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, false /* warnings */)
   609  	checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, true /* allRows */, false /* warnings */)
   610  	checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, false /* allRows */, false /* warnings */)
   611  	checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, false /* allRows */, false /* warnings */)
   612  
   613  	checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, true /* warnings */)
   614  }
   615  
   616  func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result *sqltypes.Result, wantfields, allRows, warnings bool) {
   617  
   618  	if sConn.Capabilities&CapabilityClientDeprecateEOF > 0 {
   619  		query += " NOEOF"
   620  	} else {
   621  		query += " EOF"
   622  	}
   623  	if wantfields {
   624  		query += " FIELDS"
   625  	} else {
   626  		query += " NOFIELDS"
   627  	}
   628  	if allRows {
   629  		query += " ALL"
   630  	} else {
   631  		query += " PARTIAL"
   632  	}
   633  
   634  	var warningCount uint16
   635  	if warnings {
   636  		query += " WARNINGS"
   637  		warningCount = 99
   638  	} else {
   639  		query += " NOWARNINGS"
   640  	}
   641  
   642  	var fatalError string
   643  	// Use a go routine to run ExecuteFetch.
   644  	wg := sync.WaitGroup{}
   645  	wg.Add(1)
   646  	go func() {
   647  		defer wg.Done()
   648  
   649  		maxrows := 10000
   650  		if !allRows {
   651  			// Asking for just one row max. The results that have more will fail.
   652  			maxrows = 1
   653  		}
   654  		got, gotWarnings, err := cConn.ExecuteFetchWithWarningCount(query, maxrows, wantfields)
   655  		if !allRows && len(result.Rows) > 1 {
   656  			require.ErrorContains(t, err, "Row count exceeded")
   657  			return
   658  		}
   659  		if err != nil {
   660  			fatalError = fmt.Sprintf("executeFetch failed: %v", err)
   661  			return
   662  		}
   663  		expected := *result
   664  		if !wantfields {
   665  			expected.Fields = nil
   666  		}
   667  		if !got.Equal(&expected) {
   668  			for i, f := range got.Fields {
   669  				if i < len(expected.Fields) && !proto.Equal(f, expected.Fields[i]) {
   670  					t.Logf("Got      field(%v) = %v", i, f)
   671  					t.Logf("Expected field(%v) = %v", i, expected.Fields[i])
   672  				}
   673  			}
   674  			fatalError = fmt.Sprintf("ExecuteFetch(wantfields=%v) returned:\n%v\nBut was expecting:\n%v", wantfields, got, expected)
   675  			return
   676  		}
   677  
   678  		if gotWarnings != warningCount {
   679  			t.Errorf("ExecuteFetch(%v) expected %v warnings got %v", query, warningCount, gotWarnings)
   680  			return
   681  		}
   682  
   683  		// Test ExecuteStreamFetch, build a Result.
   684  		expected = *result
   685  		if err := cConn.ExecuteStreamFetch(query); err != nil {
   686  			fatalError = fmt.Sprintf("ExecuteStreamFetch(%v) failed: %v", query, err)
   687  			return
   688  		}
   689  		got = &sqltypes.Result{}
   690  		got.RowsAffected = result.RowsAffected
   691  		got.InsertID = result.InsertID
   692  		got.Fields, err = cConn.Fields()
   693  		if err != nil {
   694  			fatalError = fmt.Sprintf("Fields(%v) failed: %v", query, err)
   695  			return
   696  		}
   697  		if len(got.Fields) == 0 {
   698  			got.Fields = nil
   699  		}
   700  		for {
   701  			row, err := cConn.FetchNext(nil)
   702  			if err != nil {
   703  				fatalError = fmt.Sprintf("FetchNext(%v) failed: %v", query, err)
   704  				return
   705  			}
   706  			if row == nil {
   707  				// Done.
   708  				break
   709  			}
   710  			got.Rows = append(got.Rows, row)
   711  		}
   712  		cConn.CloseResult()
   713  
   714  		if !got.Equal(&expected) {
   715  			for i, f := range got.Fields {
   716  				if i < len(expected.Fields) && !proto.Equal(f, expected.Fields[i]) {
   717  					t.Logf("========== Got      field(%v) = %v", i, f)
   718  					t.Logf("========== Expected field(%v) = %v", i, expected.Fields[i])
   719  				}
   720  			}
   721  			for i, row := range got.Rows {
   722  				if i < len(expected.Rows) && !reflect.DeepEqual(row, expected.Rows[i]) {
   723  					t.Logf("========== Got      row(%v) = %v", i, RowString(row))
   724  					t.Logf("========== Expected row(%v) = %v", i, RowString(expected.Rows[i]))
   725  				}
   726  			}
   727  			if expected.RowsAffected != got.RowsAffected {
   728  				t.Logf("========== Got      RowsAffected = %v", got.RowsAffected)
   729  				t.Logf("========== Expected RowsAffected = %v", expected.RowsAffected)
   730  			}
   731  			t.Errorf("\nExecuteStreamFetch(%v) returned:\n%+v\nBut was expecting:\n%+v\n", query, got, &expected)
   732  		}
   733  	}()
   734  
   735  	// The other side gets the request, and sends the result.
   736  	// Twice, once for ExecuteFetch, once for ExecuteStreamFetch.
   737  	count := 2
   738  	if !allRows && len(result.Rows) > 1 {
   739  		// short-circuit one test, the go routine returned and didn't
   740  		// do the streaming query.
   741  		count--
   742  	}
   743  
   744  	handler := testHandler{
   745  		result:   result,
   746  		warnings: warningCount,
   747  	}
   748  
   749  	for i := 0; i < count; i++ {
   750  		kontinue := sConn.handleNextCommand(&handler)
   751  		require.True(t, kontinue, "error handling command: %d", i)
   752  
   753  	}
   754  
   755  	wg.Wait()
   756  	require.Equal(t, "", fatalError, fatalError)
   757  
   758  }
   759  
   760  // nolint
   761  func writeResult(conn *Conn, result *sqltypes.Result) error {
   762  	if len(result.Fields) == 0 {
   763  		return conn.writeOKPacket(&PacketOK{
   764  			affectedRows: result.RowsAffected,
   765  			lastInsertID: result.InsertID,
   766  			statusFlags:  conn.StatusFlags,
   767  			warnings:     0,
   768  		})
   769  	}
   770  	if err := conn.writeFields(result); err != nil {
   771  		return err
   772  	}
   773  	if err := conn.writeRows(result); err != nil {
   774  		return err
   775  	}
   776  	return conn.writeEndResult(false, 0, 0, 0)
   777  }
   778  
   779  func RowString(row []sqltypes.Value) string {
   780  	l := len(row)
   781  	result := fmt.Sprintf("%v values:", l)
   782  	for _, val := range row {
   783  		result += fmt.Sprintf(" %v", val)
   784  	}
   785  	return result
   786  }