vitess.io/vitess@v0.16.2/go/mysql/query_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package mysql 18 19 import ( 20 "fmt" 21 "reflect" 22 "sync" 23 "testing" 24 25 "google.golang.org/protobuf/proto" 26 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/require" 29 30 "vitess.io/vitess/go/sqltypes" 31 32 querypb "vitess.io/vitess/go/vt/proto/query" 33 ) 34 35 // Utility function to write sql query as packets to test parseComPrepare 36 func preparePacket(t *testing.T, query string) []byte { 37 data := make([]byte, len(query)+1+packetHeaderSize) 38 // Not sure if it makes a difference 39 pos := packetHeaderSize 40 pos = writeByte(data, pos, ComPrepare) 41 copy(data[pos:], query) 42 return data 43 } 44 45 func MockPrepareData(t *testing.T) (*PrepareData, *sqltypes.Result) { 46 sql := "select * from test_table where id = ?" 47 48 result := &sqltypes.Result{ 49 Fields: []*querypb.Field{ 50 { 51 Name: "id", 52 Type: querypb.Type_INT32, 53 }, 54 }, 55 Rows: [][]sqltypes.Value{ 56 { 57 sqltypes.MakeTrusted(querypb.Type_INT32, []byte("1")), 58 }, 59 }, 60 RowsAffected: 1, 61 } 62 63 prepare := &PrepareData{ 64 StatementID: 18, 65 PrepareStmt: sql, 66 ParamsCount: 1, 67 ParamsType: []int32{263}, 68 ColumnNames: []string{"id"}, 69 BindVars: map[string]*querypb.BindVariable{ 70 "v1": sqltypes.Int32BindVariable(10), 71 }, 72 } 73 74 return prepare, result 75 } 76 77 func TestComInitDB(t *testing.T) { 78 listener, sConn, cConn := createSocketPair(t) 79 defer func() { 80 listener.Close() 81 sConn.Close() 82 cConn.Close() 83 }() 84 85 // Write ComInitDB packet, read it, compare. 86 if err := cConn.writeComInitDB("my_db"); err != nil { 87 t.Fatalf("writeComInitDB failed: %v", err) 88 } 89 data, err := sConn.ReadPacket() 90 if err != nil || len(data) == 0 || data[0] != ComInitDB { 91 t.Fatalf("sConn.ReadPacket - ComInitDB failed: %v %v", data, err) 92 } 93 db := sConn.parseComInitDB(data) 94 assert.Equal(t, "my_db", db, "parseComInitDB returned unexpected data: %v", db) 95 96 } 97 98 func TestComSetOption(t *testing.T) { 99 listener, sConn, cConn := createSocketPair(t) 100 defer func() { 101 listener.Close() 102 sConn.Close() 103 cConn.Close() 104 }() 105 106 // Write ComSetOption packet, read it, compare. 107 if err := cConn.writeComSetOption(1); err != nil { 108 t.Fatalf("writeComSetOption failed: %v", err) 109 } 110 data, err := sConn.ReadPacket() 111 if err != nil || len(data) == 0 || data[0] != ComSetOption { 112 t.Fatalf("sConn.ReadPacket - ComSetOption failed: %v %v", data, err) 113 } 114 operation, ok := sConn.parseComSetOption(data) 115 require.True(t, ok, "parseComSetOption failed unexpectedly") 116 assert.Equal(t, uint16(1), operation, "parseComSetOption returned unexpected data: %v", operation) 117 118 } 119 120 func TestComStmtPrepare(t *testing.T) { 121 listener, sConn, cConn := createSocketPair(t) 122 defer func() { 123 listener.Close() 124 sConn.Close() 125 cConn.Close() 126 }() 127 128 sql := "select * from test_table where id = ?" 129 mockData := preparePacket(t, sql) 130 131 if err := cConn.writePacket(mockData); err != nil { 132 t.Fatalf("writePacket failed: %v", err) 133 } 134 135 data, err := sConn.ReadPacket() 136 require.NoError(t, err, "sConn.ReadPacket - ComPrepare failed: %v", err) 137 138 parsedQuery := sConn.parseComPrepare(data) 139 require.Equal(t, sql, parsedQuery, "Received incorrect query, want: %v, got: %v", sql, parsedQuery) 140 141 prepare, result := MockPrepareData(t) 142 sConn.PrepareData = make(map[uint32]*PrepareData) 143 sConn.PrepareData[prepare.StatementID] = prepare 144 145 // write the response to the client 146 if err := sConn.writePrepare(result.Fields, prepare); err != nil { 147 t.Fatalf("sConn.writePrepare failed: %v", err) 148 } 149 150 resp, err := cConn.ReadPacket() 151 require.NoError(t, err, "cConn.ReadPacket failed: %v", err) 152 require.Equal(t, prepare.StatementID, uint32(resp[1]), "Received incorrect Statement ID, want: %v, got: %v", prepare.StatementID, resp[1]) 153 154 } 155 156 func TestComStmtPrepareUpdStmt(t *testing.T) { 157 listener, sConn, cConn := createSocketPair(t) 158 defer func() { 159 listener.Close() 160 sConn.Close() 161 cConn.Close() 162 }() 163 164 sql := "UPDATE test SET __bit = ?, __tinyInt = ?, __tinyIntU = ?, __smallInt = ?, __smallIntU = ?, __mediumInt = ?, __mediumIntU = ?, __int = ?, __intU = ?, __bigInt = ?, __bigIntU = ?, __decimal = ?, __float = ?, __double = ?, __date = ?, __datetime = ?, __timestamp = ?, __time = ?, __year = ?, __char = ?, __varchar = ?, __binary = ?, __varbinary = ?, __tinyblob = ?, __tinytext = ?, __blob = ?, __text = ?, __enum = ?, __set = ? WHERE __id = 0" 165 mockData := preparePacket(t, sql) 166 167 err := cConn.writePacket(mockData) 168 require.NoError(t, err, "writePacket failed") 169 170 data, err := sConn.ReadPacket() 171 require.NoError(t, err, "sConn.ReadPacket - ComPrepare failed") 172 173 parsedQuery := sConn.parseComPrepare(data) 174 require.Equal(t, sql, parsedQuery, "Received incorrect query") 175 176 paramsCount := uint16(29) 177 prepare := &PrepareData{ 178 StatementID: 1, 179 PrepareStmt: sql, 180 ParamsCount: paramsCount, 181 } 182 sConn.PrepareData = make(map[uint32]*PrepareData) 183 sConn.PrepareData[prepare.StatementID] = prepare 184 185 // write the response to the client 186 err = sConn.writePrepare(nil, prepare) 187 require.NoError(t, err, "sConn.writePrepare failed") 188 189 resp, err := cConn.ReadPacket() 190 require.NoError(t, err, "cConn.ReadPacket failed") 191 require.EqualValues(t, prepare.StatementID, resp[1], "Received incorrect Statement ID") 192 193 for i := uint16(0); i < paramsCount; i++ { 194 resp, err := cConn.ReadPacket() 195 require.NoError(t, err, "cConn.ReadPacket failed") 196 require.EqualValues(t, 0xfd, resp[17], "Received incorrect Statement ID") 197 } 198 } 199 200 func TestComStmtSendLongData(t *testing.T) { 201 listener, sConn, cConn := createSocketPair(t) 202 defer func() { 203 listener.Close() 204 sConn.Close() 205 cConn.Close() 206 }() 207 208 prepare, result := MockPrepareData(t) 209 cConn.PrepareData = make(map[uint32]*PrepareData) 210 cConn.PrepareData[prepare.StatementID] = prepare 211 if err := cConn.writePrepare(result.Fields, prepare); err != nil { 212 t.Fatalf("writePrepare failed: %v", err) 213 } 214 215 // Since there's no writeComStmtSendLongData, we'll write a prepareStmt and check if we can read the StatementID 216 data, err := sConn.ReadPacket() 217 if err != nil || len(data) == 0 { 218 t.Fatalf("sConn.ReadPacket - ComStmtClose failed: %v %v", data, err) 219 } 220 stmtID, paramID, chunkData, ok := sConn.parseComStmtSendLongData(data) 221 require.True(t, ok, "parseComStmtSendLongData failed") 222 require.Equal(t, uint16(1), paramID, "Received incorrect ParamID, want %v, got %v:", paramID, 1) 223 require.Equal(t, prepare.StatementID, stmtID, "Received incorrect value, want: %v, got: %v", uint32(data[1]), prepare.StatementID) 224 // Check length of chunkData, Since its a subset of `data` and compare with it after we subtract the number of bytes that was read from it. 225 // sizeof(uint32) + sizeof(uint16) + 1 = 7 226 require.Equal(t, len(data)-7, len(chunkData), "Received bad chunkData") 227 228 } 229 230 func TestComStmtExecute(t *testing.T) { 231 listener, sConn, cConn := createSocketPair(t) 232 defer func() { 233 listener.Close() 234 sConn.Close() 235 cConn.Close() 236 }() 237 238 prepare, _ := MockPrepareData(t) 239 cConn.PrepareData = make(map[uint32]*PrepareData) 240 cConn.PrepareData[prepare.StatementID] = prepare 241 242 // This is simulated packets for `select * from test_table where id = ?` 243 data := []byte{23, 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 1, 1, 128, 1} 244 245 stmtID, _, err := sConn.parseComStmtExecute(cConn.PrepareData, data) 246 require.NoError(t, err, "parseComStmtExeute failed: %v", err) 247 require.Equal(t, uint32(18), stmtID, "Parsed incorrect values") 248 249 } 250 251 func TestComStmtExecuteUpdStmt(t *testing.T) { 252 listener, sConn, cConn := createSocketPair(t) 253 defer func() { 254 listener.Close() 255 sConn.Close() 256 cConn.Close() 257 }() 258 259 prepareDataMap := map[uint32]*PrepareData{ 260 1: { 261 StatementID: 1, 262 ParamsCount: 29, 263 ParamsType: make([]int32, 29), 264 BindVars: map[string]*querypb.BindVariable{}, 265 }} 266 267 // This is simulated packets for update query 268 data := []byte{ 269 0x29, 0x01, 0x00, 0x00, 0x17, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 270 0x00, 0x00, 0x01, 0x10, 0x00, 0x01, 0x00, 0x01, 0x80, 0x02, 0x00, 0x02, 0x80, 0x03, 0x00, 0x03, 271 0x80, 0x03, 0x00, 0x03, 0x80, 0x08, 0x00, 0x08, 0x80, 0x00, 0x00, 0x04, 0x00, 0x05, 0x00, 0x0a, 272 0x00, 0x0c, 0x00, 0x07, 0x00, 0x0b, 0x00, 0x0d, 0x80, 0xfe, 0x00, 0xfe, 0x00, 0xfc, 0x00, 0xfc, 273 0x00, 0xfc, 0x00, 0xfe, 0x00, 0xfc, 0x00, 0xfe, 0x00, 0xfe, 0x00, 0xfe, 0x00, 0x08, 0x00, 0x00, 274 0x00, 0x00, 0x00, 0x00, 0xaa, 0xe0, 0x80, 0xff, 0x00, 0x80, 0xff, 0xff, 0x00, 0x00, 0x80, 0xff, 275 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 276 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x15, 0x31, 0x32, 0x33, 277 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x30, 0x2e, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 278 0x38, 0x39, 0xd0, 0x0f, 0x49, 0x40, 0x44, 0x17, 0x41, 0x54, 0xfb, 0x21, 0x09, 0x40, 0x04, 0xe0, 279 0x07, 0x08, 0x08, 0x0b, 0xe0, 0x07, 0x08, 0x08, 0x11, 0x19, 0x3b, 0x00, 0x00, 0x00, 0x00, 0x0b, 280 0xe0, 0x07, 0x08, 0x08, 0x11, 0x19, 0x3b, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x01, 0x08, 0x00, 0x00, 281 0x00, 0x07, 0x3b, 0x3b, 0x00, 0x00, 0x00, 0x00, 0x04, 0x31, 0x39, 0x39, 0x39, 0x08, 0x31, 0x32, 282 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x0c, 0xe9, 0x9f, 0xa9, 0xe5, 0x86, 0xac, 0xe7, 0x9c, 0x9f, 283 0xe8, 0xb5, 0x9e, 0x08, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x08, 0x31, 0x32, 0x33, 284 0x34, 0x35, 0x36, 0x37, 0x38, 0x08, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x0c, 0xe9, 285 0x9f, 0xa9, 0xe5, 0x86, 0xac, 0xe7, 0x9c, 0x9f, 0xe8, 0xb5, 0x9e, 0x08, 0x31, 0x32, 0x33, 0x34, 286 0x35, 0x36, 0x37, 0x38, 0x0c, 0xe9, 0x9f, 0xa9, 0xe5, 0x86, 0xac, 0xe7, 0x9c, 0x9f, 0xe8, 0xb5, 287 0x9e, 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x66, 0x6f, 0x6f, 0x2c, 0x62, 0x61, 0x72} 288 289 stmtID, _, err := sConn.parseComStmtExecute(prepareDataMap, data[4:]) // first 4 are header 290 require.NoError(t, err) 291 require.EqualValues(t, 1, stmtID) 292 293 prepData := prepareDataMap[stmtID] 294 assert.EqualValues(t, querypb.Type_BIT, prepData.ParamsType[0], "got: %s", querypb.Type(prepData.ParamsType[0])) 295 assert.EqualValues(t, querypb.Type_INT8, prepData.ParamsType[1], "got: %s", querypb.Type(prepData.ParamsType[1])) 296 assert.EqualValues(t, querypb.Type_INT8, prepData.ParamsType[2], "got: %s", querypb.Type(prepData.ParamsType[2])) 297 assert.EqualValues(t, querypb.Type_INT16, prepData.ParamsType[3], "got: %s", querypb.Type(prepData.ParamsType[3])) 298 assert.EqualValues(t, querypb.Type_INT16, prepData.ParamsType[4], "got: %s", querypb.Type(prepData.ParamsType[4])) 299 assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[5], "got: %s", querypb.Type(prepData.ParamsType[5])) 300 assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[6], "got: %s", querypb.Type(prepData.ParamsType[6])) 301 assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[7], "got: %s", querypb.Type(prepData.ParamsType[7])) 302 assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[8], "got: %s", querypb.Type(prepData.ParamsType[8])) 303 assert.EqualValues(t, querypb.Type_INT64, prepData.ParamsType[9], "got: %s", querypb.Type(prepData.ParamsType[9])) 304 assert.EqualValues(t, querypb.Type_INT64, prepData.ParamsType[10], "got: %s", querypb.Type(prepData.ParamsType[10])) 305 assert.EqualValues(t, querypb.Type_DECIMAL, prepData.ParamsType[11], "got: %s", querypb.Type(prepData.ParamsType[11])) 306 assert.EqualValues(t, querypb.Type_FLOAT32, prepData.ParamsType[12], "got: %s", querypb.Type(prepData.ParamsType[12])) 307 assert.EqualValues(t, querypb.Type_FLOAT64, prepData.ParamsType[13], "got: %s", querypb.Type(prepData.ParamsType[13])) 308 assert.EqualValues(t, querypb.Type_DATE, prepData.ParamsType[14], "got: %s", querypb.Type(prepData.ParamsType[14])) 309 assert.EqualValues(t, querypb.Type_DATETIME, prepData.ParamsType[15], "got: %s", querypb.Type(prepData.ParamsType[15])) 310 assert.EqualValues(t, querypb.Type_TIMESTAMP, prepData.ParamsType[16], "got: %s", querypb.Type(prepData.ParamsType[16])) 311 assert.EqualValues(t, querypb.Type_TIME, prepData.ParamsType[17], "got: %s", querypb.Type(prepData.ParamsType[17])) 312 313 // this is year but in binary it is changed to varbinary 314 assert.EqualValues(t, querypb.Type_VARBINARY, prepData.ParamsType[18], "got: %s", querypb.Type(prepData.ParamsType[18])) 315 316 assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[19], "got: %s", querypb.Type(prepData.ParamsType[19])) 317 assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[20], "got: %s", querypb.Type(prepData.ParamsType[20])) 318 assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[21], "got: %s", querypb.Type(prepData.ParamsType[21])) 319 assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[22], "got: %s", querypb.Type(prepData.ParamsType[22])) 320 assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[23], "got: %s", querypb.Type(prepData.ParamsType[23])) 321 assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[24], "got: %s", querypb.Type(prepData.ParamsType[24])) 322 assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[25], "got: %s", querypb.Type(prepData.ParamsType[25])) 323 assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[26], "got: %s", querypb.Type(prepData.ParamsType[26])) 324 assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[27], "got: %s", querypb.Type(prepData.ParamsType[27])) 325 assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[28], "got: %s", querypb.Type(prepData.ParamsType[28])) 326 } 327 328 func TestComStmtClose(t *testing.T) { 329 listener, sConn, cConn := createSocketPair(t) 330 defer func() { 331 listener.Close() 332 sConn.Close() 333 cConn.Close() 334 }() 335 336 prepare, result := MockPrepareData(t) 337 cConn.PrepareData = make(map[uint32]*PrepareData) 338 cConn.PrepareData[prepare.StatementID] = prepare 339 if err := cConn.writePrepare(result.Fields, prepare); err != nil { 340 t.Fatalf("writePrepare failed: %v", err) 341 } 342 343 // Since there's no writeComStmtClose, we'll write a prepareStmt and check if we can read the StatementID 344 data, err := sConn.ReadPacket() 345 if err != nil || len(data) == 0 { 346 t.Fatalf("sConn.ReadPacket - ComStmtClose failed: %v %v", data, err) 347 } 348 stmtID, ok := sConn.parseComStmtClose(data) 349 require.True(t, ok, "parseComStmtClose failed") 350 require.Equal(t, prepare.StatementID, stmtID, "Received incorrect value, want: %v, got: %v", uint32(data[1]), prepare.StatementID) 351 352 } 353 354 // This test has been added to verify that IO errors in a connection lead to SQL Server lost errors 355 // So that we end up closing the connection higher up the stack and not reusing it. 356 // This test was added in response to a panic that was run into. 357 func TestSQLErrorOnServerClose(t *testing.T) { 358 // Create socket pair for the server and client 359 listener, sConn, cConn := createSocketPair(t) 360 defer func() { 361 listener.Close() 362 sConn.Close() 363 cConn.Close() 364 }() 365 366 err := cConn.WriteComQuery("close before rows read") 367 require.NoError(t, err) 368 369 handler := &testRun{t: t} 370 _ = sConn.handleNextCommand(handler) 371 372 // From the server we will receive a field packet which the client will read 373 // At that point, if the server crashes and closes the connection. 374 // We should be getting a Connection lost error. 375 _, _, _, err = cConn.ReadQueryResult(100, true) 376 require.Error(t, err) 377 require.True(t, IsConnLostDuringQuery(err), err.Error()) 378 } 379 380 func TestQueries(t *testing.T) { 381 listener, sConn, cConn := createSocketPair(t) 382 defer func() { 383 listener.Close() 384 sConn.Close() 385 cConn.Close() 386 }() 387 388 // Smallest result 389 checkQuery(t, "tiny", sConn, cConn, &sqltypes.Result{}) 390 391 // Typical Insert result 392 checkQuery(t, "insert", sConn, cConn, &sqltypes.Result{ 393 RowsAffected: 0x8010203040506070, 394 InsertID: 0x0102030405060708, 395 }) 396 397 // Typical Select with TYPE_AND_NAME. 398 // One value is also NULL. 399 checkQuery(t, "type and name", sConn, cConn, &sqltypes.Result{ 400 Fields: []*querypb.Field{ 401 { 402 Name: "id", 403 Type: querypb.Type_INT32, 404 }, 405 { 406 Name: "name", 407 Type: querypb.Type_VARCHAR, 408 }, 409 }, 410 Rows: [][]sqltypes.Value{ 411 { 412 sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")), 413 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")), 414 }, 415 { 416 sqltypes.MakeTrusted(querypb.Type_INT32, []byte("20")), 417 sqltypes.NULL, 418 }, 419 }, 420 }) 421 422 // Typical Select with TYPE_AND_NAME. 423 // All types are represented. 424 // One row has all NULL values. 425 checkQuery(t, "all types", sConn, cConn, &sqltypes.Result{ 426 Fields: []*querypb.Field{ 427 {Name: "Type_INT8 ", Type: querypb.Type_INT8}, 428 {Name: "Type_UINT8 ", Type: querypb.Type_UINT8}, 429 {Name: "Type_INT16 ", Type: querypb.Type_INT16}, 430 {Name: "Type_UINT16 ", Type: querypb.Type_UINT16}, 431 {Name: "Type_INT24 ", Type: querypb.Type_INT24}, 432 {Name: "Type_UINT24 ", Type: querypb.Type_UINT24}, 433 {Name: "Type_INT32 ", Type: querypb.Type_INT32}, 434 {Name: "Type_UINT32 ", Type: querypb.Type_UINT32}, 435 {Name: "Type_INT64 ", Type: querypb.Type_INT64}, 436 {Name: "Type_UINT64 ", Type: querypb.Type_UINT64}, 437 {Name: "Type_FLOAT32 ", Type: querypb.Type_FLOAT32}, 438 {Name: "Type_FLOAT64 ", Type: querypb.Type_FLOAT64}, 439 {Name: "Type_TIMESTAMP", Type: querypb.Type_TIMESTAMP}, 440 {Name: "Type_DATE ", Type: querypb.Type_DATE}, 441 {Name: "Type_TIME ", Type: querypb.Type_TIME}, 442 {Name: "Type_DATETIME ", Type: querypb.Type_DATETIME}, 443 {Name: "Type_YEAR ", Type: querypb.Type_YEAR}, 444 {Name: "Type_DECIMAL ", Type: querypb.Type_DECIMAL}, 445 {Name: "Type_TEXT ", Type: querypb.Type_TEXT}, 446 {Name: "Type_BLOB ", Type: querypb.Type_BLOB}, 447 {Name: "Type_VARCHAR ", Type: querypb.Type_VARCHAR}, 448 {Name: "Type_VARBINARY", Type: querypb.Type_VARBINARY}, 449 {Name: "Type_CHAR ", Type: querypb.Type_CHAR}, 450 {Name: "Type_BINARY ", Type: querypb.Type_BINARY}, 451 {Name: "Type_BIT ", Type: querypb.Type_BIT}, 452 {Name: "Type_ENUM ", Type: querypb.Type_ENUM}, 453 {Name: "Type_SET ", Type: querypb.Type_SET}, 454 // Skip TUPLE, not possible in Result. 455 {Name: "Type_GEOMETRY ", Type: querypb.Type_GEOMETRY}, 456 {Name: "Type_JSON ", Type: querypb.Type_JSON}, 457 }, 458 Rows: [][]sqltypes.Value{ 459 { 460 sqltypes.MakeTrusted(querypb.Type_INT8, []byte("Type_INT8")), 461 sqltypes.MakeTrusted(querypb.Type_UINT8, []byte("Type_UINT8")), 462 sqltypes.MakeTrusted(querypb.Type_INT16, []byte("Type_INT16")), 463 sqltypes.MakeTrusted(querypb.Type_UINT16, []byte("Type_UINT16")), 464 sqltypes.MakeTrusted(querypb.Type_INT24, []byte("Type_INT24")), 465 sqltypes.MakeTrusted(querypb.Type_UINT24, []byte("Type_UINT24")), 466 sqltypes.MakeTrusted(querypb.Type_INT32, []byte("Type_INT32")), 467 sqltypes.MakeTrusted(querypb.Type_UINT32, []byte("Type_UINT32")), 468 sqltypes.MakeTrusted(querypb.Type_INT64, []byte("Type_INT64")), 469 sqltypes.MakeTrusted(querypb.Type_UINT64, []byte("Type_UINT64")), 470 sqltypes.MakeTrusted(querypb.Type_FLOAT32, []byte("Type_FLOAT32")), 471 sqltypes.MakeTrusted(querypb.Type_FLOAT64, []byte("Type_FLOAT64")), 472 sqltypes.MakeTrusted(querypb.Type_TIMESTAMP, []byte("Type_TIMESTAMP")), 473 sqltypes.MakeTrusted(querypb.Type_DATE, []byte("Type_DATE")), 474 sqltypes.MakeTrusted(querypb.Type_TIME, []byte("Type_TIME")), 475 sqltypes.MakeTrusted(querypb.Type_DATETIME, []byte("Type_DATETIME")), 476 sqltypes.MakeTrusted(querypb.Type_YEAR, []byte("Type_YEAR")), 477 sqltypes.MakeTrusted(querypb.Type_DECIMAL, []byte("Type_DECIMAL")), 478 sqltypes.MakeTrusted(querypb.Type_TEXT, []byte("Type_TEXT")), 479 sqltypes.MakeTrusted(querypb.Type_BLOB, []byte("Type_BLOB")), 480 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("Type_VARCHAR")), 481 sqltypes.MakeTrusted(querypb.Type_VARBINARY, []byte("Type_VARBINARY")), 482 sqltypes.MakeTrusted(querypb.Type_CHAR, []byte("Type_CHAR")), 483 sqltypes.MakeTrusted(querypb.Type_BINARY, []byte("Type_BINARY")), 484 sqltypes.MakeTrusted(querypb.Type_BIT, []byte("Type_BIT")), 485 sqltypes.MakeTrusted(querypb.Type_ENUM, []byte("Type_ENUM")), 486 sqltypes.MakeTrusted(querypb.Type_SET, []byte("Type_SET")), 487 sqltypes.MakeTrusted(querypb.Type_GEOMETRY, []byte("Type_GEOMETRY")), 488 sqltypes.MakeTrusted(querypb.Type_JSON, []byte("Type_JSON")), 489 }, 490 { 491 sqltypes.NULL, 492 sqltypes.NULL, 493 sqltypes.NULL, 494 sqltypes.NULL, 495 sqltypes.NULL, 496 sqltypes.NULL, 497 sqltypes.NULL, 498 sqltypes.NULL, 499 sqltypes.NULL, 500 sqltypes.NULL, 501 sqltypes.NULL, 502 sqltypes.NULL, 503 sqltypes.NULL, 504 sqltypes.NULL, 505 sqltypes.NULL, 506 sqltypes.NULL, 507 sqltypes.NULL, 508 sqltypes.NULL, 509 sqltypes.NULL, 510 sqltypes.NULL, 511 sqltypes.NULL, 512 sqltypes.NULL, 513 sqltypes.NULL, 514 sqltypes.NULL, 515 sqltypes.NULL, 516 sqltypes.NULL, 517 sqltypes.NULL, 518 sqltypes.NULL, 519 sqltypes.NULL, 520 }, 521 }, 522 }) 523 524 // Typical Select with TYPE_AND_NAME. 525 // First value first column is an empty string, so it's encoded as 0. 526 checkQuery(t, "first empty string", sConn, cConn, &sqltypes.Result{ 527 Fields: []*querypb.Field{ 528 { 529 Name: "name", 530 Type: querypb.Type_VARCHAR, 531 }, 532 }, 533 Rows: [][]sqltypes.Value{ 534 { 535 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("")), 536 }, 537 { 538 sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("nice name")), 539 }, 540 }, 541 }) 542 543 // Typical Select with TYPE_ONLY. 544 checkQuery(t, "type only", sConn, cConn, &sqltypes.Result{ 545 Fields: []*querypb.Field{ 546 { 547 Type: querypb.Type_INT64, 548 }, 549 }, 550 Rows: [][]sqltypes.Value{ 551 { 552 sqltypes.MakeTrusted(querypb.Type_INT64, []byte("10")), 553 }, 554 { 555 sqltypes.MakeTrusted(querypb.Type_INT64, []byte("20")), 556 }, 557 }, 558 }) 559 560 // Typical Select with ALL. 561 checkQuery(t, "complete", sConn, cConn, &sqltypes.Result{ 562 Fields: []*querypb.Field{ 563 { 564 Type: querypb.Type_INT64, 565 Name: "cool column name", 566 Table: "table name", 567 OrgTable: "org table", 568 Database: "fine db", 569 OrgName: "crazy org", 570 ColumnLength: 0x80020304, 571 Charset: 0x1234, 572 Decimals: 36, 573 Flags: uint32(querypb.MySqlFlag_NOT_NULL_FLAG | 574 querypb.MySqlFlag_PRI_KEY_FLAG | 575 querypb.MySqlFlag_PART_KEY_FLAG | 576 querypb.MySqlFlag_NUM_FLAG), 577 }, 578 }, 579 Rows: [][]sqltypes.Value{ 580 { 581 sqltypes.MakeTrusted(querypb.Type_INT64, []byte("10")), 582 }, 583 { 584 sqltypes.MakeTrusted(querypb.Type_INT64, []byte("20")), 585 }, 586 { 587 sqltypes.MakeTrusted(querypb.Type_INT64, []byte("30")), 588 }, 589 }, 590 }) 591 } 592 593 func checkQuery(t *testing.T, query string, sConn, cConn *Conn, result *sqltypes.Result) { 594 // The protocol depends on the CapabilityClientDeprecateEOF flag. 595 // So we want to test both cases. 596 597 sConn.Capabilities = 0 598 cConn.Capabilities = 0 599 checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, false /* warnings */) 600 checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, true /* allRows */, false /* warnings */) 601 checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, false /* allRows */, false /* warnings */) 602 checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, false /* allRows */, false /* warnings */) 603 604 checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, true /* warnings */) 605 606 sConn.Capabilities = CapabilityClientDeprecateEOF 607 cConn.Capabilities = CapabilityClientDeprecateEOF 608 checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, false /* warnings */) 609 checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, true /* allRows */, false /* warnings */) 610 checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, false /* allRows */, false /* warnings */) 611 checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, false /* allRows */, false /* warnings */) 612 613 checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, true /* warnings */) 614 } 615 616 func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result *sqltypes.Result, wantfields, allRows, warnings bool) { 617 618 if sConn.Capabilities&CapabilityClientDeprecateEOF > 0 { 619 query += " NOEOF" 620 } else { 621 query += " EOF" 622 } 623 if wantfields { 624 query += " FIELDS" 625 } else { 626 query += " NOFIELDS" 627 } 628 if allRows { 629 query += " ALL" 630 } else { 631 query += " PARTIAL" 632 } 633 634 var warningCount uint16 635 if warnings { 636 query += " WARNINGS" 637 warningCount = 99 638 } else { 639 query += " NOWARNINGS" 640 } 641 642 var fatalError string 643 // Use a go routine to run ExecuteFetch. 644 wg := sync.WaitGroup{} 645 wg.Add(1) 646 go func() { 647 defer wg.Done() 648 649 maxrows := 10000 650 if !allRows { 651 // Asking for just one row max. The results that have more will fail. 652 maxrows = 1 653 } 654 got, gotWarnings, err := cConn.ExecuteFetchWithWarningCount(query, maxrows, wantfields) 655 if !allRows && len(result.Rows) > 1 { 656 require.ErrorContains(t, err, "Row count exceeded") 657 return 658 } 659 if err != nil { 660 fatalError = fmt.Sprintf("executeFetch failed: %v", err) 661 return 662 } 663 expected := *result 664 if !wantfields { 665 expected.Fields = nil 666 } 667 if !got.Equal(&expected) { 668 for i, f := range got.Fields { 669 if i < len(expected.Fields) && !proto.Equal(f, expected.Fields[i]) { 670 t.Logf("Got field(%v) = %v", i, f) 671 t.Logf("Expected field(%v) = %v", i, expected.Fields[i]) 672 } 673 } 674 fatalError = fmt.Sprintf("ExecuteFetch(wantfields=%v) returned:\n%v\nBut was expecting:\n%v", wantfields, got, expected) 675 return 676 } 677 678 if gotWarnings != warningCount { 679 t.Errorf("ExecuteFetch(%v) expected %v warnings got %v", query, warningCount, gotWarnings) 680 return 681 } 682 683 // Test ExecuteStreamFetch, build a Result. 684 expected = *result 685 if err := cConn.ExecuteStreamFetch(query); err != nil { 686 fatalError = fmt.Sprintf("ExecuteStreamFetch(%v) failed: %v", query, err) 687 return 688 } 689 got = &sqltypes.Result{} 690 got.RowsAffected = result.RowsAffected 691 got.InsertID = result.InsertID 692 got.Fields, err = cConn.Fields() 693 if err != nil { 694 fatalError = fmt.Sprintf("Fields(%v) failed: %v", query, err) 695 return 696 } 697 if len(got.Fields) == 0 { 698 got.Fields = nil 699 } 700 for { 701 row, err := cConn.FetchNext(nil) 702 if err != nil { 703 fatalError = fmt.Sprintf("FetchNext(%v) failed: %v", query, err) 704 return 705 } 706 if row == nil { 707 // Done. 708 break 709 } 710 got.Rows = append(got.Rows, row) 711 } 712 cConn.CloseResult() 713 714 if !got.Equal(&expected) { 715 for i, f := range got.Fields { 716 if i < len(expected.Fields) && !proto.Equal(f, expected.Fields[i]) { 717 t.Logf("========== Got field(%v) = %v", i, f) 718 t.Logf("========== Expected field(%v) = %v", i, expected.Fields[i]) 719 } 720 } 721 for i, row := range got.Rows { 722 if i < len(expected.Rows) && !reflect.DeepEqual(row, expected.Rows[i]) { 723 t.Logf("========== Got row(%v) = %v", i, RowString(row)) 724 t.Logf("========== Expected row(%v) = %v", i, RowString(expected.Rows[i])) 725 } 726 } 727 if expected.RowsAffected != got.RowsAffected { 728 t.Logf("========== Got RowsAffected = %v", got.RowsAffected) 729 t.Logf("========== Expected RowsAffected = %v", expected.RowsAffected) 730 } 731 t.Errorf("\nExecuteStreamFetch(%v) returned:\n%+v\nBut was expecting:\n%+v\n", query, got, &expected) 732 } 733 }() 734 735 // The other side gets the request, and sends the result. 736 // Twice, once for ExecuteFetch, once for ExecuteStreamFetch. 737 count := 2 738 if !allRows && len(result.Rows) > 1 { 739 // short-circuit one test, the go routine returned and didn't 740 // do the streaming query. 741 count-- 742 } 743 744 handler := testHandler{ 745 result: result, 746 warnings: warningCount, 747 } 748 749 for i := 0; i < count; i++ { 750 kontinue := sConn.handleNextCommand(&handler) 751 require.True(t, kontinue, "error handling command: %d", i) 752 753 } 754 755 wg.Wait() 756 require.Equal(t, "", fatalError, fatalError) 757 758 } 759 760 // nolint 761 func writeResult(conn *Conn, result *sqltypes.Result) error { 762 if len(result.Fields) == 0 { 763 return conn.writeOKPacket(&PacketOK{ 764 affectedRows: result.RowsAffected, 765 lastInsertID: result.InsertID, 766 statusFlags: conn.StatusFlags, 767 warnings: 0, 768 }) 769 } 770 if err := conn.writeFields(result); err != nil { 771 return err 772 } 773 if err := conn.writeRows(result); err != nil { 774 return err 775 } 776 return conn.writeEndResult(false, 0, 0, 0) 777 } 778 779 func RowString(row []sqltypes.Value) string { 780 l := len(row) 781 result := fmt.Sprintf("%v values:", l) 782 for _, val := range row { 783 result += fmt.Sprintf(" %v", val) 784 } 785 return result 786 }