vitess.io/vitess@v0.16.2/go/mysql/query.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package mysql 18 19 import ( 20 "fmt" 21 "math" 22 "strconv" 23 "strings" 24 25 "vitess.io/vitess/go/sqltypes" 26 "vitess.io/vitess/go/vt/proto/vtrpc" 27 "vitess.io/vitess/go/vt/vterrors" 28 29 querypb "vitess.io/vitess/go/vt/proto/query" 30 ) 31 32 // This file contains the methods related to queries. 33 34 // 35 // Client side methods. 36 // 37 38 // WriteComQuery writes a query for the server to execute. 39 // Client -> Server. 40 // Returns SQLError(CRServerGone) if it can't. 41 func (c *Conn) WriteComQuery(query string) error { 42 // This is a new command, need to reset the sequence. 43 c.sequence = 0 44 45 data, pos := c.startEphemeralPacketWithHeader(len(query) + 1) 46 data[pos] = ComQuery 47 pos++ 48 copy(data[pos:], query) 49 if err := c.writeEphemeralPacket(); err != nil { 50 return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error()) 51 } 52 return nil 53 } 54 55 // writeComInitDB changes the default database to use. 56 // Client -> Server. 57 // Returns SQLError(CRServerGone) if it can't. 58 func (c *Conn) writeComInitDB(db string) error { 59 data, pos := c.startEphemeralPacketWithHeader(len(db) + 1) 60 data[pos] = ComInitDB 61 pos++ 62 copy(data[pos:], db) 63 if err := c.writeEphemeralPacket(); err != nil { 64 return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error()) 65 } 66 return nil 67 } 68 69 // writeComSetOption changes the connection's capability of executing multi statements. 70 // Returns SQLError(CRServerGone) if it can't. 71 func (c *Conn) writeComSetOption(operation uint16) error { 72 data, pos := c.startEphemeralPacketWithHeader(16 + 1) 73 data[pos] = ComSetOption 74 pos++ 75 writeUint16(data, pos, operation) 76 if err := c.writeEphemeralPacket(); err != nil { 77 return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error()) 78 } 79 return nil 80 } 81 82 // readColumnDefinition reads the next Column Definition packet. 83 // Returns a SQLError. 84 func (c *Conn) readColumnDefinition(field *querypb.Field, index int) error { 85 colDef, err := c.readEphemeralPacket() 86 if err != nil { 87 return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 88 } 89 defer c.recycleReadPacket() 90 91 // Catalog is ignored, always set to "def" 92 pos, ok := skipLenEncString(colDef, 0) 93 if !ok { 94 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v catalog failed", index) 95 } 96 97 // schema, table, orgTable, name and OrgName are strings. 98 field.Database, pos, ok = readLenEncString(colDef, pos) 99 if !ok { 100 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v schema failed", index) 101 } 102 field.Table, pos, ok = readLenEncString(colDef, pos) 103 if !ok { 104 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v table failed", index) 105 } 106 field.OrgTable, pos, ok = readLenEncString(colDef, pos) 107 if !ok { 108 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v org_table failed", index) 109 } 110 field.Name, pos, ok = readLenEncString(colDef, pos) 111 if !ok { 112 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v name failed", index) 113 } 114 field.OrgName, pos, ok = readLenEncString(colDef, pos) 115 if !ok { 116 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v org_name failed", index) 117 } 118 119 // Skip length of fixed-length fields. 120 pos++ 121 122 // characterSet is a uint16. 123 characterSet, pos, ok := readUint16(colDef, pos) 124 if !ok { 125 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v characterSet failed", index) 126 } 127 field.Charset = uint32(characterSet) 128 129 // columnLength is a uint32. 130 field.ColumnLength, pos, ok = readUint32(colDef, pos) 131 if !ok { 132 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v columnLength failed", index) 133 } 134 135 // type is one byte. 136 t, pos, ok := readByte(colDef, pos) 137 if !ok { 138 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v type failed", index) 139 } 140 141 // flags is 2 bytes. 142 flags, pos, ok := readUint16(colDef, pos) 143 if !ok { 144 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v flags failed", index) 145 } 146 147 // Convert MySQL type to Vitess type. 148 field.Type, err = sqltypes.MySQLToType(int64(t), int64(flags)) 149 if err != nil { 150 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "MySQLToType(%v,%v) failed for column %v: %v", t, flags, index, err) 151 } 152 // Decimals is a byte. 153 decimals, _, ok := readByte(colDef, pos) 154 if !ok { 155 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v decimals failed", index) 156 } 157 field.Decimals = uint32(decimals) 158 159 // If we didn't get column length or character set, 160 // we assume the original row on the other side was encoded from 161 // a Field without that data, so we don't return the flags. 162 if field.ColumnLength != 0 || field.Charset != 0 { 163 field.Flags = uint32(flags) 164 165 // FIXME(alainjobart): This is something the MySQL 166 // client library does: If the type is numerical, it 167 // adds a NUM_FLAG to the flags. We're doing it here 168 // only to be compatible with the C library. Once 169 // we're not using that library any more, we'll remove this. 170 // See doc.go. 171 if IsNum(t) { 172 field.Flags |= uint32(querypb.MySqlFlag_NUM_FLAG) 173 } 174 } 175 176 return nil 177 } 178 179 // readColumnDefinitionType is a faster version of 180 // readColumnDefinition that only fills in the Type. 181 // Returns a SQLError. 182 func (c *Conn) readColumnDefinitionType(field *querypb.Field, index int) error { 183 colDef, err := c.readEphemeralPacket() 184 if err != nil { 185 return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 186 } 187 defer c.recycleReadPacket() 188 189 // catalog, schema, table, orgTable, name and orgName are 190 // strings, all skipped. 191 pos, ok := skipLenEncString(colDef, 0) 192 if !ok { 193 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v catalog failed", index) 194 } 195 pos, ok = skipLenEncString(colDef, pos) 196 if !ok { 197 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v schema failed", index) 198 } 199 pos, ok = skipLenEncString(colDef, pos) 200 if !ok { 201 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v table failed", index) 202 } 203 pos, ok = skipLenEncString(colDef, pos) 204 if !ok { 205 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v org_table failed", index) 206 } 207 pos, ok = skipLenEncString(colDef, pos) 208 if !ok { 209 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v name failed", index) 210 } 211 pos, ok = skipLenEncString(colDef, pos) 212 if !ok { 213 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "skipping col %v org_name failed", index) 214 } 215 216 // Skip length of fixed-length fields. 217 pos++ 218 219 // characterSet is a uint16. 220 _, pos, ok = readUint16(colDef, pos) 221 if !ok { 222 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v characterSet failed", index) 223 } 224 225 // columnLength is a uint32. 226 _, pos, ok = readUint32(colDef, pos) 227 if !ok { 228 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v columnLength failed", index) 229 } 230 231 // type is one byte 232 t, pos, ok := readByte(colDef, pos) 233 if !ok { 234 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v type failed", index) 235 } 236 237 // flags is 2 bytes 238 flags, _, ok := readUint16(colDef, pos) 239 if !ok { 240 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extracting col %v flags failed", index) 241 } 242 243 // Convert MySQL type to Vitess type. 244 field.Type, err = sqltypes.MySQLToType(int64(t), int64(flags)) 245 if err != nil { 246 return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "MySQLToType(%v,%v) failed for column %v: %v", t, flags, index, err) 247 } 248 249 // skip decimals 250 251 return nil 252 } 253 254 // parseRow parses an individual row. 255 // Returns a SQLError. 256 func (c *Conn) parseRow(data []byte, fields []*querypb.Field, reader func([]byte, int) ([]byte, int, bool), result []sqltypes.Value) ([]sqltypes.Value, error) { 257 colNumber := len(fields) 258 if result == nil { 259 result = make([]sqltypes.Value, 0, colNumber) 260 } 261 pos := 0 262 for i := 0; i < colNumber; i++ { 263 if data[pos] == NullValue { 264 result = append(result, sqltypes.Value{}) 265 pos++ 266 continue 267 } 268 var s []byte 269 var ok bool 270 s, pos, ok = reader(data, pos) 271 if !ok { 272 return nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "decoding string failed") 273 } 274 result = append(result, sqltypes.MakeTrusted(fields[i].Type, s)) 275 } 276 return result, nil 277 } 278 279 // ExecuteFetch executes a query and returns the result. 280 // Returns a SQLError. Depending on the transport used, the error 281 // returned might be different for the same condition: 282 // 283 // 1. if the server closes the connection when no command is in flight: 284 // 285 // 1.1 unix: WriteComQuery will fail with a 'broken pipe', and we'll 286 // return CRServerGone(2006). 287 // 288 // 1.2 tcp: WriteComQuery will most likely work, but readComQueryResponse 289 // will fail, and we'll return CRServerLost(2013). 290 // 291 // This is because closing a TCP socket on the server side sends 292 // a FIN to the client (telling the client the server is done 293 // writing), but on most platforms doesn't send a RST. So the 294 // client has no idea it can't write. So it succeeds writing data, which 295 // *then* triggers the server to send a RST back, received a bit 296 // later. By then, the client has already started waiting for 297 // the response, and will just return a CRServerLost(2013). 298 // So CRServerGone(2006) will almost never be seen with TCP. 299 // 300 // 2. if the server closes the connection when a command is in flight, 301 // readComQueryResponse will fail, and we'll return CRServerLost(2013). 302 func (c *Conn) ExecuteFetch(query string, maxrows int, wantfields bool) (result *sqltypes.Result, err error) { 303 result, _, err = c.ExecuteFetchMulti(query, maxrows, wantfields) 304 return result, err 305 } 306 307 // ExecuteFetchMulti is for fetching multiple results from a multi-statement result. 308 // It returns an additional 'more' flag. If it is set, you must fetch the additional 309 // results using ReadQueryResult. 310 func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (result *sqltypes.Result, more bool, err error) { 311 defer func() { 312 if err != nil { 313 if sqlerr, ok := err.(*SQLError); ok { 314 sqlerr.Query = query 315 } 316 } 317 }() 318 319 // Send the query as a COM_QUERY packet. 320 if err = c.WriteComQuery(query); err != nil { 321 return nil, false, err 322 } 323 324 res, more, _, err := c.ReadQueryResult(maxrows, wantfields) 325 if err != nil { 326 return nil, false, err 327 } 328 return res, more, err 329 } 330 331 // ExecuteFetchWithWarningCount is for fetching results and a warning count 332 // Note: In a future iteration this should be abolished and merged into the 333 // ExecuteFetch API. 334 func (c *Conn) ExecuteFetchWithWarningCount(query string, maxrows int, wantfields bool) (result *sqltypes.Result, warnings uint16, err error) { 335 defer func() { 336 if err != nil { 337 if sqlerr, ok := err.(*SQLError); ok { 338 sqlerr.Query = query 339 } 340 } 341 }() 342 343 // Send the query as a COM_QUERY packet. 344 if err = c.WriteComQuery(query); err != nil { 345 return nil, 0, err 346 } 347 348 res, _, warnings, err := c.ReadQueryResult(maxrows, wantfields) 349 return res, warnings, err 350 } 351 352 // ReadQueryResult gets the result from the last written query. 353 func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, bool, uint16, error) { 354 // Get the result. 355 colNumber, packetOk, err := c.readComQueryResponse() 356 if err != nil { 357 return nil, false, 0, err 358 } 359 more := packetOk.statusFlags&ServerMoreResultsExists != 0 360 warnings := packetOk.warnings 361 if colNumber == 0 { 362 // OK packet, means no results. Just use the numbers. 363 return &sqltypes.Result{ 364 RowsAffected: packetOk.affectedRows, 365 InsertID: packetOk.lastInsertID, 366 SessionStateChanges: packetOk.sessionStateData, 367 StatusFlags: packetOk.statusFlags, 368 Info: packetOk.info, 369 }, more, warnings, nil 370 } 371 372 fields := make([]querypb.Field, colNumber) 373 result := &sqltypes.Result{ 374 Fields: make([]*querypb.Field, colNumber), 375 } 376 377 // Read column headers. One packet per column. 378 // Build the fields. 379 for i := 0; i < colNumber; i++ { 380 result.Fields[i] = &fields[i] 381 382 if wantfields { 383 if err := c.readColumnDefinition(result.Fields[i], i); err != nil { 384 return nil, false, 0, err 385 } 386 } else { 387 if err := c.readColumnDefinitionType(result.Fields[i], i); err != nil { 388 return nil, false, 0, err 389 } 390 } 391 } 392 393 if c.Capabilities&CapabilityClientDeprecateEOF == 0 { 394 // EOF is only present here if it's not deprecated. 395 data, err := c.readEphemeralPacket() 396 if err != nil { 397 return nil, false, 0, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 398 } 399 if c.isEOFPacket(data) { 400 401 // This is what we expect. 402 // Warnings and status flags are ignored. 403 c.recycleReadPacket() 404 // goto: read row loop 405 406 } else if isErrorPacket(data) { 407 defer c.recycleReadPacket() 408 return nil, false, 0, ParseErrorPacket(data) 409 } else { 410 defer c.recycleReadPacket() 411 return nil, false, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data) 412 } 413 } 414 415 // read each row until EOF or OK packet. 416 for { 417 data, err := c.readEphemeralPacket() 418 if err != nil { 419 return nil, false, 0, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 420 } 421 422 if c.isEOFPacket(data) { 423 defer c.recycleReadPacket() 424 425 // Strip the partial Fields before returning. 426 if !wantfields { 427 result.Fields = nil 428 } 429 430 // The deprecated EOF packets change means that this is either an 431 // EOF packet or an OK packet with the EOF type code. 432 if c.Capabilities&CapabilityClientDeprecateEOF == 0 { 433 var statusFlags uint16 434 warnings, statusFlags, err = parseEOFPacket(data) 435 if err != nil { 436 return nil, false, 0, err 437 } 438 more = (statusFlags & ServerMoreResultsExists) != 0 439 result.StatusFlags = statusFlags 440 } else { 441 packetOk, err := c.parseOKPacket(data) 442 if err != nil { 443 return nil, false, 0, err 444 } 445 warnings = packetOk.warnings 446 more = (packetOk.statusFlags & ServerMoreResultsExists) != 0 447 result.SessionStateChanges = packetOk.sessionStateData 448 result.StatusFlags = packetOk.statusFlags 449 result.Info = packetOk.info 450 } 451 return result, more, warnings, nil 452 453 } else if isErrorPacket(data) { 454 defer c.recycleReadPacket() 455 // Error packet. 456 return nil, false, 0, ParseErrorPacket(data) 457 } 458 459 // Check we're not over the limit before we add more. 460 if len(result.Rows) == maxrows { 461 c.recycleReadPacket() 462 if err := c.drainResults(); err != nil { 463 return nil, false, 0, err 464 } 465 return nil, false, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows) 466 } 467 468 // Regular row. 469 row, err := c.parseRow(data, result.Fields, readLenEncStringAsBytesCopy, nil) 470 if err != nil { 471 c.recycleReadPacket() 472 return nil, false, 0, err 473 } 474 result.Rows = append(result.Rows, row) 475 c.recycleReadPacket() 476 } 477 } 478 479 // drainResults will read all packets for a result set and ignore them. 480 func (c *Conn) drainResults() error { 481 for { 482 data, err := c.readEphemeralPacket() 483 if err != nil { 484 return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 485 } 486 if c.isEOFPacket(data) { 487 c.recycleReadPacket() 488 return nil 489 } else if isErrorPacket(data) { 490 defer c.recycleReadPacket() 491 return ParseErrorPacket(data) 492 } 493 c.recycleReadPacket() 494 } 495 } 496 497 func (c *Conn) readComQueryResponse() (int, *PacketOK, error) { 498 data, err := c.readEphemeralPacket() 499 if err != nil { 500 return 0, nil, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) 501 } 502 defer c.recycleReadPacket() 503 if len(data) == 0 { 504 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "invalid empty COM_QUERY response packet") 505 } 506 507 switch data[0] { 508 case OKPacket: 509 packetOk, err := c.parseOKPacket(data) 510 return 0, packetOk, err 511 case ErrPacket: 512 // Error 513 return 0, nil, ParseErrorPacket(data) 514 case 0xfb: 515 // Local infile 516 return 0, nil, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented") 517 } 518 n, pos, ok := readLenEncInt(data, 0) 519 if !ok { 520 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "cannot get column number") 521 } 522 if pos != len(data) { 523 return 0, nil, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extra data in COM_QUERY response") 524 } 525 return int(n), &PacketOK{}, nil 526 } 527 528 // 529 // Server side methods. 530 // 531 532 func (c *Conn) parseComQuery(data []byte) string { 533 return string(data[1:]) 534 } 535 536 func (c *Conn) parseComSetOption(data []byte) (uint16, bool) { 537 val, _, ok := readUint16(data, 1) 538 return val, ok 539 } 540 541 func (c *Conn) parseComPrepare(data []byte) string { 542 return string(data[1:]) 543 } 544 545 func (c *Conn) parseComStmtExecute(prepareData map[uint32]*PrepareData, data []byte) (uint32, byte, error) { 546 pos := 0 547 payload := data[1:] 548 bitMap := make([]byte, 0) 549 550 // statement ID 551 stmtID, pos, ok := readUint32(payload, 0) 552 if !ok { 553 return 0, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading statement ID failed") 554 } 555 prepare, ok := prepareData[stmtID] 556 if !ok { 557 return 0, 0, NewSQLError(CRCommandsOutOfSync, SSUnknownSQLState, "statement ID is not found from record") 558 } 559 560 // cursor type flags 561 cursorType, pos, ok := readByte(payload, pos) 562 if !ok { 563 return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading cursor type flags failed") 564 } 565 566 // iteration count 567 iterCount, pos, ok := readUint32(payload, pos) 568 if !ok { 569 return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading iteration count failed") 570 } 571 if iterCount != uint32(1) { 572 return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "iteration count is not equal to 1") 573 } 574 575 if prepare.ParamsCount > 0 { 576 bitMap, pos, ok = readBytes(payload, pos, (int(prepare.ParamsCount)+7)/8) 577 if !ok { 578 return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading NULL-bitmap failed") 579 } 580 } 581 582 newParamsBoundFlag, pos, ok := readByte(payload, pos) 583 if ok && newParamsBoundFlag == 0x01 { 584 var mysqlType, flags byte 585 for i := uint16(0); i < prepare.ParamsCount; i++ { 586 mysqlType, pos, ok = readByte(payload, pos) 587 if !ok { 588 return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading parameter type failed") 589 } 590 591 flags, pos, ok = readByte(payload, pos) 592 if !ok { 593 return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading parameter flags failed") 594 } 595 596 // convert MySQL type to internal type. 597 valType, err := sqltypes.MySQLToType(int64(mysqlType), int64(flags)) 598 if err != nil { 599 return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "MySQLToType(%v,%v) failed: %v", mysqlType, flags, err) 600 } 601 602 prepare.ParamsType[i] = int32(valType) 603 } 604 } 605 606 for i := 0; i < len(prepare.ParamsType); i++ { 607 var val sqltypes.Value 608 parameterID := fmt.Sprintf("v%d", i+1) 609 if v, ok := prepare.BindVars[parameterID]; ok { 610 if v != nil { 611 continue 612 } 613 } 614 615 if (bitMap[i/8] & (1 << uint(i%8))) > 0 { 616 val, pos, ok = c.parseStmtArgs(nil, sqltypes.Null, pos) 617 } else { 618 val, pos, ok = c.parseStmtArgs(payload, querypb.Type(prepare.ParamsType[i]), pos) 619 } 620 if !ok { 621 return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "decoding parameter value failed: %v", prepare.ParamsType[i]) 622 } 623 624 prepare.BindVars[parameterID] = sqltypes.ValueBindVariable(val) 625 } 626 627 return stmtID, cursorType, nil 628 } 629 630 func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.Value, int, bool) { 631 switch typ { 632 case sqltypes.Null: 633 return sqltypes.NULL, pos, true 634 case sqltypes.Int8: 635 val, pos, ok := readByte(data, pos) 636 return sqltypes.NewInt64(int64(int8(val))), pos, ok 637 case sqltypes.Uint8: 638 val, pos, ok := readByte(data, pos) 639 return sqltypes.NewUint64(uint64(val)), pos, ok 640 case sqltypes.Uint16: 641 val, pos, ok := readUint16(data, pos) 642 return sqltypes.NewUint64(uint64(val)), pos, ok 643 case sqltypes.Int16, sqltypes.Year: 644 val, pos, ok := readUint16(data, pos) 645 return sqltypes.NewInt64(int64(int16(val))), pos, ok 646 case sqltypes.Uint24, sqltypes.Uint32: 647 val, pos, ok := readUint32(data, pos) 648 return sqltypes.NewUint64(uint64(val)), pos, ok 649 case sqltypes.Int24, sqltypes.Int32: 650 val, pos, ok := readUint32(data, pos) 651 return sqltypes.NewInt64(int64(int32(val))), pos, ok 652 case sqltypes.Float32: 653 val, pos, ok := readUint32(data, pos) 654 return sqltypes.NewFloat64(float64(math.Float32frombits(uint32(val)))), pos, ok 655 case sqltypes.Uint64: 656 val, pos, ok := readUint64(data, pos) 657 return sqltypes.NewUint64(val), pos, ok 658 case sqltypes.Int64: 659 val, pos, ok := readUint64(data, pos) 660 return sqltypes.NewInt64(int64(val)), pos, ok 661 case sqltypes.Float64: 662 val, pos, ok := readUint64(data, pos) 663 return sqltypes.NewFloat64(math.Float64frombits(val)), pos, ok 664 case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime: 665 size, pos, ok := readByte(data, pos) 666 if !ok { 667 return sqltypes.NULL, 0, false 668 } 669 switch size { 670 case 0x00: 671 return sqltypes.NewVarChar(" "), pos, ok 672 case 0x0b: 673 year, pos, ok := readUint16(data, pos) 674 if !ok { 675 return sqltypes.NULL, 0, false 676 } 677 month, pos, ok := readByte(data, pos) 678 if !ok { 679 return sqltypes.NULL, 0, false 680 } 681 day, pos, ok := readByte(data, pos) 682 if !ok { 683 return sqltypes.NULL, 0, false 684 } 685 hour, pos, ok := readByte(data, pos) 686 if !ok { 687 return sqltypes.NULL, 0, false 688 } 689 minute, pos, ok := readByte(data, pos) 690 if !ok { 691 return sqltypes.NULL, 0, false 692 } 693 second, pos, ok := readByte(data, pos) 694 if !ok { 695 return sqltypes.NULL, 0, false 696 } 697 microSecond, pos, ok := readUint32(data, pos) 698 if !ok { 699 return sqltypes.NULL, 0, false 700 } 701 val := strconv.Itoa(int(year)) + "-" + 702 strconv.Itoa(int(month)) + "-" + 703 strconv.Itoa(int(day)) + " " + 704 strconv.Itoa(int(hour)) + ":" + 705 strconv.Itoa(int(minute)) + ":" + 706 strconv.Itoa(int(second)) + "." + 707 fmt.Sprintf("%06d", microSecond) 708 709 return sqltypes.NewVarChar(val), pos, ok 710 case 0x07: 711 year, pos, ok := readUint16(data, pos) 712 if !ok { 713 return sqltypes.NULL, 0, false 714 } 715 month, pos, ok := readByte(data, pos) 716 if !ok { 717 return sqltypes.NULL, 0, false 718 } 719 day, pos, ok := readByte(data, pos) 720 if !ok { 721 return sqltypes.NULL, 0, false 722 } 723 hour, pos, ok := readByte(data, pos) 724 if !ok { 725 return sqltypes.NULL, 0, false 726 } 727 minute, pos, ok := readByte(data, pos) 728 if !ok { 729 return sqltypes.NULL, 0, false 730 } 731 second, pos, ok := readByte(data, pos) 732 if !ok { 733 return sqltypes.NULL, 0, false 734 } 735 val := strconv.Itoa(int(year)) + "-" + 736 strconv.Itoa(int(month)) + "-" + 737 strconv.Itoa(int(day)) + " " + 738 strconv.Itoa(int(hour)) + ":" + 739 strconv.Itoa(int(minute)) + ":" + 740 strconv.Itoa(int(second)) 741 742 return sqltypes.NewVarChar(val), pos, ok 743 case 0x04: 744 year, pos, ok := readUint16(data, pos) 745 if !ok { 746 return sqltypes.NULL, 0, false 747 } 748 month, pos, ok := readByte(data, pos) 749 if !ok { 750 return sqltypes.NULL, 0, false 751 } 752 day, pos, ok := readByte(data, pos) 753 if !ok { 754 return sqltypes.NULL, 0, false 755 } 756 val := strconv.Itoa(int(year)) + "-" + 757 strconv.Itoa(int(month)) + "-" + 758 strconv.Itoa(int(day)) 759 760 return sqltypes.NewVarChar(val), pos, ok 761 default: 762 return sqltypes.NULL, 0, false 763 } 764 case sqltypes.Time: 765 size, pos, ok := readByte(data, pos) 766 if !ok { 767 return sqltypes.NULL, 0, false 768 } 769 switch size { 770 case 0x00: 771 return sqltypes.NewVarChar("00:00:00"), pos, ok 772 case 0x0c: 773 isNegative, pos, ok := readByte(data, pos) 774 if !ok { 775 return sqltypes.NULL, 0, false 776 } 777 days, pos, ok := readUint32(data, pos) 778 if !ok { 779 return sqltypes.NULL, 0, false 780 } 781 hour, pos, ok := readByte(data, pos) 782 if !ok { 783 return sqltypes.NULL, 0, false 784 } 785 786 hours := uint32(hour) + days*uint32(24) 787 788 minute, pos, ok := readByte(data, pos) 789 if !ok { 790 return sqltypes.NULL, 0, false 791 } 792 second, pos, ok := readByte(data, pos) 793 if !ok { 794 return sqltypes.NULL, 0, false 795 } 796 microSecond, pos, ok := readUint32(data, pos) 797 if !ok { 798 return sqltypes.NULL, 0, false 799 } 800 801 val := "" 802 if isNegative == 0x01 { 803 val += "-" 804 } 805 val += strconv.Itoa(int(hours)) + ":" + 806 strconv.Itoa(int(minute)) + ":" + 807 strconv.Itoa(int(second)) + "." + 808 fmt.Sprintf("%06d", microSecond) 809 810 return sqltypes.NewVarChar(val), pos, ok 811 case 0x08: 812 isNegative, pos, ok := readByte(data, pos) 813 if !ok { 814 return sqltypes.NULL, 0, false 815 } 816 days, pos, ok := readUint32(data, pos) 817 if !ok { 818 return sqltypes.NULL, 0, false 819 } 820 hour, pos, ok := readByte(data, pos) 821 if !ok { 822 return sqltypes.NULL, 0, false 823 } 824 825 hours := uint32(hour) + days*uint32(24) 826 827 minute, pos, ok := readByte(data, pos) 828 if !ok { 829 return sqltypes.NULL, 0, false 830 } 831 second, pos, ok := readByte(data, pos) 832 if !ok { 833 return sqltypes.NULL, 0, false 834 } 835 836 val := "" 837 if isNegative == 0x01 { 838 val += "-" 839 } 840 val += strconv.Itoa(int(hours)) + ":" + 841 strconv.Itoa(int(minute)) + ":" + 842 strconv.Itoa(int(second)) 843 844 return sqltypes.NewVarChar(val), pos, ok 845 default: 846 return sqltypes.NULL, 0, false 847 } 848 case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, sqltypes.VarBinary, sqltypes.Char, 849 sqltypes.Bit, sqltypes.Enum, sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON: 850 val, pos, ok := readLenEncStringAsBytesCopy(data, pos) 851 return sqltypes.MakeTrusted(sqltypes.VarBinary, val), pos, ok 852 default: 853 return sqltypes.NULL, pos, false 854 } 855 } 856 857 func (c *Conn) parseComStmtSendLongData(data []byte) (uint32, uint16, []byte, bool) { 858 pos := 1 859 statementID, pos, ok := readUint32(data, pos) 860 if !ok { 861 return 0, 0, nil, false 862 } 863 864 paramID, pos, ok := readUint16(data, pos) 865 if !ok { 866 return 0, 0, nil, false 867 } 868 869 chunkData := data[pos:] 870 chunk := make([]byte, len(chunkData)) 871 copy(chunk, chunkData) 872 873 return statementID, paramID, chunk, true 874 } 875 876 func (c *Conn) parseComStmtClose(data []byte) (uint32, bool) { 877 val, _, ok := readUint32(data, 1) 878 return val, ok 879 } 880 881 func (c *Conn) parseComStmtReset(data []byte) (uint32, bool) { 882 val, _, ok := readUint32(data, 1) 883 return val, ok 884 } 885 886 func (c *Conn) parseComInitDB(data []byte) string { 887 return string(data[1:]) 888 } 889 890 func (c *Conn) sendColumnCount(count uint64) error { 891 length := lenEncIntSize(count) 892 data, pos := c.startEphemeralPacketWithHeader(length) 893 writeLenEncInt(data, pos, count) 894 return c.writeEphemeralPacket() 895 } 896 897 func (c *Conn) writeColumnDefinition(field *querypb.Field) error { 898 length := 4 + // lenEncStringSize("def") 899 lenEncStringSize(field.Database) + 900 lenEncStringSize(field.Table) + 901 lenEncStringSize(field.OrgTable) + 902 lenEncStringSize(field.Name) + 903 lenEncStringSize(field.OrgName) + 904 1 + // length of fixed length fields 905 2 + // character set 906 4 + // column length 907 1 + // type 908 2 + // flags 909 1 + // decimals 910 2 // filler 911 912 // Get the type and the flags back. If the Field contains 913 // non-zero flags, we use them. Otherwise use the flags we 914 // derive from the type. 915 typ, flags := sqltypes.TypeToMySQL(field.Type) 916 if field.Flags != 0 { 917 flags = int64(field.Flags) 918 } 919 920 data, pos := c.startEphemeralPacketWithHeader(length) 921 922 pos = writeLenEncString(data, pos, "def") // Always the same. 923 pos = writeLenEncString(data, pos, field.Database) 924 pos = writeLenEncString(data, pos, field.Table) 925 pos = writeLenEncString(data, pos, field.OrgTable) 926 pos = writeLenEncString(data, pos, field.Name) 927 pos = writeLenEncString(data, pos, field.OrgName) 928 pos = writeByte(data, pos, 0x0c) 929 pos = writeUint16(data, pos, uint16(field.Charset)) 930 pos = writeUint32(data, pos, field.ColumnLength) 931 pos = writeByte(data, pos, byte(typ)) 932 pos = writeUint16(data, pos, uint16(flags)) 933 pos = writeByte(data, pos, byte(field.Decimals)) 934 pos = writeUint16(data, pos, uint16(0x0000)) 935 936 if pos != len(data) { 937 return vterrors.Errorf(vtrpc.Code_INTERNAL, "packing of column definition used %v bytes instead of %v", pos, len(data)) 938 } 939 940 return c.writeEphemeralPacket() 941 } 942 943 func (c *Conn) writeRow(row []sqltypes.Value) error { 944 length := 0 945 for _, val := range row { 946 if val.IsNull() { 947 length++ 948 } else { 949 l := len(val.Raw()) 950 length += lenEncIntSize(uint64(l)) + l 951 } 952 } 953 954 data, pos := c.startEphemeralPacketWithHeader(length) 955 for _, val := range row { 956 if val.IsNull() { 957 pos = writeByte(data, pos, NullValue) 958 } else { 959 l := len(val.Raw()) 960 pos = writeLenEncInt(data, pos, uint64(l)) 961 pos += copy(data[pos:], val.Raw()) 962 } 963 } 964 965 return c.writeEphemeralPacket() 966 } 967 968 // writeFields writes the fields of a Result. It should be called only 969 // if there are valid columns in the result. 970 func (c *Conn) writeFields(result *sqltypes.Result) error { 971 // Send the number of fields first. 972 if err := c.sendColumnCount(uint64(len(result.Fields))); err != nil { 973 return err 974 } 975 976 // Now send each Field. 977 for _, field := range result.Fields { 978 if err := c.writeColumnDefinition(field); err != nil { 979 return err 980 } 981 } 982 983 // Now send an EOF packet. 984 if c.Capabilities&CapabilityClientDeprecateEOF == 0 { 985 // With CapabilityClientDeprecateEOF, we do not send this EOF. 986 if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { 987 return err 988 } 989 } 990 return nil 991 } 992 993 // writeRows sends the rows of a Result. 994 func (c *Conn) writeRows(result *sqltypes.Result) error { 995 for _, row := range result.Rows { 996 if err := c.writeRow(row); err != nil { 997 return err 998 } 999 } 1000 return nil 1001 } 1002 1003 // writeEndResult concludes the sending of a Result. 1004 // if more is set to true, then it means there are more results afterwords 1005 func (c *Conn) writeEndResult(more bool, affectedRows, lastInsertID uint64, warnings uint16) error { 1006 // Send either an EOF, or an OK packet. 1007 // See doc.go. 1008 flags := c.StatusFlags 1009 if more { 1010 flags |= ServerMoreResultsExists 1011 } 1012 if c.Capabilities&CapabilityClientDeprecateEOF == 0 { 1013 if err := c.writeEOFPacket(flags, warnings); err != nil { 1014 return err 1015 } 1016 } else { 1017 // This will flush too. 1018 if err := c.writeOKPacketWithEOFHeader(&PacketOK{ 1019 affectedRows: affectedRows, 1020 lastInsertID: lastInsertID, 1021 statusFlags: flags, 1022 warnings: warnings, 1023 }); err != nil { 1024 return err 1025 } 1026 } 1027 1028 return nil 1029 } 1030 1031 // PacketComStmtPrepareOK contains the COM_STMT_PREPARE_OK packet details 1032 type PacketComStmtPrepareOK struct { 1033 status uint8 1034 stmtID uint32 1035 numCols uint16 1036 numParams uint16 1037 warningCount uint16 1038 } 1039 1040 // writePrepare writes a prepare query response to the wire. 1041 func (c *Conn) writePrepare(fld []*querypb.Field, prepare *PrepareData) error { 1042 paramsCount := prepare.ParamsCount 1043 columnCount := 0 1044 if len(fld) != 0 { 1045 columnCount = len(fld) 1046 } 1047 if columnCount > 0 { 1048 prepare.ColumnNames = make([]string, columnCount) 1049 } 1050 1051 ok := PacketComStmtPrepareOK{ 1052 status: OKPacket, 1053 stmtID: prepare.StatementID, 1054 numCols: (uint16)(columnCount), 1055 numParams: paramsCount, 1056 warningCount: 0, 1057 } 1058 bytes, pos := c.startEphemeralPacketWithHeader(12) 1059 data := &coder{data: bytes, pos: pos} 1060 data.writeByte(ok.status) 1061 data.writeUint32(ok.stmtID) 1062 data.writeUint16(ok.numCols) 1063 data.writeUint16(ok.numParams) 1064 data.writeByte(0x00) // reserved 1 byte 1065 data.writeUint16(ok.warningCount) 1066 1067 if err := c.writeEphemeralPacket(); err != nil { 1068 return err 1069 } 1070 1071 if paramsCount > 0 { 1072 for i := uint16(0); i < paramsCount; i++ { 1073 if err := c.writeColumnDefinition(&querypb.Field{ 1074 Name: "?", 1075 Type: sqltypes.VarBinary, 1076 Charset: 63}); err != nil { 1077 return err 1078 } 1079 } 1080 1081 // Now send an EOF packet. 1082 if c.Capabilities&CapabilityClientDeprecateEOF == 0 { 1083 // With CapabilityClientDeprecateEOF, we do not send this EOF. 1084 if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { 1085 return err 1086 } 1087 } 1088 } 1089 1090 for i, field := range fld { 1091 field.Name = strings.Replace(field.Name, "'?'", "?", -1) 1092 prepare.ColumnNames[i] = field.Name 1093 if err := c.writeColumnDefinition(field); err != nil { 1094 return err 1095 } 1096 } 1097 1098 if columnCount > 0 { 1099 // Now send an EOF packet. 1100 if c.Capabilities&CapabilityClientDeprecateEOF == 0 { 1101 // With CapabilityClientDeprecateEOF, we do not send this EOF. 1102 if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { 1103 return err 1104 } 1105 } 1106 } 1107 1108 return nil 1109 } 1110 1111 func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) error { 1112 length := 0 1113 nullBitMapLen := (len(fields) + 7 + 2) / 8 1114 for _, val := range row { 1115 if !val.IsNull() { 1116 l, err := val2MySQLLen(val) 1117 if err != nil { 1118 return fmt.Errorf("internal value %v get MySQL value length error: %v", val, err) 1119 } 1120 length += l 1121 } 1122 } 1123 1124 length += nullBitMapLen + 1 1125 1126 data, pos := c.startEphemeralPacketWithHeader(length) 1127 1128 pos = writeByte(data, pos, 0x00) 1129 1130 for i := 0; i < nullBitMapLen; i++ { 1131 pos = writeByte(data, pos, 0x00) 1132 } 1133 1134 for i, val := range row { 1135 if val.IsNull() { 1136 bytePos := (i+2)/8 + 1 + packetHeaderSize 1137 bitPos := (i + 2) % 8 1138 data[bytePos] |= 1 << uint(bitPos) 1139 } else { 1140 v, err := val2MySQL(val) 1141 if err != nil { 1142 c.recycleWritePacket() 1143 return fmt.Errorf("internal value %v to MySQL value error: %v", val, err) 1144 } 1145 pos += copy(data[pos:], v) 1146 } 1147 } 1148 1149 return c.writeEphemeralPacket() 1150 } 1151 1152 // writeBinaryRows sends the rows of a Result with binary form. 1153 func (c *Conn) writeBinaryRows(result *sqltypes.Result) error { 1154 for _, row := range result.Rows { 1155 if err := c.writeBinaryRow(result.Fields, row); err != nil { 1156 return err 1157 } 1158 } 1159 return nil 1160 } 1161 1162 func val2MySQL(v sqltypes.Value) ([]byte, error) { 1163 var out []byte 1164 pos := 0 1165 switch v.Type() { 1166 case sqltypes.Null: 1167 // no-op 1168 case sqltypes.Int8: 1169 val, err := strconv.ParseInt(v.ToString(), 10, 8) 1170 if err != nil { 1171 return []byte{}, err 1172 } 1173 out = make([]byte, 1) 1174 writeByte(out, pos, uint8(val)) 1175 case sqltypes.Uint8: 1176 val, err := strconv.ParseUint(v.ToString(), 10, 8) 1177 if err != nil { 1178 return []byte{}, err 1179 } 1180 out = make([]byte, 1) 1181 writeByte(out, pos, uint8(val)) 1182 case sqltypes.Uint16: 1183 val, err := strconv.ParseUint(v.ToString(), 10, 16) 1184 if err != nil { 1185 return []byte{}, err 1186 } 1187 out = make([]byte, 2) 1188 writeUint16(out, pos, uint16(val)) 1189 case sqltypes.Int16, sqltypes.Year: 1190 val, err := strconv.ParseInt(v.ToString(), 10, 16) 1191 if err != nil { 1192 return []byte{}, err 1193 } 1194 out = make([]byte, 2) 1195 writeUint16(out, pos, uint16(val)) 1196 case sqltypes.Uint24, sqltypes.Uint32: 1197 val, err := strconv.ParseUint(v.ToString(), 10, 32) 1198 if err != nil { 1199 return []byte{}, err 1200 } 1201 out = make([]byte, 4) 1202 writeUint32(out, pos, uint32(val)) 1203 case sqltypes.Int24, sqltypes.Int32: 1204 val, err := strconv.ParseInt(v.ToString(), 10, 32) 1205 if err != nil { 1206 return []byte{}, err 1207 } 1208 out = make([]byte, 4) 1209 writeUint32(out, pos, uint32(val)) 1210 case sqltypes.Float32: 1211 val, err := strconv.ParseFloat(v.ToString(), 32) 1212 if err != nil { 1213 return []byte{}, err 1214 } 1215 bits := math.Float32bits(float32(val)) 1216 out = make([]byte, 4) 1217 writeUint32(out, pos, bits) 1218 case sqltypes.Uint64: 1219 val, err := strconv.ParseUint(v.ToString(), 10, 64) 1220 if err != nil { 1221 return []byte{}, err 1222 } 1223 out = make([]byte, 8) 1224 writeUint64(out, pos, uint64(val)) 1225 case sqltypes.Int64: 1226 val, err := strconv.ParseInt(v.ToString(), 10, 64) 1227 if err != nil { 1228 return []byte{}, err 1229 } 1230 out = make([]byte, 8) 1231 writeUint64(out, pos, uint64(val)) 1232 case sqltypes.Float64: 1233 val, err := strconv.ParseFloat(v.ToString(), 64) 1234 if err != nil { 1235 return []byte{}, err 1236 } 1237 bits := math.Float64bits(val) 1238 out = make([]byte, 8) 1239 writeUint64(out, pos, bits) 1240 case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime: 1241 if len(v.Raw()) > 19 { 1242 out = make([]byte, 1+11) 1243 out[pos] = 0x0b 1244 pos++ 1245 year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16) 1246 if err != nil { 1247 return []byte{}, err 1248 } 1249 month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8) 1250 if err != nil { 1251 return []byte{}, err 1252 } 1253 day, err := strconv.ParseUint(string(v.Raw()[8:10]), 10, 8) 1254 if err != nil { 1255 return []byte{}, err 1256 } 1257 hour, err := strconv.ParseUint(string(v.Raw()[11:13]), 10, 8) 1258 if err != nil { 1259 return []byte{}, err 1260 } 1261 minute, err := strconv.ParseUint(string(v.Raw()[14:16]), 10, 8) 1262 if err != nil { 1263 return []byte{}, err 1264 } 1265 second, err := strconv.ParseUint(string(v.Raw()[17:19]), 10, 8) 1266 if err != nil { 1267 return []byte{}, err 1268 } 1269 val := make([]byte, 6) 1270 count := copy(val, v.Raw()[20:]) 1271 for i := 0; i < (6 - count); i++ { 1272 val[count+i] = 0x30 1273 } 1274 microSecond, err := strconv.ParseUint(string(val), 10, 32) 1275 if err != nil { 1276 return []byte{}, err 1277 } 1278 pos = writeUint16(out, pos, uint16(year)) 1279 pos = writeByte(out, pos, byte(month)) 1280 pos = writeByte(out, pos, byte(day)) 1281 pos = writeByte(out, pos, byte(hour)) 1282 pos = writeByte(out, pos, byte(minute)) 1283 pos = writeByte(out, pos, byte(second)) 1284 writeUint32(out, pos, uint32(microSecond)) 1285 } else if len(v.Raw()) > 10 { 1286 out = make([]byte, 1+7) 1287 out[pos] = 0x07 1288 pos++ 1289 year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16) 1290 if err != nil { 1291 return []byte{}, err 1292 } 1293 month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8) 1294 if err != nil { 1295 return []byte{}, err 1296 } 1297 day, err := strconv.ParseUint(string(v.Raw()[8:10]), 10, 8) 1298 if err != nil { 1299 return []byte{}, err 1300 } 1301 hour, err := strconv.ParseUint(string(v.Raw()[11:13]), 10, 8) 1302 if err != nil { 1303 return []byte{}, err 1304 } 1305 minute, err := strconv.ParseUint(string(v.Raw()[14:16]), 10, 8) 1306 if err != nil { 1307 return []byte{}, err 1308 } 1309 second, err := strconv.ParseUint(string(v.Raw()[17:]), 10, 8) 1310 if err != nil { 1311 return []byte{}, err 1312 } 1313 pos = writeUint16(out, pos, uint16(year)) 1314 pos = writeByte(out, pos, byte(month)) 1315 pos = writeByte(out, pos, byte(day)) 1316 pos = writeByte(out, pos, byte(hour)) 1317 pos = writeByte(out, pos, byte(minute)) 1318 writeByte(out, pos, byte(second)) 1319 } else if len(v.Raw()) > 0 { 1320 out = make([]byte, 1+4) 1321 out[pos] = 0x04 1322 pos++ 1323 year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16) 1324 if err != nil { 1325 return []byte{}, err 1326 } 1327 month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8) 1328 if err != nil { 1329 return []byte{}, err 1330 } 1331 day, err := strconv.ParseUint(string(v.Raw()[8:]), 10, 8) 1332 if err != nil { 1333 return []byte{}, err 1334 } 1335 pos = writeUint16(out, pos, uint16(year)) 1336 pos = writeByte(out, pos, byte(month)) 1337 writeByte(out, pos, byte(day)) 1338 } else { 1339 out = make([]byte, 1) 1340 out[pos] = 0x00 1341 } 1342 case sqltypes.Time: 1343 if string(v.Raw()) == "00:00:00" { 1344 out = make([]byte, 1) 1345 out[pos] = 0x00 1346 } else if strings.Contains(string(v.Raw()), ".") { 1347 out = make([]byte, 1+12) 1348 out[pos] = 0x0c 1349 pos++ 1350 1351 sub1 := strings.Split(string(v.Raw()), ":") 1352 if len(sub1) != 3 { 1353 err := fmt.Errorf("incorrect time value, ':' is not found") 1354 return []byte{}, err 1355 } 1356 sub2 := strings.Split(sub1[2], ".") 1357 if len(sub2) != 2 { 1358 err := fmt.Errorf("incorrect time value, '.' is not found") 1359 return []byte{}, err 1360 } 1361 1362 var total []byte 1363 if strings.HasPrefix(sub1[0], "-") { 1364 out[pos] = 0x01 1365 total = []byte(sub1[0]) 1366 total = total[1:] 1367 } else { 1368 out[pos] = 0x00 1369 total = []byte(sub1[0]) 1370 } 1371 pos++ 1372 1373 h, err := strconv.ParseUint(string(total), 10, 32) 1374 if err != nil { 1375 return []byte{}, err 1376 } 1377 1378 days := uint32(h) / 24 1379 hours := uint32(h) % 24 1380 minute := sub1[1] 1381 second := sub2[0] 1382 microSecond := sub2[1] 1383 1384 minutes, err := strconv.ParseUint(minute, 10, 8) 1385 if err != nil { 1386 return []byte{}, err 1387 } 1388 1389 seconds, err := strconv.ParseUint(second, 10, 8) 1390 if err != nil { 1391 return []byte{}, err 1392 } 1393 pos = writeUint32(out, pos, uint32(days)) 1394 pos = writeByte(out, pos, byte(hours)) 1395 pos = writeByte(out, pos, byte(minutes)) 1396 pos = writeByte(out, pos, byte(seconds)) 1397 1398 val := make([]byte, 6) 1399 count := copy(val, microSecond) 1400 for i := 0; i < (6 - count); i++ { 1401 val[count+i] = 0x30 1402 } 1403 microSeconds, err := strconv.ParseUint(string(val), 10, 32) 1404 if err != nil { 1405 return []byte{}, err 1406 } 1407 writeUint32(out, pos, uint32(microSeconds)) 1408 } else if len(v.Raw()) > 0 { 1409 out = make([]byte, 1+8) 1410 out[pos] = 0x08 1411 pos++ 1412 1413 sub1 := strings.Split(string(v.Raw()), ":") 1414 if len(sub1) != 3 { 1415 err := fmt.Errorf("incorrect time value, ':' is not found") 1416 return []byte{}, err 1417 } 1418 1419 var total []byte 1420 if strings.HasPrefix(sub1[0], "-") { 1421 out[pos] = 0x01 1422 total = []byte(sub1[0]) 1423 total = total[1:] 1424 } else { 1425 out[pos] = 0x00 1426 total = []byte(sub1[0]) 1427 } 1428 pos++ 1429 1430 h, err := strconv.ParseUint(string(total), 10, 32) 1431 if err != nil { 1432 return []byte{}, err 1433 } 1434 1435 days := uint32(h) / 24 1436 hours := uint32(h) % 24 1437 minute := sub1[1] 1438 second := sub1[2] 1439 1440 minutes, err := strconv.ParseUint(minute, 10, 8) 1441 if err != nil { 1442 return []byte{}, err 1443 } 1444 1445 seconds, err := strconv.ParseUint(second, 10, 8) 1446 if err != nil { 1447 return []byte{}, err 1448 } 1449 pos = writeUint32(out, pos, uint32(days)) 1450 pos = writeByte(out, pos, byte(hours)) 1451 pos = writeByte(out, pos, byte(minutes)) 1452 writeByte(out, pos, byte(seconds)) 1453 } else { 1454 err := fmt.Errorf("incorrect time value") 1455 return []byte{}, err 1456 } 1457 case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, 1458 sqltypes.VarBinary, sqltypes.Char, sqltypes.Bit, sqltypes.Enum, 1459 sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON: 1460 l := len(v.Raw()) 1461 length := lenEncIntSize(uint64(l)) + l 1462 out = make([]byte, length) 1463 pos = writeLenEncInt(out, pos, uint64(l)) 1464 copy(out[pos:], v.Raw()) 1465 default: 1466 out = make([]byte, len(v.Raw())) 1467 copy(out, v.Raw()) 1468 } 1469 return out, nil 1470 } 1471 1472 func val2MySQLLen(v sqltypes.Value) (int, error) { 1473 var length int 1474 var err error 1475 1476 switch v.Type() { 1477 case sqltypes.Null: 1478 length = 0 1479 case sqltypes.Int8, sqltypes.Uint8: 1480 length = 1 1481 case sqltypes.Uint16, sqltypes.Int16, sqltypes.Year: 1482 length = 2 1483 case sqltypes.Uint24, sqltypes.Uint32, sqltypes.Int24, sqltypes.Int32, sqltypes.Float32: 1484 length = 4 1485 case sqltypes.Uint64, sqltypes.Int64, sqltypes.Float64: 1486 length = 8 1487 case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime: 1488 if len(v.Raw()) > 19 { 1489 length = 12 1490 } else if len(v.Raw()) > 10 { 1491 length = 8 1492 } else if len(v.Raw()) > 0 { 1493 length = 5 1494 } else { 1495 length = 1 1496 } 1497 case sqltypes.Time: 1498 if string(v.Raw()) == "00:00:00" { 1499 length = 1 1500 } else if strings.Contains(string(v.Raw()), ".") { 1501 length = 13 1502 } else if len(v.Raw()) > 0 { 1503 length = 9 1504 } else { 1505 err = fmt.Errorf("incorrect time value") 1506 } 1507 case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, 1508 sqltypes.VarBinary, sqltypes.Char, sqltypes.Bit, sqltypes.Enum, 1509 sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON: 1510 l := len(v.Raw()) 1511 length = lenEncIntSize(uint64(l)) + l 1512 default: 1513 length = len(v.Raw()) 1514 } 1515 if err != nil { 1516 return 0, err 1517 } 1518 return length, nil 1519 }